Source code for mighty.trainer.gradient
import warnings
from typing import Union
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer
from mighty.utils.data import DataLoader
from .trainer import Trainer
__all__ = [
"TrainerGrad"
]
[docs]
class TrainerGrad(Trainer):
"""
The default gradient descent trainer.
Parameters
----------
model : nn.Module
A neural network to train.
criterion : nn.Module
A loss function.
data_loader : DataLoader
A data loader.
optimizer : Optimizer
An optimizer (Adam, SGD, etc.).
scheduler : _LRScheduler or ReduceLROnPlateau, or None
A learning rate scheduler.
Default: None
**kwargs
Passed to the base class.
"""
def __init__(self,
model: nn.Module,
criterion: nn.Module,
data_loader: DataLoader,
optimizer: Optimizer,
scheduler: Union[_LRScheduler, ReduceLROnPlateau] = None,
**kwargs):
super().__init__(model, criterion=criterion, data_loader=data_loader, **kwargs)
self.optimizer = optimizer
self.scheduler = scheduler
[docs]
def monitor_functions(self):
super().monitor_functions()
def learning_rate(viz):
viz.line_update(y=[group['lr'] for group in self.optimizer.param_groups], opts=dict(
xlabel='Epoch',
ylabel='Learning rate',
title='Learning rate',
ytype='log',
))
if self.scheduler is not None and self.optimizer is not None:
self.monitor.register_func(learning_rate)
[docs]
def log_trainer(self):
super().log_trainer()
if self.optimizer is not None:
optimizer_str = f"Optimizer {self.optimizer.__class__.__name__}:"
for group_id, group in enumerate(self.optimizer.param_groups):
optimizer_str += f"\n\tgroup {group_id}: lr={group['lr']}, weight_decay={group['weight_decay']}"
self.monitor.log(optimizer_str)
[docs]
def train_batch(self, batch):
if self.optimizer is not None:
self.optimizer.zero_grad()
outputs = self._forward(batch)
loss = self._get_loss(batch, outputs)
if torch.isnan(loss).item():
warnings.warn("NaN loss")
else:
loss.backward()
if self.optimizer is not None:
self.optimizer.step(closure=None)
return loss
def _epoch_finished(self, loss):
if isinstance(self.scheduler, ReduceLROnPlateau):
self.scheduler.step(metrics=loss)
elif isinstance(self.scheduler, _LRScheduler):
self.scheduler.step()
super()._epoch_finished(loss)
[docs]
def state_dict(self):
state = super().state_dict()
if self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
state['criterion'] = self.criterion.state_dict()
if self.scheduler is not None:
state['scheduler'] = self.scheduler.state_dict()
return state
[docs]
def restore(self, checkpoint_path=None, best=None, strict=True):
checkpoint_state = super().restore(checkpoint_path, best=best, strict=strict)
try:
if checkpoint_state is not None:
if self.optimizer is not None:
self.optimizer.load_state_dict(checkpoint_state['optimizer'])
self.criterion.load_state_dict(checkpoint_state['criterion'], strict=strict)
scheduler_state = checkpoint_state.get('scheduler')
if self.scheduler is not None and scheduler_state is not None:
self.scheduler.load_state_dict(scheduler_state)
except Exception as exception:
print("Couldn't restore the trained state: ", exception)
return checkpoint_state