from typing import Union
import torch
import torch.nn as nn
import torch.utils.data
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer
from mighty.loss import LossPenalty
from mighty.models import AutoencoderLinear
from mighty.monitor.monitor import MonitorAutoencoder
from mighty.utils.var_online import MeanOnline
from mighty.utils.signal import peak_to_signal_noise_ratio
from mighty.utils.common import input_from_batch, batch_to_cuda
from mighty.utils.data import DataLoader
from .embedding import TrainerEmbedding
__all__ = [
"TrainerAutoencoder"
]
[docs]
class TrainerAutoencoder(TrainerEmbedding):
"""
An unsupervised AutoEncoder trainer that not only transforms inputs to
meaningful embeddings but also aims to restore the input signal from it.
Parameters
----------
model : nn.Module
A neural network to train.
criterion : nn.Module
A loss function.
data_loader : DataLoader
A data loader.
optimizer : Optimizer
An optimizer (Adam, SGD, etc.).
scheduler : _LRScheduler or ReduceLROnPlateau, or None
A learning rate scheduler.
Default: None
accuracy_measure : AccuracyEmbedding, optional
Calculates the accuracy of embedding vectors.
Default: ``AccuracyEmbedding()``
**kwargs
Passed to the base class.
"""
watch_modules = TrainerEmbedding.watch_modules + (AutoencoderLinear,)
def __init__(self,
model: nn.Module,
criterion: nn.Module,
data_loader: DataLoader,
optimizer: Optimizer,
scheduler: Union[_LRScheduler, ReduceLROnPlateau] = None,
**kwargs):
super().__init__(model, criterion=criterion, data_loader=data_loader,
optimizer=optimizer, scheduler=scheduler, **kwargs)
def _init_monitor(self, mutual_info) -> MonitorAutoencoder:
monitor = MonitorAutoencoder(
mutual_info=mutual_info,
normalize_inverse=self.data_loader.normalize_inverse
)
return monitor
def _init_online_measures(self):
online = super()._init_online_measures()
# peak signal-to-noise ratio
online['psnr-train'] = MeanOnline()
online['psnr-test'] = MeanOnline()
return online
def _get_loss(self, batch, output):
input = input_from_batch(batch)
latent, reconstructed = output
if isinstance(self.criterion, LossPenalty):
loss = self.criterion(reconstructed, input, latent)
else:
loss = self.criterion(reconstructed, input)
return loss
def _on_forward_pass_batch(self, batch, output, train):
input = input_from_batch(batch)
latent, reconstructed = output
if isinstance(self.criterion, nn.BCEWithLogitsLoss):
reconstructed = reconstructed.sigmoid()
psnr = peak_to_signal_noise_ratio(input, reconstructed)
fold = 'train' if train else 'test'
if torch.isfinite(psnr):
self.online[f'psnr-{fold}'].update(psnr.cpu())
super()._on_forward_pass_batch(batch, latent, train)
def _epoch_finished(self, loss):
self.plot_autoencoder()
for fold in ('train', 'test'):
self.monitor.plot_psnr(self.online[f'psnr-{fold}'].get_mean(),
mode=fold)
super()._epoch_finished(loss)
[docs]
def plot_autoencoder(self):
"""
Plots AutoEncoder reconstruction.
"""
batch = self.data_loader.sample()
batch = batch_to_cuda(batch)
mode_saved = self.model.training
self.model.train(False)
with torch.no_grad():
latent, reconstructed = self._forward(batch)
if isinstance(self.criterion, nn.BCEWithLogitsLoss):
reconstructed = reconstructed.sigmoid()
self._plot_autoencoder(batch, reconstructed)
self.model.train(mode_saved)
def _plot_autoencoder(self, batch, reconstructed, mode='train'):
input = input_from_batch(batch)
self.monitor.plot_autoencoder(input, reconstructed, mode=mode)