mighty.trainer.Trainer

class mighty.trainer.Trainer(model: Module, criterion: Module, data_loader: DataLoader, accuracy_measure: Accuracy = None, mutual_info=None, env_suffix='', checkpoint_dir=PosixPath('/home/docs/.cache/mighty/checkpoints'), verbosity=2)[source]

Trainer base class.

Parameters:
modelnn.Module

A neural network to train.

criterionnn.Module

Loss function.

data_loaderDataLoader

A data loader.

accuracy_measureAccuracy or None, optional

Calculates the accuracy from the last layer activations. If None, set to AccuracyArgmax for a classification task and AccuracyEmbedding otherwise.

if isinstance(criterion, PairLossSampler):
    accuracy_measure = AccuracyEmbedding()
else:
    # cross entropy loss
    accuracy_measure = AccuracyArgmax()

Default: None

mutual_infoMutualInfo or None, optional

A handle to compute the mutual information I(X; T) and I(Y; T) [1]. If None, don’t compute the mutual information. Default: None

env_suffixstr, optional

The suffix to add to the current environment name. Default: ‘’

checkpoint_dirPath or str, optional

The path to store the checkpoints. Default: ${HOME}/.mighty/checkpoints

verbosityint, optional
  • 0 - don’t print anything

  • 1 - show the progress with each epoch

  • 2 - show the progress with each batch

Default: 2

Notes

For the choice of mutual_info refer to https://github.com/dizcza/entropy-estimators

References

[1]

Shwartz-Ziv, R., & Tishby, N. (2017). Opening the black box of deep neural networks via information. arXiv preprint arXiv:1703.00810.

Methods

__init__(model, criterion, data_loader[, ...])

checkpoint_path([best])

Get the checkpoint path, given the mode.

full_forward_pass([train])

Fixes the model weights, evaluates the epoch score and updates the monitor.

is_test_mode()

is_unsupervised()

log_trainer()

Logs the trainer in Visdom text field.

monitor_functions()

Override this method to register Visdom callbacks on each epoch.

open_monitor([offline])

Opens a Visdom monitor.

restore([checkpoint_path, best, strict])

Restores the trainer progress and the model from the path.

save([best])

Saves the trainer and the model parameters to self.checkpoint_path(best).

state_dict()

test()

train([n_epochs, mutual_info_layers, ...])

User-entry function to train the model for n_epochs.

train_batch(batch)

The core function of a trainer to update the model parameters, given a batch.

train_epoch(epoch)

Trains an epoch.

train_mask([mask_explain_params])

Train mask to see what part of an image is crucial from the network perspective (saliency map).

training_finished()

Training is finished callback.

training_started()

Training is started callback.

update_accuracy([train])

Updates the accuracy of the model.

update_best_score(score[, score_type])

If score is greater than the self.best_score, save the model.

Attributes

epoch

The current epoch, int.

watch_modules

checkpoint_path(best=None)[source]

Get the checkpoint path, given the mode.

Parameters:
beststr or None

Tag name. If set, the path will be expanded to ".../best/tag". Default: None

Returns:
Path

Checkpoint path.

property epoch

The current epoch, int.

full_forward_pass(train=True)[source]

Fixes the model weights, evaluates the epoch score and updates the monitor.

Parameters:
trainbool

Either train (True) or test (False) batches to run. In both cases, the model is set to the evaluation regime via self.model.eval().

Returns:
losstorch.Tensor

The loss of a full forward pass.

is_unsupervised()[source]
Returns:
bool

True, if the training is unsupervised and False otherwise.

log_trainer()[source]

Logs the trainer in Visdom text field.

monitor_functions()[source]

Override this method to register Visdom callbacks on each epoch.

open_monitor(offline=False)[source]

Opens a Visdom monitor.

Parameters:
offlinebool

Online (False) or offline (True) monitoring.

restore(checkpoint_path=None, best=None, strict=True)[source]

Restores the trainer progress and the model from the path.

Parameters:
checkpoint_pathPath or None

Trainer checkpoint path to restore. If None, the default path self.checkpoint_path() is used. Default: None

beststr or None

Best or latest (refer to Trainer.checkpoint_path()).

strictbool

Strict model loading or not.

Returns:
checkpoint_statedict

The loaded state of a trainer.

save(best=None)[source]

Saves the trainer and the model parameters to self.checkpoint_path(best).

Parameters:
beststr or None

Tag name. If set, the path will be expanded to ".../best/tag". Default: None

See also

restore

restore the training progress

state_dict()[source]
Returns:
dict

A dict of the trainer state to be saved.

train(n_epochs=10, mutual_info_layers=0, mask_explain_params=None)[source]

User-entry function to train the model for n_epochs.

Parameters:
n_epochsint

The number of epochs to run. Default: 10

mutual_info_layersint, optional

Evaluate the mutual information [1] from the last mutual_info_layers layers at each epoch. If set to 0, skip the (time-consuming) mutual information estimation. Default: 0

mask_explain_paramsdict or None, optional

If not None, a dictionary with parameters for MaskTrainer, that is used to show the “saliency map” [2]. Default: None

Returns:
loss_epochslist

A list of epoch loss.

References

[1]

Shwartz-Ziv, R., & Tishby, N. (2017). Opening the black box of deep neural networks via information. arXiv preprint arXiv:1703.00810.

[2]

Fong, R. C., & Vedaldi, A. (2017). Interpretable explanations of black boxes by meaningful perturbation.

abstractmethod train_batch(batch)[source]

The core function of a trainer to update the model parameters, given a batch.

Parameters:
batchtorch.Tensor or tuple of torch.Tensor

(X, Y) or X batch of input data.

Returns:
losstorch.Tensor

The batch loss.

train_epoch(epoch)[source]

Trains an epoch.

Parameters:
epochint

Epoch ID.

train_mask(mask_explain_params={})[source]

Train mask to see what part of an image is crucial from the network perspective (saliency map).

Parameters:
mask_explain_paramsdict, optional

MaskTrainer keyword arguments.

training_finished()[source]

Training is finished callback.

This function is called right before exiting the Trainer.train() function.

training_started()[source]

Training is started callback.

This function is called before training the first epoch.

update_accuracy(train=True)[source]

Updates the accuracy of the model.

Parameters:
trainbool

Either train (True) or test (False) mode.

Returns:
accuracytorch.Tensor

A scalar with the accuracy value.

update_best_score(score, score_type='loss')[source]

If score is greater than the self.best_score, save the model.

The internal best score is updated and the current model is saved as “best” if the object’s best_score_type tag matches with its class best_score_type.

Parameters:
scorefloat

The model score at the current epoch. The higher, the better. The simplest way to use this function is set score = -loss.

score_typestr, optional

A key-word to determine the criteria for the “best” score. The name of the tag is irrelevant. Default: ‘loss’