Source code for mighty.trainer.embedding
from typing import Union
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer
from mighty.monitor.accuracy import Accuracy, AccuracyEmbedding
from mighty.monitor.monitor import MonitorEmbedding
from mighty.utils.var_online import MeanOnline, VarianceOnlineLabels
from mighty.utils.signal import compute_sparsity
from mighty.utils.data import DataLoader
from .gradient import TrainerGrad
__all__ = [
"TrainerEmbedding"
]
[docs]
class TrainerEmbedding(TrainerGrad):
"""
An (unsupervised) trainer that transforms input data into
linearly-separable embedding vectors that form clusters.
Parameters
----------
model : nn.Module
A neural network to train.
criterion : nn.Module
A loss function.
data_loader : DataLoader
A data loader.
optimizer : Optimizer
An optimizer (Adam, SGD, etc.).
scheduler : _LRScheduler or ReduceLROnPlateau, or None
A learning rate scheduler.
Default: None
accuracy_measure : AccuracyEmbedding, optional
Calculates the accuracy of embedding vectors.
Default: ``AccuracyEmbedding()``
**kwargs
Passed to the base class.
"""
def __init__(self,
model: nn.Module,
criterion: nn.Module,
data_loader: DataLoader,
optimizer: Optimizer,
scheduler: Union[_LRScheduler, ReduceLROnPlateau] = None,
accuracy_measure: Accuracy = AccuracyEmbedding(),
**kwargs):
if not isinstance(accuracy_measure, AccuracyEmbedding):
raise ValueError("'accuracy_measure' must be of instance "
f"{AccuracyEmbedding.__name__}")
super().__init__(model=model,
criterion=criterion,
data_loader=data_loader,
optimizer=optimizer,
scheduler=scheduler,
accuracy_measure=accuracy_measure,
**kwargs)
def _init_monitor(self, mutual_info):
monitor = MonitorEmbedding(
mutual_info=mutual_info,
normalize_inverse=self.data_loader.normalize_inverse
)
return monitor
def _init_online_measures(self):
online = super()._init_online_measures()
online['sparsity'] = MeanOnline() # scalar
online['l1_norm'] = MeanOnline() # (V,) vector
online['clusters'] = VarianceOnlineLabels() # (C, V) tensor
return online
def _on_forward_pass_batch(self, batch, output, train):
if train:
output = output.float()
sparsity = compute_sparsity(output)
self.online['sparsity'].update(sparsity.cpu())
# L1 norm sums the result, we take the batch mean
self.online['l1_norm'].update(output.abs().mean(dim=0).cpu())
if self.data_loader.has_labels:
# supervised
input, labels = batch
self.online['clusters'].update(output, labels)
super()._on_forward_pass_batch(batch, output, train)
def _epoch_finished(self, loss):
self.monitor.update_sparsity(self.online['sparsity'].get_mean(),
mode='train')
self.monitor.update_l1_neuron_norm(self.online['l1_norm'].get_mean())
# mean and std can be Nones
mean, std = self.online['clusters'].get_mean_std()
self.monitor.clusters_heatmap(mean)
self.monitor.update_pairwise_dist(mean, std)
self.monitor.embedding_hist(activations=mean)
super()._epoch_finished(loss)