"""
Mean and variance online measures
---------------------------------
.. autosummary::
:toctree: toctree/utils/
MeanOnline
MeanOnlineBatch
SumOnlineBatch
MeanOnlineLabels
VarianceOnline
VarianceOnlineBatch
VarianceOnlineLabels
"""
from collections import defaultdict
import torch
import torch.utils.data
__all__ = [
"MeanOnline",
"VarianceOnline",
"MeanOnlineBatch",
"SumOnlineBatch",
"VarianceOnlineBatch",
"MeanOnlineLabels",
"VarianceOnlineLabels"
]
[docs]
class MeanOnline:
"""
Online sample mean aggregate. Works with scalars, vectors, and
n-dimensional tensors.
Parameters
----------
tensor : torch.Tensor or None
The initial tensor, if provided.
"""
def __init__(self, tensor=None):
self.mean = None
self.count = 0
self.is_active = True
if tensor is not None:
self.update(tensor)
[docs]
def activate(self, is_active):
"""
Activates or deactivates the updates.
Parameters
----------
is_active : bool
New state.
"""
self.is_active = is_active
[docs]
def update(self, tensor):
"""
Update sample mean (and variance) from a batch of new values.
Parameters
----------
tensor : torch.Tensor
Next tensor sample.
"""
if not self.is_active:
return
tensor = tensor.float()
self.count += 1
if self.mean is None:
self.mean = tensor.clone()
else:
self.mean += (tensor - self.mean) / self.count
[docs]
def get_mean(self):
"""
Returns
-------
torch.Tensor
The mean of all tensors.
"""
if self.mean is None:
return None
else:
return self.mean.clone()
[docs]
def reset(self):
"""
Reset the mean and the count.
"""
self.mean = None
self.count = 0
[docs]
class VarianceOnline(MeanOnline):
"""
Welford's online algorithm for population mean and variance estimation.
"""
def __init__(self, tensor=None):
self.M2 = None
super().__init__(tensor)
[docs]
def update(self, tensor):
if not self.is_active:
return
tensor = tensor.float()
self.count += 1
if self.mean is None:
self.mean = torch.zeros_like(tensor)
self.M2 = torch.zeros_like(tensor)
delta_var = tensor - self.mean
self.mean += delta_var / self.count
delta_var.mul_(tensor - self.mean)
self.M2 += delta_var
[docs]
def get_mean_std(self, unbiased=True):
"""
Return mean and std of all samples.
Parameters
----------
unbiased : bool, optional
Biased (False) or unbiased (True) variance estimate.
Default: True
Returns
-------
mean : torch.Tensor
The mean of all samples.
std : torch.Tensor
The std of all samples.
"""
if self.mean is None:
return None, None
if self.count > 1:
count = self.count - 1 if unbiased else self.count
std = torch.sqrt(self.M2 / count)
else:
# with 1 update both biased & unbiased sample variance is zero
std = torch.zeros_like(self.mean)
return self.mean.clone(), std
[docs]
def reset(self):
super().reset()
self.var = None
[docs]
class MeanOnlineBatch(MeanOnline):
"""
Online mean measure that updates 1d vector mean from a batch of vectors
(2d tensor).
"""
[docs]
def update(self, tensor):
if not self.is_active:
return
tensor = tensor.float()
batch_size = tensor.shape[0]
self.count += batch_size
if self.mean is None:
self.mean = tensor.mean(dim=0)
else:
self.mean += (tensor.sum(dim=0) -
self.mean * batch_size) / self.count
[docs]
class SumOnlineBatch:
"""
Online sum measure.
"""
def __init__(self):
self.sum = None
self.count = 0
self.is_active = True
def activate(self, is_active):
self.is_active = is_active
def update(self, tensor: torch.Tensor):
if not self.is_active:
return
tensor = tensor.float()
self.count += tensor.shape[0]
if self.sum is None:
self.sum = tensor.sum(dim=0)
else:
self.sum += tensor.sum(dim=0)
def get_sum(self):
if self.sum is None:
return None
return self.sum.clone()
def reset(self):
self.sum = None
self.count = 0
[docs]
class VarianceOnlineBatch(VarianceOnline):
"""
Welford's online algorithm for population mean and variance estimation
from batches of 1d vectors.
"""
[docs]
def update(self, tensor):
if not self.is_active:
return
tensor = tensor.float()
batch_size = tensor.shape[0]
self.count += batch_size
if self.mean is None:
self.mean = torch.zeros_like(tensor[0])
self.M2 = torch.zeros_like(tensor[0])
delta_var = tensor - self.mean
delta_mean = tensor.sum(dim=0).sub_(self.mean * batch_size).div_(self.count)
self.mean.add_(delta_mean)
delta_var.mul_(tensor - self.mean)
self.M2 += torch.sum(delta_var, dim=0)
[docs]
class MeanOnlineLabels:
"""
Keep track of population mean for each unique class label.
Parameters
----------
cls : type, optional
The generator class of online mean: either :class:`MeanOnline` or
:class`MeanOnlineBatch`.
Default: MeanOnlineBatch
"""
def __init__(self, cls=MeanOnlineBatch):
self.online = defaultdict(cls)
self.is_active = True
def __len__(self):
return len(self.online)
[docs]
def activate(self, is_active: bool):
"""
Activates or deactivates the updates.
Parameters
----------
is_active : bool
New state.
"""
self.is_active = is_active
[docs]
def labels(self):
"""
Returns
-------
list
Unique sorted class labels.
"""
return sorted(self.online.keys())
[docs]
def update(self, tensor, labels):
"""
Update sample mean (and variance) from a batch of new values, split
by labels.
Parameters
----------
tensor : (B, V) torch.Tensor
A tensor sample.
labels : (B,) torch.Tensor
Batch labels.
"""
if not self.is_active:
return
tensor = tensor.float()
for label in labels.unique(sorted=False):
self.online[label.item()].update(tensor[labels == label])
[docs]
def get_mean_labels(self):
"""
Returns
-------
mean_sorted : (C, V) torch.Tensor
Mean tensor for each of `C` unique class labels.
labels_sorted : (C,) torch.Tensor
Class labels, associated with `mean_sorted`.
"""
if len(self) == 0:
# no updates yet
return None, None
labels_sorted = self.labels()
mean_sorted = [self.online[label].get_mean() for label in
labels_sorted]
mean_sorted = torch.stack(mean_sorted, dim=0)
return mean_sorted, labels_sorted
[docs]
def get_mean(self):
"""
Returns
-------
mean_sorted : (C, V) torch.Tensor
Mean tensor for each of `C` unique class labels.
"""
mean_sorted, _ = self.get_mean_labels()
return mean_sorted
[docs]
def reset(self):
"""
Reset the mean and the count.
"""
self.online.clear()
[docs]
class VarianceOnlineLabels(MeanOnlineLabels):
"""
Keep track of population mean and std for each unique class label.
"""
def __init__(self):
super().__init__(cls=VarianceOnlineBatch)
[docs]
def get_mean_std_labels(self, unbiased=True):
"""
Return the mean and std for each unique label individually.
Parameters
----------
unbiased : bool, optional
Biased (False) or unbiased (True) variance estimate.
Default: True
Returns
-------
mean_sorted : (C, V) torch.Tensor
Mean tensor for each of `C` unique class labels.
std_sorted : (C, V) torch.Tensor
Std tensor for each of `C` unique class labels.
labels_sorted : (C,) torch.Tensor
Class labels, associated with `mean_sorted` and `std_sorted`.
"""
if len(self) == 0:
# no updates yet
return None, None, None
labels_sorted = self.labels()
mean_std = [self.online[label].get_mean_std(unbiased) for label in
labels_sorted]
mean_sorted, std_sorted = zip(*mean_std)
mean_sorted = torch.stack(mean_sorted, dim=0)
std_sorted = torch.stack(std_sorted, dim=0)
return mean_sorted, std_sorted, labels_sorted
[docs]
def get_mean_std(self, unbiased=True):
"""
Return the mean and std for each unique label individually without
the labels themselves.
Parameters
----------
unbiased : bool, optional
Biased (False) or unbiased (True) variance estimate.
Default: True
Returns
-------
mean_sorted : (C, V) torch.Tensor
Mean tensor for each of `C` unique class labels.
std_sorted : (C, V) torch.Tensor
Std tensor for each of `C` unique class labels.
"""
mean_sorted, std_sorted, _ = self.get_mean_std_labels(unbiased)
return mean_sorted, std_sorted