mighty.trainer.TrainerEmbedding

class mighty.trainer.TrainerEmbedding(model: Module, criterion: Module, data_loader: DataLoader, optimizer: Optimizer, scheduler: _LRScheduler | ReduceLROnPlateau = None, accuracy_measure: Accuracy = AccuracyEmbedding(metric=cosine, cache=False), **kwargs)[source]

An (unsupervised) trainer that transforms input data into linearly-separable embedding vectors that form clusters.

Parameters:
modelnn.Module

A neural network to train.

criterionnn.Module

A loss function.

data_loaderDataLoader

A data loader.

optimizerOptimizer

An optimizer (Adam, SGD, etc.).

scheduler_LRScheduler or ReduceLROnPlateau, or None

A learning rate scheduler. Default: None

accuracy_measureAccuracyEmbedding, optional

Calculates the accuracy of embedding vectors. Default: AccuracyEmbedding()

**kwargs

Passed to the base class.

Methods

__init__(model, criterion, data_loader, ...)

checkpoint_path([best])

Get the checkpoint path, given the mode.

full_forward_pass([train])

Fixes the model weights, evaluates the epoch score and updates the monitor.

is_test_mode()

is_unsupervised()

log_trainer()

Logs the trainer in Visdom text field.

monitor_functions()

Override this method to register Visdom callbacks on each epoch.

open_monitor([offline])

Opens a Visdom monitor.

restore([checkpoint_path, best, strict])

Restores the trainer progress and the model from the path.

save([best])

Saves the trainer and the model parameters to self.checkpoint_path(best).

state_dict()

test()

train([n_epochs, mutual_info_layers, ...])

User-entry function to train the model for n_epochs.

train_batch(batch)

The core function of a trainer to update the model parameters, given a batch.

train_epoch(epoch)

Trains an epoch.

train_mask([mask_explain_params])

Train mask to see what part of an image is crucial from the network perspective (saliency map).

training_finished()

Training is finished callback.

training_started()

Training is started callback.

update_accuracy([train])

Updates the accuracy of the model.

update_best_score(score[, score_type])

If score is greater than the self.best_score, save the model.

Attributes

epoch

The current epoch, int.

watch_modules

checkpoint_path(best=None)

Get the checkpoint path, given the mode.

Parameters:
beststr or None

Tag name. If set, the path will be expanded to ".../best/tag". Default: None

Returns:
Path

Checkpoint path.

property epoch

The current epoch, int.

full_forward_pass(train=True)

Fixes the model weights, evaluates the epoch score and updates the monitor.

Parameters:
trainbool

Either train (True) or test (False) batches to run. In both cases, the model is set to the evaluation regime via self.model.eval().

Returns:
losstorch.Tensor

The loss of a full forward pass.

is_unsupervised()
Returns:
bool

True, if the training is unsupervised and False otherwise.

log_trainer()

Logs the trainer in Visdom text field.

monitor_functions()

Override this method to register Visdom callbacks on each epoch.

open_monitor(offline=False)

Opens a Visdom monitor.

Parameters:
offlinebool

Online (False) or offline (True) monitoring.

restore(checkpoint_path=None, best=None, strict=True)

Restores the trainer progress and the model from the path.

Parameters:
checkpoint_pathPath or None

Trainer checkpoint path to restore. If None, the default path self.checkpoint_path() is used. Default: None

beststr or None

Best or latest (refer to Trainer.checkpoint_path()).

strictbool

Strict model loading or not.

Returns:
checkpoint_statedict

The loaded state of a trainer.

save(best=None)

Saves the trainer and the model parameters to self.checkpoint_path(best).

Parameters:
beststr or None

Tag name. If set, the path will be expanded to ".../best/tag". Default: None

See also

restore

restore the training progress

state_dict()
Returns:
dict

A dict of the trainer state to be saved.

train(n_epochs=10, mutual_info_layers=0, mask_explain_params=None)

User-entry function to train the model for n_epochs.

Parameters:
n_epochsint

The number of epochs to run. Default: 10

mutual_info_layersint, optional

Evaluate the mutual information [1] from the last mutual_info_layers layers at each epoch. If set to 0, skip the (time-consuming) mutual information estimation. Default: 0

mask_explain_paramsdict or None, optional

If not None, a dictionary with parameters for MaskTrainer, that is used to show the “saliency map” [2]. Default: None

Returns:
loss_epochslist

A list of epoch loss.

References

[1]

Shwartz-Ziv, R., & Tishby, N. (2017). Opening the black box of deep neural networks via information. arXiv preprint arXiv:1703.00810.

[2]

Fong, R. C., & Vedaldi, A. (2017). Interpretable explanations of black boxes by meaningful perturbation.

train_batch(batch)

The core function of a trainer to update the model parameters, given a batch.

Parameters:
batchtorch.Tensor or tuple of torch.Tensor

(X, Y) or X batch of input data.

Returns:
losstorch.Tensor

The batch loss.

train_epoch(epoch)

Trains an epoch.

Parameters:
epochint

Epoch ID.

train_mask(mask_explain_params={})

Train mask to see what part of an image is crucial from the network perspective (saliency map).

Parameters:
mask_explain_paramsdict, optional

MaskTrainer keyword arguments.

training_finished()

Training is finished callback.

This function is called right before exiting the Trainer.train() function.

training_started()

Training is started callback.

This function is called before training the first epoch.

update_accuracy(train=True)

Updates the accuracy of the model.

Parameters:
trainbool

Either train (True) or test (False) mode.

Returns:
accuracytorch.Tensor

A scalar with the accuracy value.

update_best_score(score, score_type='loss')

If score is greater than the self.best_score, save the model.

The internal best score is updated and the current model is saved as “best” if the object’s best_score_type tag matches with its class best_score_type.

Parameters:
scorefloat

The model score at the current epoch. The higher, the better. The simplest way to use this function is set score = -loss.

score_typestr, optional

A key-word to determine the criteria for the “best” score. The name of the tag is irrelevant. Default: ‘loss’