Source code for mighty.utils.data.normalize

"""
Inverse normalizations of data transforms
-----------------------------------------

.. autosummary::
    :toctree: toctree/utils/

    NormalizeInverse
    get_normalize_inverse
    get_normalize
    dataset_mean_std

"""


import torch
import torch.utils.data
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, ToTensor
from tqdm import tqdm

from mighty.utils.common import input_from_batch
from mighty.utils.constants import DATA_DIR, BATCH_SIZE
from mighty.utils.var_online import VarianceOnlineBatch


[docs] class NormalizeInverse(Normalize): """ Undoes the normalization and returns reconstructed images in the input domain. Parameters ---------- mean, std : array-like Sequence of means and standard deviations for each channel. """ def __init__(self, mean, std): mean = torch.as_tensor(mean, dtype=torch.float32) std = torch.as_tensor(std, dtype=torch.float32) std_inv = 1 / (std + 1e-7) mean_inv = -mean * std_inv super().__init__(mean=mean_inv, std=std_inv, inplace=False)
[docs] def __call__(self, tensor): # (B, C, H, W) tensor dtype = tensor.dtype mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) return F.normalize(tensor, mean=mean, std=std, inplace=False)
[docs] def get_normalize_inverse(transform): """ Traverses the input `transform`, finds a class of ``Normalize``, and returns ``NormalizeInverse``. Parameters ---------- transform: Torchvision transform. Returns ------- NormalizeInverse NormalizeInverse object to undo the normalization in the input `transform` or None, if ``Normalize`` instance is not found. """ normalize = get_normalize(transform) if normalize: return NormalizeInverse(mean=normalize.mean, std=normalize.std) return None
[docs] def get_normalize(transform, normalize_cls=Normalize): """ Traverses the input `transform` and finds an instance of `normalize_cls`. Parameters ---------- transform: Torchvision transform normalize_cls : type, optional A class to look for in the input transform. Default: Normalize Returns ------- Normalize Found normalize instance or None. """ if isinstance(transform, Compose): for child in transform.transforms: norm_inv = get_normalize(child, normalize_cls) if norm_inv is not None: return norm_inv elif isinstance(transform, Normalize): return transform return None
[docs] def dataset_mean_std(dataset_cls: type, cache=True, verbose=False): """ Estimates dataset mean and std. Parameters ---------- dataset_cls : type A dataset class. cache : bool, optional Compute once (True) or every time (False). Default: True verbose : bool, optional Verbosity flag. Default: False Returns ------- mean, std : (C, H, W) torch.Tensor Channel- and pixel-wise dataset mean and std, estimated over all samples. """ mean_std_file = (DATA_DIR / "mean_std" / dataset_cls.__name__ ).with_suffix('.pt') if not cache or not mean_std_file.exists(): dataset = dataset_cls(DATA_DIR, train=True, download=True, transform=ToTensor()) loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False) var_online = VarianceOnlineBatch() for batch in tqdm( loader, desc=f"{dataset_cls.__name__}: running online mean, std", disable=not verbose): input = input_from_batch(batch) var_online.update(input) mean, std = var_online.get_mean_std() mean_std_file.parent.mkdir(exist_ok=True, parents=True) with open(mean_std_file, 'wb') as f: torch.save((mean, std), f) with open(mean_std_file, 'rb') as f: mean, std = torch.load(f) return mean, std