mighty.utils.var_online.MeanOnlineLabels

class mighty.utils.var_online.MeanOnlineLabels(cls=<class 'mighty.utils.var_online.MeanOnlineBatch'>)[source]

Keep track of population mean for each unique class label.

Parameters:
clstype, optional

The generator class of online mean: either MeanOnline or :class`MeanOnlineBatch`. Default: MeanOnlineBatch

Methods

__init__([cls])

activate(is_active)

Activates or deactivates the updates.

get_mean()

Returns:

get_mean_labels()

Returns:

labels()

Returns:

reset()

Reset the mean and the count.

update(tensor, labels)

Update sample mean (and variance) from a batch of new values, split by labels.

activate(is_active: bool)[source]

Activates or deactivates the updates.

Parameters:
is_activebool

New state.

get_mean()[source]
Returns:
mean_sorted(C, V) torch.Tensor

Mean tensor for each of C unique class labels.

get_mean_labels()[source]
Returns:
mean_sorted(C, V) torch.Tensor

Mean tensor for each of C unique class labels.

labels_sorted(C,) torch.Tensor

Class labels, associated with mean_sorted.

labels()[source]
Returns:
list

Unique sorted class labels.

reset()[source]

Reset the mean and the count.

update(tensor, labels)[source]

Update sample mean (and variance) from a batch of new values, split by labels.

Parameters:
tensor(B, V) torch.Tensor

A tensor sample.

labels(B,) torch.Tensor

Batch labels.