"""
Monitors
--------
.. autosummary::
:toctree: toctree/monitor
Monitor
MonitorEmbedding
MonitorAutoencoder
Monitor Parameter Records
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: toctree/monitor
ParamRecord
ParamsDict
"""
from collections import UserDict, defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F
from mighty.monitor.accuracy import calc_accuracy, Accuracy
from mighty.monitor.batch_timer import timer, ScheduleExp
from mighty.monitor.mutual_info.stub import MutualInfoStub
from mighty.utils.var_online import VarianceOnline
from mighty.monitor.viz import VisdomMighty
from mighty.utils.common import clone_cpu
from mighty.utils.domain import MonitorLevel
__all__ = [
"ParamRecord",
"ParamsDict",
"Monitor",
"MonitorEmbedding",
"MonitorAutoencoder"
]
[docs]
class ParamRecord:
"""
A parameter record, created by a monitor, that tracks parameter statistics
like gradient variance, sign flips on update step, etc.
Parameters
----------
param : nn.Parameter
Model parameter.
monitor_level : MonitorLevel, optional
The extent of keeping the statistics.
Default: MonitorLevel.DISABLED
"""
def __init__(self, param, monitor_level=MonitorLevel.DISABLED):
self.param = param
self.monitor_level = monitor_level
self.grad_variance = VarianceOnline()
self.variance = VarianceOnline()
self.prev_data = None
if monitor_level & MonitorLevel.SIGN_FLIPS:
self.prev_data = clone_cpu(param.data)
self.initial_data = None
if monitor_level & MonitorLevel.WEIGHT_INITIAL_DIFFERENCE:
self.initial_data = clone_cpu(param.data)
self.initial_norm = param.data.float().norm(p=2).item()
[docs]
@staticmethod
def count_sign_flips(new_data, prev_data):
"""
Count the no. of sing flips in `prev_data -> new_data`.
The default implementation is
.. code-block:: python
sum(new_data * prev_data)
Parameters
----------
new_data, prev_data : torch.Tensor
New and previous tensors.
Returns
-------
sign_flips : int
The no. of signs flipped.
"""
sign_flips = (new_data * prev_data < 0).sum().item()
return sign_flips
[docs]
def update_signs(self):
"""
Updates the number of sign flips by comparing with the previously
stored tensor.
Returns
-------
sign_flips : float
Normalized number of sign flips in range ``[0, 1]``.
"""
new_data = clone_cpu(self.param.data)
if self.prev_data is None:
# The user might turn on the advanced monitoring in the middle.
self.prev_data = new_data
sign_flips = self.count_sign_flips(new_data, self.prev_data)
self.prev_data = new_data
return sign_flips
[docs]
def update_grad_variance(self):
"""
Updates the gradient variance, need for Signal-to-Noise ratio
estimation.
"""
if self.param.grad is not None:
self.grad_variance.update(self.param.grad.data.cpu())
[docs]
def reset(self):
"""
Resets the current state and all saved variables.
"""
self.grad_variance.reset()
self.variance.reset()
[docs]
class ParamsDict(UserDict):
"""
A dictionary that holds named `ParamRecord`s.
"""
def __init__(self):
super().__init__()
self.sign_flips = defaultdict(int)
self.n_updates = 0
def reset(self):
self.sign_flips.clear()
self.n_updates = 0
[docs]
def batch_finished(self):
"""
Batch finished callback that triggers `update*()` methods of the
stored param records.
"""
def filter_level(level: MonitorLevel):
# filter by greater or equal to the Monitor level
return (precord for precord in self.values()
if precord.monitor_level & level)
self.n_updates += 1
for precord in filter_level(MonitorLevel.SIGNAL_TO_NOISE):
precord.update_grad_variance()
for name, precord in self.items():
if precord.monitor_level & MonitorLevel.SIGN_FLIPS:
self.sign_flips[name] += precord.update_signs()
for precord in filter_level(MonitorLevel.WEIGHT_SNR_TRACE):
precord.variance.update(precord.param.data.cpu())
[docs]
def plot_sign_flips(self, viz, total=True):
"""
Plots the no. of weight sign flips after a batch update. Refer to
:func:`ParamRecord.update_signs`.
Parameters
----------
viz : VisdomMighty
Visdom server instance.
total : bool, optional
Whether to sum across all the parameters (True) or plot individual
parameter sign flips counts (False).
Default: True
"""
opts = dict(
xlabel='Epoch',
ylabel='Sign flips',
title="Sign flips after optimizer.step()",
)
if total:
sign_flips = sum(self.sign_flips.values()) / self.n_updates
else:
labels, sign_flips = zip(*self.sign_flips.items())
sign_flips = np.divide(sign_flips, self.n_updates)
opts['legend'] = list(labels)
viz.line_update(y=sign_flips, opts=opts)
[docs]
class Monitor:
"""
Generic Monitor that provides meaningful statistics in interactive Visdom
plots throughout training.
Parameters
----------
mutual_info : MutualInfo or None, optional
The Mutual Information estimator (the same as in a trainer).
Default: None
normalize_inverse : NormalizeInverse or None, optional
The inverse normalization transform, taken from a data loader.
Default: None
"""
n_classes_format_ytickstep_1 = 10
def __init__(self, mutual_info=None, normalize_inverse=None):
self.timer = timer
self.viz = None
self._advanced_monitoring_level = MonitorLevel.DISABLED
self.normalize_inverse = normalize_inverse
self.param_records = ParamsDict()
if mutual_info is None:
mutual_info = MutualInfoStub()
self.mutual_info = mutual_info
self.functions = []
@property
def is_active(self):
"""
Returns
-------
bool
Indicator whether a Visdom server is initialized or not.
"""
return self.viz is not None
[docs]
def advanced_monitoring(self, level=MonitorLevel.DEFAULT):
"""
Sets the extent of monitoring.
Parameters
----------
level : MonitorLevel, optional
New monitoring level to apply.
* DISABLED - only basic metrics are computed (memory tolerable)
* SIGNAL_TO_NOISE - track SNR of the gradients
* FULL - SNR, sign flips, weight hist, weight diff
Default: MonitorLevel.NORMAL
Notes
-----
Advanced monitoring is memory consuming.
"""
self._advanced_monitoring_level = level
for param_record in self.param_records.values():
param_record.monitor_level = level
[docs]
def open(self, env_name: str, offline=False):
"""
Opens a Visdom server.
Parameters
----------
env_name : str
Environment name.
offline : bool
Offline mode (True) or online (False).
"""
self.viz = VisdomMighty(env=env_name, offline=offline)
[docs]
def log_model(self, model, space='-'):
"""
Logs the model.
Parameters
----------
model : nn.Module
A PyTorch model.
space : str, optional
A space substitution to correctly parse HTML later on.
Default: '-'
"""
lines = []
for line in repr(model).splitlines():
n_spaces = len(line) - len(line.lstrip())
line = space * n_spaces + line
lines.append(line)
lines = '<br>'.join(lines)
self.log(lines)
[docs]
def log_self(self):
"""
Logs the monitor itself.
"""
self.log(f"{self.__class__.__name__}("
f"level={self._advanced_monitoring_level})")
[docs]
def log(self, text):
"""
Logs the text.
Parameters
----------
text : str
Log text.
"""
self.viz.log(text)
[docs]
def batch_finished(self, model):
"""
Batch finished monitor callback.
Parameters
----------
model : nn.Module
A model that has been trained the last epoch.
"""
self.param_records.batch_finished()
self.timer.tick()
if self.timer.epoch == 0:
self.batch_finished_first_epoch(model)
[docs]
@ScheduleExp()
def batch_finished_first_epoch(self, model):
"""
First batch finished monitor callback.
Parameters
----------
model : nn.Module
A model that has been trained the last epoch.
"""
# inspect the very beginning of the training progress
self.mutual_info.force_update(model)
self.update_mutual_info()
self.update_gradient_signal_to_noise_ratio()
[docs]
def update_loss(self, loss, mode='batch'):
"""
Update the loss plot with a new value.
Parameters
----------
loss : torch.Tensor
Loss tensor. If None, do noting.
mode : {'batch', 'epoch'}, optional
The update mode.
Default: 'batch'
"""
if loss is None:
return
self.viz.line_update(loss.item(), opts=dict(
xlabel='Epoch',
title='Loss'
), name=mode)
[docs]
def update_accuracy(self, accuracy, mode='batch'):
"""
Update the accuracy plot with a new value.
Parameters
----------
accuracy : torch.Tensor or float
Accuracy scalar.
mode : {'batch', 'epoch'}, optional
The update mode.
Default: 'batch'
"""
title = 'Accuracy'
if isinstance(self, MonitorAutoencoder):
# the ability to identify class ID given an embedding vector
title = f"{title} embedding"
self.viz.line_update(accuracy, opts=dict(
xlabel='Epoch',
ylabel='Accuracy',
title=title
), name=mode)
[docs]
def clear(self):
"""
Clear out all Visdom plots.
"""
self.viz.close()
[docs]
def register_func(self, func):
"""
Register a plotting function to call on the end of each epoch.
The `func` must have only one argument `viz`, a Visdom instance.
Parameters
----------
func : callable
User-provided plot function with one argument `viz`.
"""
self.functions.append(func)
[docs]
def update_weight_histogram(self):
"""
Update the model weights histogram.
"""
if not self._advanced_monitoring_level & MonitorLevel.WEIGHT_HISTOGRAM:
return
for name, param_record in self.param_records.items():
param_data = param_record.param.data.cpu()
if param_data.numel() == 1:
# not used since biases aren't tracked anymore
self.viz.line_update(y=param_data.item(), opts=dict(
xlabel='Epoch',
ylabel='Value',
title='Layer biases',
name=name,
))
else:
self.viz.histogram(X=param_data.view(-1), win=name, opts=dict(
xlabel='Param value',
ylabel='# bins (distribution)',
title=name,
))
[docs]
def update_weight_trace_signal_to_noise_ratio(self):
"""
Update the SNR, mean divided by std, of the model weights.
If mean / std is large, the network is confident in which direction
to "move". If mean / std is small, the network is making random walk.
"""
if not self._advanced_monitoring_level & MonitorLevel.WEIGHT_SNR_TRACE:
return
for name, param_record in self.param_records.items():
mean, std = param_record.variance.get_mean_std()
snr = mean / std
snr = snr[torch.isfinite(snr)]
if snr.nelement() == 0:
continue
snr.pow_(2)
snr.log10_().mul_(10)
name = f"Weight SNR {name}"
self.viz.histogram(X=snr.flatten(), win=name, opts=dict(
xlabel='10 log10[(mean/std)^2], db',
ylabel='# params (distribution)',
title=name,
))
[docs]
def update_gradient_signal_to_noise_ratio(self):
"""
Update the SNR, mean divided by std, of the model weight gradients.
Similar to :func:`Monitor.update_weight_trace_signal_to_noise_ratio`
but on a smaller time scale.
"""
if not self._advanced_monitoring_level & MonitorLevel.SIGNAL_TO_NOISE:
# SNR is not monitored
return
snr = []
legend = []
for name, param_record in self.param_records.items():
param = param_record.param
if param.grad is None:
continue
mean, std = param_record.grad_variance.get_mean_std()
param_norm = param.data.norm(p=2).cpu()
# matrix Frobenius norm is L2 norm
mean = mean.norm(p=2) / param_norm
std = std.norm(p=2) / param_norm
snr.append(mean / std)
legend.append(name)
if self._advanced_monitoring_level == MonitorLevel.FULL:
if (std == 0).all():
# skip the first update
continue
self.viz.line_update(y=[mean, std], opts=dict(
xlabel='Epoch',
ylabel='Normalized Mean and STD',
title=f'Gradient Mean and STD: {name}',
legend=['||Mean(∇Wi)||', '||STD(∇Wi)||'],
xtype='log',
ytype='log',
))
if any(snr) and np.isfinite(snr).all():
snr.append(torch.tensor(1.))
legend.append('phase-transition')
self.viz.line_update(y=snr, opts=dict(
xlabel='Epoch',
ylabel='||Mean(∇Wi)|| / ||STD(∇Wi)||',
title='Gradient Signal to Noise Ratio',
legend=legend,
xtype='log',
ytype='log',
))
[docs]
def update_accuracy_epoch(self, labels_pred, labels_true, mode):
"""
The callback to calculate and update the epoch accuracy from a batch
of predicted and true class labels.
Parameters
----------
labels_pred, labels_true : (N,) torch.Tensor
Predicted and true class labels.
mode : str
Update mode: 'batch' or 'epoch'.
Returns
-------
accuracy : torch.Tensor
A scalar tensor with one value - accuracy.
"""
accuracy = calc_accuracy(labels_true, labels_pred)
self.update_accuracy(accuracy=accuracy, mode=mode)
title = f"Confusion matrix '{mode}'"
if len(labels_true.unique()) <= self.n_classes_format_ytickstep_1:
# don't plot huge matrices
if labels_true.ndim == 2:
onehot = F.one_hot(labels_pred, num_classes=labels_true.shape[1])
labels_true = (labels_true * onehot).argmax(dim=1)
confusion = confusion_matrix(labels_true, labels_pred)
self.viz.heatmap(confusion, win=title, opts=dict(
title=title,
xlabel='Predicted label',
ylabel='True label',
))
return accuracy
[docs]
def update_mutual_info(self):
"""
Update the mutual info.
"""
self.mutual_info.plot(self.viz)
[docs]
def epoch_finished(self):
"""
Epoch finished callback.
"""
self.update_mutual_info()
for monitored_function in self.functions:
monitored_function(self.viz)
self.update_grad_norm()
self.update_gradient_signal_to_noise_ratio()
if self._advanced_monitoring_level & MonitorLevel.SIGN_FLIPS:
self.param_records.plot_sign_flips(self.viz)
self.param_records.reset()
self.update_initial_difference()
self.update_weight_trace_signal_to_noise_ratio()
self.update_weight_histogram()
self.reset()
[docs]
def reset(self):
"""
Reset the parameter records statistics.
"""
for precord in self.param_records.values():
precord.reset()
[docs]
def register_layer(self, layer: nn.Module, prefix: str):
"""
Register a layer to monitor.
Parameters
----------
layer : nn.Module
A model layer.
prefix : str
The layer name.
"""
for name, param in layer.named_parameters(prefix=prefix):
if param.requires_grad and not name.endswith('.bias'):
self.param_records[name] = ParamRecord(
param,
monitor_level=self._advanced_monitoring_level
)
[docs]
def update_initial_difference(self):
"""
Update the L1 normalized difference between the current and starting
weights (before training).
"""
if not self._advanced_monitoring_level \
& MonitorLevel.WEIGHT_INITIAL_DIFFERENCE:
return
legend = []
dp_normed = []
for name, precord in self.param_records.items():
legend.append(name)
if precord.initial_data is None:
precord.initial_data = precord.param.data.cpu().clone()
dp = precord.param.data.cpu() - precord.initial_data
dp = dp.float().norm(p=2) / precord.initial_norm
dp_normed.append(dp)
self.viz.line_update(y=dp_normed, opts=dict(
xlabel='Epoch',
ylabel='||W - W_initial|| / ||W_initial||',
title='How far are the current weights from the initial?',
legend=legend,
))
[docs]
def update_grad_norm(self):
"""
Update the parameters gradient norm.
"""
grad_norms = []
legend = []
for name, param_record in self.param_records.items():
grad = param_record.param.grad
if grad is not None:
grad_norms.append(grad.norm(p=2).cpu())
legend.append(name)
if len(grad_norms) > 0:
self.viz.line_update(y=grad_norms, opts=dict(
xlabel='Epoch',
ylabel='Gradient norm, L2',
title='Gradient norm',
legend=legend,
))
[docs]
def plot_psnr(self, psnr, mode='train'):
"""
If given, plot the Peak Signal to Noise Ratio.
Used in training autoencoders.
Parameters
----------
psnr : torch.Tensor or float
The Peak Signal to Noise Ratio scalar.
mode : {'train', 'test'}, optional
The update mode.
Default: 'train'
"""
self.viz.line_update(y=psnr, opts=dict(
xlabel='Epoch',
ylabel='PSNR',
title='Peak signal-to-noise ratio',
), name=mode)
[docs]
class MonitorEmbedding(Monitor):
"""
A monitor for :class:`mighty.trainer.TrainerEmbedding`.
"""
[docs]
def update_sparsity(self, sparsity, mode):
"""
Update the L1 sparsity of the hidden layer activations.
Parameters
----------
sparsity : torch.Tensor or float
Sparsity scalar.
mode : {'train', 'test'}
The update mode.
"""
# L1 sparsity
self.viz.line_update(y=sparsity, opts=dict(
xlabel='Epoch',
ylabel='L1 norm / size',
title='Output L1 sparsity',
), name=mode)
[docs]
def embedding_hist(self, activations):
"""
Plots embedding activations histogram.
Parameters
----------
activations : (N,) torch.Tensor
The averaged embedding vector.
"""
if activations is None:
return
title = "Embedding activations hist"
self.viz.histogram(X=activations.flatten(), win=title, opts=dict(
xlabel='Neuron value',
ylabel='Count',
title=title,
ytype='log',
))
[docs]
def update_pairwise_dist(self, mean, std):
r"""
Updates L1 mean pairwise distance and intra-STD of embedding centroids.
.. math::
\text{inter-distance} = \frac{pdist(\text{mean}, p=1)}
{||\text{mean}||_1}
\text{intra-STD} = \frac{||\text{std}||_1}{||\text{mean}||_1},
where :math:`||x||_1 = \frac{\sum_{ij}{|x_{ij}|}}{C}`.
When these two metrics equalize, the feature vectors, embeddings,
become indistinguishable. We want :math:`\text{inter-distance}` to be
high and :math:`\text{intra-STD}` to keep low.
Parameters
----------
mean, std : torch.Tensor
Tensors of shape `(C, V)`.
The mean and standard deviation of `C` clusters (vectors of size
`V`).
Returns
-------
torch.Tensor
A scalar, a proxy to signal-to-noise ratio defined as
.. math::
\frac{\text{inter-distance}}{\text{intra-STD}}
"""
if mean is None:
return
if mean.shape != std.shape:
raise ValueError("The mean and std must have the same shape and "
"come from VarianceOnline.get_mean_std().")
l1_norm = mean.norm(p=1, dim=1).mean()
pdist = torch.pdist(mean, p=1).mean() / l1_norm
std = std.norm(p=1, dim=1).mean() / l1_norm
self.viz.line_update(y=[pdist.item(), std.item()], opts=dict(
xlabel='Epoch',
ylabel='Mean pairwise distance (normalized) and STD',
legend=['inter-distance', 'intra-STD'],
title='How much do patterns differ in L1 measure?',
))
return pdist / std
[docs]
def clusters_heatmap(self, mean, title=None, save=False):
"""
Cluster centers distribution heatmap.
Parameters
----------
mean : (C, V) torch.Tensor
The mean of `C` clusters (vectors of size `V`).
title : str or None, optional
An optional title to the plot.
If None, set to "Embedding activations heatmap".
Default: None
save : bool, optional
Save the heatmap plot in a separate fixed window.
Default: False
"""
if mean is None:
return
n_classes = mean.shape[0]
if title is None:
win = "Embedding activations heatmap"
title = f"{win}. Epoch {self.timer.epoch}"
else:
win = title
opts = dict(
title=title,
xlabel='Neuron',
ylabel='Label',
rownames=list(map(str, range(n_classes))),
)
if n_classes <= self.n_classes_format_ytickstep_1:
opts.update(ytickstep=1)
if save:
win = f"{win}. Epoch {self.timer.epoch}"
self.viz.heatmap(mean.cpu(), win=win, opts=opts)
[docs]
def update_l1_neuron_norm(self, l1_norm: torch.Tensor):
"""
Update the L1 neuron norm heatmap, normalized by the batch size.
Useful to explore which neurons are "dead" and which - "super active".
Parameters
----------
l1_norm : (V,) torch.Tensor
L1 norm per neuron in a hidden layer.
"""
l1_norm = l1_norm.unsqueeze(dim=0)
title = 'Neuron L1 norm'
self.viz.heatmap(l1_norm, win=title, opts=dict(
title=title,
xlabel='Neuron ID',
rownames=['Last layer'],
width=None,
height=200,
))
[docs]
class MonitorAutoencoder(MonitorEmbedding):
"""
A monitor for :class:`mighty.trainer.TrainerAutoencoder`.
"""
[docs]
def plot_autoencoder(self, images, reconstructed, *tensors, labels=(),
normalize_inverse=True, n_show=10, mode='train'):
"""
Plot autoencoder reconstructed samples.
Parameters
----------
images, reconstructed : (B, C, H, W) torch.Tensor
A batch of input and reconstructed images.
*tensors : (B, C, H, W) torch.Tensor
Other tensors of the same size, showing intermediate steps.
labels : tuple of str, optional
The labels of additional `tensors`.
normalize_inverse : bool, optional
Perform inverse normalization to the input samples to show images
in the original input domain rather than zero-mean unit-variance
normalized representation.
Default: True
n_show : int, optional
The number of samples to show.
Default: 10
mode : {'train', 'test'}
The update mode.
Default: 'train'
"""
if images.shape != reconstructed.shape:
raise ValueError("Input & reconstructed image shapes differ")
n_take = 10 if n_show == 'all' else n_show
n_take = min(images.shape[0], n_take)
combined = [images, reconstructed, *tensors]
combined = [t.cpu() for t in combined]
if n_show == 'all':
combined = [t.split(n_take) for t in combined]
else:
combined = [[t[:n_show]] for t in combined]
images_stacked = []
for batch in zip(*combined):
for tensor in batch:
if normalize_inverse and self.normalize_inverse is not None:
tensor = self.normalize_inverse(tensor)
images_stacked.append(tensor)
empty_space = torch.ones_like(images_stacked[0])
images_stacked.append(empty_space)
images_stacked.pop() # remove the last empty batch pad
images_stacked = torch.cat(images_stacked, dim=0).squeeze_()
if images_stacked.ndim == 2:
# (2xB, L) time series
images, reconstructed = images_stacked.tensor_split(len(combined))[:2]
images = images.numpy()
reconstructed = reconstructed.numpy()
for i, (im, rec) in enumerate(zip(images, reconstructed)):
ts = np.c_[im, rec]
title = f"Reconstructed {i}"
self.viz.line(Y=ts, X=np.arange(len(ts)), win=title, opts=dict(
title=title,
label=['orig', 'reconstructed']
))
self.viz.update_window_opts(win=title, opts=dict(legend=[], title=title))
else:
# (2xB, H, W) images
images_stacked.clamp_(0, 1)
if images_stacked.ndim == 3:
images_stacked.unsqueeze_(1) # (B, 1, H, W)
labels = [f'[{mode.upper()}] Original (Top)', 'Reconstructed', *labels]
self.viz.images(images_stacked,
nrow=n_take, win=f'autoencoder {mode}', opts=dict(
title=' | '.join(labels),
width=1000,
height=None,
))
[docs]
def plot_reconstruction_error(self, pixel_missed, thresholds,
optimal_id=None):
"""
Plot the reconstruction pixel-wise error, depending on the threshold.
When the input images can be considered as binary, like MNIST, this
plot helps to choose the correct threshold that minimizes incorrectly
reconstructed binary pixels count.
Parameters
----------
pixel_missed : (N,) torch.Tensor
Incorrectly reconstructed pixels count.
thresholds : (N,) torch.Tensor
Used thresholds.
optimal_id : int or None, optional
The optimal threshold ID used as the "best" threshold. If None,
set to ``pixel_missed.argmin()``.
"""
title = "Reconstruction error"
self.viz.line(Y=pixel_missed, X=thresholds, win=title, opts=dict(
title=title,
xlabel="reconstruct threshold",
ylabel="#incorrect_pixels"
))
if optimal_id is None:
optimal_id = pixel_missed.argmin()
self.viz.line_update(pixel_missed[optimal_id], opts=dict(
title="Reconstruction error lowest",
xlabel="Epoch",
ylabel="min_thr #incorrect_pixels"
))
self.viz.line_update(thresholds[optimal_id].item(), opts=dict(
title="Reconstruction threshold",
xlabel="Epoch",
ylabel="Thr. that minimizes the error"
))