Source code for mighty.utils.hooks

"""
Layer hooks
-----------

.. autosummary::
    :toctree: toctree/utils/

    get_layers_ordered

"""

import pickle
import shutil
from pathlib import Path

import torch
import torch.nn as nn

from mighty.utils.common import batch_to_cuda
from mighty.utils.constants import DATA_DIR

DUMPS_DIR = DATA_DIR / "dumps"

__all__ = [
    "get_layers_ordered",
    "DumpActivationsHook"
]


[docs] def get_layers_ordered(model, input_sample, ignore_layers=(nn.Sequential,), ignore_children=()): """ Returns a list of ordered layers of the input model. Parameters ---------- model : nn.Module An input model. input_sample : torch.Tensor A sample tensor to be used with the model. ignore_layers : tuple of type A tuple of model classes not to include in the final result. Default: (nn.Sequential,) ignore_children : tuple of type A tuple of model classes to skip entering their children. Default: () Returns ------- layers_ordered : list of torch.Tensor An ordered list of layers of the input model. Note that some modules might be added in the list more than once. """ hooks = [] layers_ordered = [] def register_hooks(a_model: nn.Module): children = tuple(a_model.children()) if any(children) and not isinstance(a_model, ignore_children): for layer in children: register_hooks(layer) if not (isinstance(a_model, ignore_layers) or a_model is model): handle = a_model.register_forward_pre_hook(append_layer) hooks.append(handle) def append_layer(layer, tensor_input): layers_ordered.append(layer) register_hooks(model) model_params = tuple(model.parameters()) device = 'cpu' if len(model_params) == 0 else model_params[0].device.type if device != 'cpu': if isinstance(input_sample, torch.Tensor): input_sample = input_sample.to(device=device) else: # iterable input_sample = [t.to(device=device) for t in input_sample] with torch.no_grad(): try: model(input_sample) except Exception as e: layers_ordered.clear() model(input_sample.unsqueeze(dim=0)) for handle in hooks: handle.remove() if not any(layers_ordered): layers_ordered = [model] return layers_ordered
class DumpActivationsHook: """ A use-case for :func:`get_layers_ordered`. """ def __init__(self, model: nn.Module, inspect_layers=(nn.Linear, nn.Conv2d), dumps_dir=DUMPS_DIR): self.hooks = [] self.layer_to_name = {} self.inspect_layers = inspect_layers self.dumps_dir = Path(dumps_dir) / model._get_name() shutil.rmtree(self.dumps_dir, ignore_errors=True) self.dumps_dir.mkdir(parents=True) self.register_hooks(model) print(f"Dumping activations from {self.layer_to_name.values()} layers " f"to {self.dumps_dir}.") def register_hooks(self, model: nn.Module, prefix=''): children = tuple(model.named_children()) if any(children): for name, layer in children: self.register_hooks(layer, prefix=f"{prefix}.{name}") elif isinstance(model, self.inspect_layers): self.layer_to_name[model] = prefix.lstrip('.') handle = model.register_forward_hook(self.dump_activations) self.hooks.append(handle) def dump_activations(self, layer, tensor_input, tensor_output): layer_name = self.layer_to_name[layer] layer_path = self.dumps_dir / layer_name activations_input_path = f"{layer_path}_inp.pkl" activations_output_path = f"{layer_path}_out.pkl" if isinstance(tensor_input, tuple): assert len(tensor_input) == 1, "Expected only 1 input tensor" tensor_input = tensor_input[0] with open(activations_input_path, 'ab') as f: pickle.dump(tensor_input.detach().cpu(), f) with open(activations_output_path, 'ab') as f: pickle.dump(tensor_output.detach().cpu(), f) def remove_hooks(self): for handle in self.hooks: handle.remove() self.hooks = []