Source code for mighty.trainer.trainer

import subprocess
import sys
import time
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict

import torch
import torch.nn as nn
import torch.utils.data
from tqdm import tqdm, trange

from mighty.loss import PairLossSampler, LossPenalty
from mighty.models import AutoencoderOutput, MLP
from mighty.monitor.accuracy import AccuracyEmbedding, \
    AccuracyArgmax, Accuracy
from mighty.monitor.batch_timer import timer
from mighty.monitor.monitor import Monitor
from mighty.monitor.mutual_info import MutualInfoNeuralEstimation, \
    MutualInfoStub
from mighty.monitor.mutual_info.mutual_info import MutualInfo
from mighty.trainer.mask import MaskTrainer
from mighty.utils.common import find_named_layers, batch_to_cuda, \
    input_from_batch
from mighty.utils.constants import CHECKPOINTS_DIR
from mighty.utils.data import DataLoader
from mighty.utils.prepare import prepare_eval
from mighty.utils.var_online import MeanOnline

__all__ = [
    "Trainer"
]


[docs] class Trainer(ABC): """ Trainer base class. Parameters ---------- model : nn.Module A neural network to train. criterion : nn.Module Loss function. data_loader : DataLoader A data loader. accuracy_measure : Accuracy or None, optional Calculates the accuracy from the last layer activations. If None, set to :code:`AccuracyArgmax` for a classification task and :code:`AccuracyEmbedding` otherwise. .. code-block:: python if isinstance(criterion, PairLossSampler): accuracy_measure = AccuracyEmbedding() else: # cross entropy loss accuracy_measure = AccuracyArgmax() Default: None mutual_info : MutualInfo or None, optional A handle to compute the mutual information I(X; T) and I(Y; T) [1]_. If None, don't compute the mutual information. Default: None env_suffix : str, optional The suffix to add to the current environment name. Default: '' checkpoint_dir : Path or str, optional The path to store the checkpoints. Default: ``${HOME}/.mighty/checkpoints`` verbosity : int, optional * 0 - don't print anything * 1 - show the progress with each epoch * 2 - show the progress with each batch Default: 2 References ---------- .. [1] Shwartz-Ziv, R., & Tishby, N. (2017). Opening the black box of deep neural networks via information. arXiv preprint arXiv:1703.00810. Notes ----- For the choice of ``mutual_info`` refer to https://github.com/dizcza/entropy-estimators """ watch_modules = (nn.Linear, nn.Conv2d, MLP, nn.RNNBase, nn.RNNCellBase) def __init__(self, model: nn.Module, criterion: nn.Module, data_loader: DataLoader, accuracy_measure: Accuracy = None, mutual_info=None, env_suffix='', checkpoint_dir=CHECKPOINTS_DIR, verbosity=2): if torch.cuda.is_available(): model.cuda() self.model = model self.criterion = criterion self.data_loader = data_loader self.train_loader = data_loader.get(train=True) if mutual_info is None: mutual_info = MutualInfoNeuralEstimation(data_loader) self.mutual_info = mutual_info self.checkpoint_dir = Path(checkpoint_dir) self.timer = timer self.timer.init(batches_in_epoch=len(self.train_loader)) self.timer.set_epoch(0) criterion_name = self.criterion.__class__.__name__ if isinstance(criterion, LossPenalty): criterion_name = f"{criterion_name}(" \ f"{criterion.criterion.__class__.__name__})" self.env_name = f"{time.strftime('%Y.%m.%d')} " \ f"{model.__class__.__name__}: " \ f"{data_loader.dataset_cls.__name__} " \ f"{self.__class__.__name__} " \ f"{criterion_name}" env_suffix = env_suffix.lstrip(' ') if env_suffix: self.env_name = f'{self.env_name} {env_suffix}' self.verbosity = verbosity if accuracy_measure is None: if isinstance(self.criterion, PairLossSampler): accuracy_measure = AccuracyEmbedding() else: # cross entropy loss accuracy_measure = AccuracyArgmax() self.accuracy_measure = accuracy_measure self.monitor = self._init_monitor(mutual_info) self.online = self._init_online_measures() self.best_score = {} self.epoch_finished_callbacks = [] self.n_classes = 0 # no. of classes if supervised self.training = True @property def epoch(self): """ The current epoch, int. """ return self.timer.epoch
[docs] def checkpoint_path(self, best=None): """ Get the checkpoint path, given the mode. Parameters ---------- best : str or None Tag name. If set, the path will be expanded to ``".../best/tag"``. Default: None Returns ------- Path Checkpoint path. """ checkpoint_dir = self.checkpoint_dir if best is not None: checkpoint_dir = self.checkpoint_dir / "best" / best checkpoint_dir.mkdir(parents=True, exist_ok=True) return checkpoint_dir / (self.env_name + '.pt')
[docs] def monitor_functions(self): """ Override this method to register `Visdom` callbacks on each epoch. """ pass
[docs] def log_trainer(self): """ Logs the trainer in `Visdom` text field. """ self.monitor.log_model(self.model) self.monitor.log(f"Criterion: {self.criterion}") self.monitor.log(repr(self.data_loader)) self.monitor.log_self() self.monitor.log(repr(self.accuracy_measure)) self.monitor.log(repr(self.mutual_info)) git_dir = Path(sys.argv[0]).parent while str(git_dir) != git_dir.root: try: commit = subprocess.run(['git', '--git-dir', str(git_dir / '.git'), 'rev-parse', 'HEAD'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) except FileNotFoundError: # Git is not installed break if commit.returncode == 0: self.monitor.log(f"Git location '{str(git_dir)}' " f"commit: {commit.stdout}") break git_dir = git_dir.parent
def _init_monitor(self, mutual_info) -> Monitor: monitor = Monitor( mutual_info=mutual_info, normalize_inverse=self.data_loader.normalize_inverse ) return monitor def _init_online_measures(self) -> Dict[str, MeanOnline]: return dict(accuracy=MeanOnline())
[docs] @abstractmethod def train_batch(self, batch): """ The core function of a trainer to update the model parameters, given a batch. Parameters ---------- batch : torch.Tensor or tuple of torch.Tensor :code:`(X, Y)` or :code:`X` batch of input data. Returns ------- loss : torch.Tensor The batch loss. """ raise NotImplementedError()
[docs] def update_best_score(self, score, score_type='loss'): """ If :code:`score` is greater than the :code:`self.best_score`, save the model. The internal best score is updated and the current model is saved as "best" if the object's :attr:`best_score_type` tag matches with its class :attr:`best_score_type`. Parameters ---------- score : float The model score at the current epoch. The higher, the better. The simplest way to use this function is set :code:`score = -loss`. score_type : str, optional A key-word to determine the criteria for the "best" score. The name of the tag is irrelevant. Default: 'loss' """ # This function can be called multiple times from different functions # but only one call will lead to updating the score and saving the # best model. if score_type not in self.best_score \ or score > self.best_score[score_type]: self.best_score[score_type] = score self.save(best=score_type) self.monitor.log(f"[epoch={self.timer.epoch}] " f"'{score_type}' best score: {score}")
[docs] def save(self, best=None): """ Saves the trainer and the model parameters to :code:`self.checkpoint_path(best)`. Parameters ---------- best : str or None Tag name. If set, the path will be expanded to ``".../best/tag"``. Default: None See Also -------- restore : restore the training progress """ if not self.training: return checkpoint_path = self.checkpoint_path(best) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) try: torch.save(self.state_dict(), checkpoint_path) except PermissionError as error: print(error)
[docs] def state_dict(self): """ Returns ------- dict A dict of the trainer state to be saved. """ return { "model_state": self.model.state_dict(), "data_norm": self.data_loader.state_dict(), "epoch": self.timer.epoch, "env_name": self.env_name, "best_score": self.best_score, }
[docs] def restore(self, checkpoint_path=None, best=None, strict=True): """ Restores the trainer progress and the model from the path. Parameters ---------- checkpoint_path : Path or None Trainer checkpoint path to restore. If None, the default path :code:`self.checkpoint_path()` is used. Default: None best : str or None Best or latest (refer to :func:`Trainer.checkpoint_path`). strict : bool Strict model loading or not. Returns ------- checkpoint_state : dict The loaded state of a trainer. """ if checkpoint_path is None: checkpoint_path = self.checkpoint_path(best) checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): print(f"Checkpoint '{checkpoint_path}' doesn't exist. " f"Nothing to restore.") return None map_location = None if not torch.cuda.is_available(): map_location = 'cpu' checkpoint_state = torch.load(checkpoint_path, map_location=map_location, weights_only=True) try: self.model.load_state_dict(checkpoint_state['model_state'], strict=strict) except RuntimeError as error: print(f"Restoring {checkpoint_path} raised {error}") return None self.env_name = checkpoint_state['env_name'] self.timer.set_epoch(checkpoint_state['epoch']) self.best_score = checkpoint_state['best_score'] self.monitor.open(env_name=self.env_name) print(f"Restored model state from {checkpoint_path}.") return checkpoint_state
def _get_loss(self, batch, output): # In case of unsupervised learning, '_get_loss' is overridden # accordingly. input, labels = batch return self.criterion(output, labels) def _on_forward_pass_batch(self, batch, output, train): if self.is_unsupervised(): # unsupervised, no labels return _, labels = batch if train: self.accuracy_measure.partial_fit(output, labels) else: self.accuracy_measure.true_labels_cached.append(labels.cpu()) labels_pred = self.accuracy_measure.predict(output).cpu() self.accuracy_measure.predicted_labels_cached.append(labels_pred) def _forward(self, batch): input = input_from_batch(batch) return self.model(input)
[docs] def is_unsupervised(self): """ Returns ------- bool True, if the training is unsupervised and False otherwise. """ return not self.data_loader.has_labels
[docs] def full_forward_pass(self, train=True): """ Fixes the model weights, evaluates the epoch score and updates the monitor. Parameters ---------- train : bool Either train (True) or test (False) batches to run. In both cases, the model is set to the evaluation regime via `self.model.eval()`. Returns ------- loss : torch.Tensor The loss of a full forward pass. """ mode_saved = self.model.training self.model.train(False) self.accuracy_measure.reset_labels() loss_online = MeanOnline() if train: description = "Full forward pass (eval)" \ if self.verbosity >= 2 else None loader = self.data_loader.eval(description) self.mutual_info.start_listening() else: loader = self.data_loader.get(train) proba_list = [] labels_list = [] with torch.no_grad(): for batch in loader: batch = batch_to_cuda(batch) output = self._forward(batch) if isinstance(self.accuracy_measure, AccuracyArgmax) and self.data_loader.has_labels \ and self.n_classes in (0, 2): proba = self.accuracy_measure.predict_proba(output) self.n_classes = proba.shape[1] if self.n_classes == 2: proba_list.append(proba) batch_input, batch_labels = batch labels_list.append(batch_labels) loss = self._get_loss(batch, output) self._on_forward_pass_batch(batch, output, train) loss_online.update(loss) if len(labels_list): labels_true = torch.cat(labels_list) proba_list = torch.cat(proba_list) self.monitor.update_precision_recall(proba_list, labels_true, train) self.mutual_info.finish_listening() self.model.train(mode_saved) loss = loss_online.get_mean() self.monitor.update_loss(loss, mode='train' if train else 'test') self.update_best_score(score=-loss.item(), score_type='loss') self.update_accuracy(train=train) return loss
def _epoch_finished(self, loss): for clbk in self.epoch_finished_callbacks: clbk(loss) self.monitor.epoch_finished() self.save() for online_measure in self.online.values(): online_measure.reset() self.accuracy_measure.reset()
[docs] def update_accuracy(self, train=True): """ Updates the accuracy of the model. Parameters ---------- train : bool Either train (True) or test (False) mode. Returns ------- accuracy : torch.Tensor A scalar with the accuracy value. """ if self.is_unsupervised(): return None labels_true = torch.cat(self.accuracy_measure.true_labels_cached) if not train or isinstance(self.accuracy_measure, AccuracyArgmax): labels_pred = torch.cat( self.accuracy_measure.predicted_labels_cached) elif getattr(self.accuracy_measure, 'cache', False): labels_pred = self.accuracy_measure.predict_cached() else: labels_pred = [] with torch.no_grad(): for batch in self.data_loader.eval(): batch = batch_to_cuda(batch) output = self._forward(batch) if isinstance(output, AutoencoderOutput): output = output.latent labels_pred.append( self.accuracy_measure.predict(output).cpu()) labels_pred = torch.cat(labels_pred, dim=0) if labels_true.is_cuda: warnings.warn("'labels_true' is a cuda tensor") labels_true = labels_true.cpu() if labels_pred.is_cuda: warnings.warn("'labels_pred' is a cuda tensor") labels_pred = labels_pred.cpu() mode = 'train' if train else 'test' accuracy = self.monitor.update_accuracy_epoch(labels_pred, labels_true, mode=mode) self.update_best_score(accuracy, score_type=f'accuracy-{mode}') return accuracy
[docs] def train_mask(self, mask_explain_params=dict()): """ Train mask to see what part of an image is crucial from the network perspective (saliency map). Parameters ---------- mask_explain_params : dict, optional `MaskTrainer` keyword arguments. """ images, labels = next(iter(self.train_loader)) mask_trainer = MaskTrainer(self.accuracy_measure, image_shape=images[0].shape, **mask_explain_params) mode_saved = prepare_eval(self.model) if torch.cuda.is_available(): images = images.cuda() with torch.no_grad(): proba = self.accuracy_measure.predict_proba(self.model(images)) proba_max, _ = proba.max(dim=1) sample_max_proba = proba_max.argmax() image = images[sample_max_proba] label = labels[sample_max_proba] self.monitor.plot_explain_input_mask(self.model, mask_trainer=mask_trainer, image=image, label=label) mode_saved.restore(self.model) return image, label
def is_test_mode(self): params = [param for param in self.model.parameters() if param.requires_grad] return len(params) == 0 and not self.model.training def test(self): self.training = False self.monitor.viz.close() self.model.eval() for param in self.model.parameters(): param.requires_grad_(False) self.train(n_epochs=1)
[docs] def train_epoch(self, epoch): """ Trains an epoch. Parameters ---------- epoch : int Epoch ID. """ loss_online = MeanOnline() for batch in tqdm(self.train_loader, desc="Epoch {:d}".format(epoch), disable=self.verbosity < 2, leave=False): batch = batch_to_cuda(batch) if self.is_test_mode(): outputs = self._forward(batch) loss = self._get_loss(batch, outputs) else: loss = self.train_batch(batch) loss_online.update(loss.detach().cpu()) for name, param in self.model.named_parameters(): if torch.isnan(param).any(): warnings.warn(f"NaN parameters in '{name}'") self.monitor.batch_finished(self.model) self.monitor.update_loss(loss=loss_online.get_mean(), mode='batch')
[docs] def open_monitor(self, offline=False): """ Opens a `Visdom` monitor. Parameters ---------- offline : bool Online (False) or offline (True) monitoring. """ # visdom can be already initialized via trainer.restore() if self.monitor.viz is None: # new environment self.monitor.open(env_name=self.env_name, offline=offline) self.monitor.clear()
def _prepare_train(self, mutual_info_layers=0): self.open_monitor() self.monitor_functions() self.log_trainer() for name, layer in find_named_layers(self.model, layer_class=self.watch_modules): self.monitor.register_layer(layer, prefix=name) if mutual_info_layers > 0 and not isinstance(self.mutual_info, MutualInfoStub): # Mutual Information in conv layers are poorly estimated monitor_classes = set(self.watch_modules).difference({nn.Conv2d}) self.mutual_info.prepare(model=self.model, monitor_layers=tuple(monitor_classes), monitor_layers_count=mutual_info_layers, verbosity=self.verbosity)
[docs] def training_started(self): """ Training is started callback. This function is called before training the first epoch. """ pass
[docs] def training_finished(self): """ Training is finished callback. This function is called right before exiting the :func:`Trainer.train` function. """ pass
[docs] def train(self, n_epochs=10, mutual_info_layers=0, mask_explain_params=None): """ User-entry function to train the model for :code:`n_epochs`. Parameters ---------- n_epochs : int The number of epochs to run. Default: 10 mutual_info_layers : int, optional Evaluate the mutual information [1]_ from the last :code:`mutual_info_layers` layers at each epoch. If set to 0, skip the (time-consuming) mutual information estimation. Default: 0 mask_explain_params : dict or None, optional If not None, a dictionary with parameters for :class:`MaskTrainer`, that is used to show the "saliency map" [2]_. Default: None Returns ------- loss_epochs : list A list of epoch loss. References ---------- .. [1] Shwartz-Ziv, R., & Tishby, N. (2017). Opening the black box of deep neural networks via information. arXiv preprint arXiv:1703.00810. .. [2] Fong, R. C., & Vedaldi, A. (2017). Interpretable explanations of black boxes by meaningful perturbation. """ if self.verbosity >= 2: print(self.model) self.timer.n_epochs = n_epochs self._prepare_train(mutual_info_layers) if n_epochs == 1: self.monitor.viz.with_markers = True self.training_started() loss_epochs = [] for epoch in trange(self.timer.epoch, self.timer.epoch + n_epochs, disable=self.verbosity != 1): self.train_epoch(epoch=epoch) loss = self.full_forward_pass(train=True) self.full_forward_pass(train=False) if mask_explain_params: self.train_mask(mask_explain_params) self._epoch_finished(loss) loss_epochs.append(loss.item()) self.training_finished() return loss_epochs