Source code for mighty.models.serialize

from abc import ABC

import torch.nn as nn


[docs] class SerializableModule(nn.Module, ABC): """ A serializable module to easily save and restore the attributes, defined in `state_attr`. Attributes ---------- state_attr : list of str A list of module attribute names to be a part of a state dict - the result of :func:`SerializableModule.state_dict`. """ state_attr = []
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) for attribute in self.state_attr: destination[prefix + attribute] = getattr(self, attribute) return destination
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): state_dict_keys = list(state_dict.keys()) for attribute in self.state_attr: key = prefix + attribute if key in state_dict_keys: setattr(self, attribute, state_dict.pop(key)) elif strict: missing_keys.append(key) super()._load_from_state_dict(state_dict=state_dict, prefix=prefix, local_metadata=local_metadata, strict=strict, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)