mighty.trainer.TrainerEmbedding¶
- class mighty.trainer.TrainerEmbedding(model: Module, criterion: Module, data_loader: DataLoader, optimizer: Optimizer, scheduler: _LRScheduler | ReduceLROnPlateau = None, accuracy_measure: Accuracy = AccuracyEmbedding(metric=cosine, cache=False), **kwargs)[source]¶
An (unsupervised) trainer that transforms input data into linearly-separable embedding vectors that form clusters.
- Parameters:
- modelnn.Module
A neural network to train.
- criterionnn.Module
A loss function.
- data_loaderDataLoader
A data loader.
- optimizerOptimizer
An optimizer (Adam, SGD, etc.).
- scheduler_LRScheduler or ReduceLROnPlateau, or None
A learning rate scheduler. Default: None
- accuracy_measureAccuracyEmbedding, optional
Calculates the accuracy of embedding vectors. Default:
AccuracyEmbedding()- **kwargs
Passed to the base class.
Methods
__init__(model, criterion, data_loader, ...)checkpoint_path([best])Get the checkpoint path, given the mode.
full_forward_pass([train])Fixes the model weights, evaluates the epoch score and updates the monitor.
is_test_mode()Logs the trainer in Visdom text field.
Override this method to register Visdom callbacks on each epoch.
open_monitor([offline])Opens a Visdom monitor.
restore([checkpoint_path, best, strict])Restores the trainer progress and the model from the path.
save([best])Saves the trainer and the model parameters to
self.checkpoint_path(best).test()train([n_epochs, mutual_info_layers, ...])User-entry function to train the model for
n_epochs.train_batch(batch)The core function of a trainer to update the model parameters, given a batch.
train_epoch(epoch)Trains an epoch.
train_mask([mask_explain_params])Train mask to see what part of an image is crucial from the network perspective (saliency map).
Training is finished callback.
Training is started callback.
update_accuracy([train])Updates the accuracy of the model.
update_best_score(score[, score_type])If
scoreis greater than theself.best_score, save the model.Attributes
The current epoch, int.
watch_modules- checkpoint_path(best=None)¶
Get the checkpoint path, given the mode.
- Parameters:
- beststr or None
Tag name. If set, the path will be expanded to
".../best/tag". Default: None
- Returns:
- Path
Checkpoint path.
- property epoch¶
The current epoch, int.
- full_forward_pass(train=True)¶
Fixes the model weights, evaluates the epoch score and updates the monitor.
- Parameters:
- trainbool
Either train (True) or test (False) batches to run. In both cases, the model is set to the evaluation regime via self.model.eval().
- Returns:
- losstorch.Tensor
The loss of a full forward pass.
- is_unsupervised()¶
- Returns:
- bool
True, if the training is unsupervised and False otherwise.
- log_trainer()¶
Logs the trainer in Visdom text field.
- monitor_functions()¶
Override this method to register Visdom callbacks on each epoch.
- open_monitor(offline=False)¶
Opens a Visdom monitor.
- Parameters:
- offlinebool
Online (False) or offline (True) monitoring.
- restore(checkpoint_path=None, best=None, strict=True)¶
Restores the trainer progress and the model from the path.
- Parameters:
- checkpoint_pathPath or None
Trainer checkpoint path to restore. If None, the default path
self.checkpoint_path()is used. Default: None- beststr or None
Best or latest (refer to
Trainer.checkpoint_path()).- strictbool
Strict model loading or not.
- Returns:
- checkpoint_statedict
The loaded state of a trainer.
- save(best=None)¶
Saves the trainer and the model parameters to
self.checkpoint_path(best).- Parameters:
- beststr or None
Tag name. If set, the path will be expanded to
".../best/tag". Default: None
See also
restorerestore the training progress
- state_dict()¶
- Returns:
- dict
A dict of the trainer state to be saved.
- train(n_epochs=10, mutual_info_layers=0, mask_explain_params=None)¶
User-entry function to train the model for
n_epochs.- Parameters:
- n_epochsint
The number of epochs to run. Default: 10
- mutual_info_layersint, optional
Evaluate the mutual information [1] from the last
mutual_info_layerslayers at each epoch. If set to 0, skip the (time-consuming) mutual information estimation. Default: 0- mask_explain_paramsdict or None, optional
If not None, a dictionary with parameters for
MaskTrainer, that is used to show the “saliency map” [2]. Default: None
- Returns:
- loss_epochslist
A list of epoch loss.
References
- train_batch(batch)¶
The core function of a trainer to update the model parameters, given a batch.
- Parameters:
- batchtorch.Tensor or tuple of torch.Tensor
(X, Y)orXbatch of input data.
- Returns:
- losstorch.Tensor
The batch loss.
- train_epoch(epoch)¶
Trains an epoch.
- Parameters:
- epochint
Epoch ID.
- train_mask(mask_explain_params={})¶
Train mask to see what part of an image is crucial from the network perspective (saliency map).
- Parameters:
- mask_explain_paramsdict, optional
MaskTrainer keyword arguments.
- training_finished()¶
Training is finished callback.
This function is called right before exiting the
Trainer.train()function.
- training_started()¶
Training is started callback.
This function is called before training the first epoch.
- update_accuracy(train=True)¶
Updates the accuracy of the model.
- Parameters:
- trainbool
Either train (True) or test (False) mode.
- Returns:
- accuracytorch.Tensor
A scalar with the accuracy value.
- update_best_score(score, score_type='loss')¶
If
scoreis greater than theself.best_score, save the model.The internal best score is updated and the current model is saved as “best” if the object’s
best_score_typetag matches with its classbest_score_type.- Parameters:
- scorefloat
The model score at the current epoch. The higher, the better. The simplest way to use this function is set
score = -loss.- score_typestr, optional
A key-word to determine the criteria for the “best” score. The name of the tag is irrelevant. Default: ‘loss’