Source code for mighty.utils.prepare
"""
Convert a model to the train or test mode
-----------------------------------------
.. autosummary::
:toctree: toctree/utils/
ModelMode
prepare_eval
"""
import torch.nn as nn
__all__ = [
"ModelMode",
"prepare_eval"
]
[docs]
class ModelMode:
"""
Stores the model state with its parameters to be restored later on.
Parameters
----------
mode : bool
Original model mode extracted as ``model.training``.
requires_grad : dict
A dict with keys that match ``model.named_parameters()`` dict which
store a boolean state of each model parameter.
"""
def __init__(self, mode, requires_grad):
self.mode = mode
self.requires_grad = requires_grad
[docs]
def restore(self, model):
"""
Restore the original state of the model and its parameters.
Parameters
----------
model : nn.Module
A model that was used as the input to :func:`prepare_eval`
function.
"""
model.train(self.mode)
for name, param in model.named_parameters():
param.requires_grad_(self.requires_grad[name])
[docs]
def prepare_eval(model):
"""
Sets the model and its parameters to the eval state.
Parameters
----------
model : nn.Module
An input model.
Returns
-------
ModelMode
A model mode state that can recover the original input state.
"""
mode_saved = model.training
requires_grad_saved = {}
model.eval()
for name, param in model.named_parameters():
requires_grad_saved[name] = param.requires_grad
param.requires_grad_(False)
return ModelMode(mode=mode_saved, requires_grad=requires_grad_saved)