"""
Accuracy measures
-----------------
.. autosummary::
:toctree: toctree/monitor
AccuracyArgmax
AccuracyEmbedding
"""
from abc import ABC
import torch
import torch.utils.data
import torch.nn.functional as F
from mighty.utils.var_online import MeanOnlineLabels
from mighty.utils.signal import compute_distance
__all__ = [
"calc_accuracy",
"Accuracy",
"AccuracyArgmax",
"AccuracyEmbedding"
]
def calc_accuracy(labels_true, labels_predicted) -> float:
if labels_true.ndim == labels_predicted.ndim:
accuracy = (labels_true == labels_predicted).float().mean()
else:
onehot = F.one_hot(labels_predicted, num_classes=labels_true.shape[1])
accuracy = (labels_true * onehot).sum(dim=1).float().mean()
return accuracy.item()
class Accuracy(ABC):
def __init__(self):
self.true_labels_cached = []
self.predicted_labels_cached = []
def reset(self):
"""
Resets all cached predicted and ground truth data.
"""
self.reset_labels()
def reset_labels(self):
"""
Resets predicted and ground truth **labels**.
"""
self.true_labels_cached.clear()
self.predicted_labels_cached.clear()
def partial_fit(self, outputs_batch, labels_batch):
"""
If the accuracy measure is not argmax (if the model's last layer isn't
a softmax), the output is an embedding vector, which has to be stored
and retrieved at prediction.
Parameters
----------
outputs_batch : torch.Tensor or tuple
The output of a model.
labels_batch : torch.Tensor
True labels.
"""
self.true_labels_cached.append(labels_batch.cpu())
def predict(self, outputs_test):
"""
Predict the labels, given model output.
Parameters
----------
outputs_test : torch.Tensor or tuple
The output of a model.
Returns
-------
torch.Tensor
Predicted labels.
"""
return self.predict_proba(outputs_test).argmax(dim=1)
def predict_proba(self, outputs_test):
"""
Compute label probabilities, given model output.
Parameters
----------
outputs_test : torch.Tensor or tuple
The output of a model.
Returns
-------
torch.Tensor
The probabilities of assigning to each class of shape `(., C)`,
where C is the number of classes.
"""
raise NotImplementedError
def __repr__(self):
return f"{self.__class__.__name__}({self.extra_repr()})"
def extra_repr(self):
return ''
[docs]
class AccuracyArgmax(Accuracy):
"""
Softmax accuracy.
The predicted labels are simply ``output.argmax(dim=-1)``.
"""
[docs]
def predict(self, outputs_test):
labels_predicted = outputs_test.argmax(dim=-1)
return labels_predicted
[docs]
def predict_proba(self, outputs_test):
return outputs_test.softmax(dim=1)
[docs]
def partial_fit(self, outputs_batch, labels_batch):
super().partial_fit(outputs_batch=outputs_batch,
labels_batch=labels_batch)
labels_pred = self.predict(outputs_batch)
self.predicted_labels_cached.append(labels_pred.cpu())
[docs]
class AccuracyEmbedding(Accuracy):
"""
Calculates the accuracy of embedding vectors.
The mean embedding vector is kept for each class.
Prediction is based on the closest centroid ID.
Parameters
----------
metric : str, optional
The metric to compute pairwise distances with.
Default: 'cosine'
cache : bool, optional
Cache predicted data or not.
Default: False
"""
def __init__(self, metric='cosine', cache=False):
super().__init__()
self.metric = metric
self.cache = cache
self.input_cached = []
self.centroids_dict = MeanOnlineLabels()
@property
def centroids(self):
"""
Returns
-------
torch.Tensor
`(C, N)` mean centroids tensor, where C is the number of unique
classes, and N is the hidden layer dimensionality.
"""
centroids = self.centroids_dict.get_mean()
return centroids
@property
def is_fit(self):
"""
Returns
-------
bool
Whether the accuracy predictor is fit with data or not.
"""
return len(self.centroids_dict) > 0
[docs]
def reset(self):
super().reset()
self.centroids_dict.reset()
self.input_cached.clear()
def extra_repr(self):
return f'metric={self.metric}, cache={self.cache}'
[docs]
def distances(self, outputs_test):
"""
Returns the distances to fit centroid means.
Parameters
----------
outputs_test : (B, D) torch.Tensor
Hidden layer activations.
Returns
-------
distances : (B, C) torch.Tensor
Distances to each class (label).
"""
assert len(self.centroids_dict) > 0, "Fit the classifier first"
centroids = torch.as_tensor(self.centroids, device=outputs_test.device)
distances = []
outputs_test = outputs_test.unsqueeze(dim=1) # (B, 1, D)
centroids = centroids.unsqueeze(dim=0) # (1, n_classes, D)
for centroids_chunk in centroids.split(split_size=50, dim=1):
# memory efficient
distances_chunk = compute_distance(input1=outputs_test,
input2=centroids_chunk,
metric=self.metric, dim=2)
distances.append(distances_chunk)
distances = torch.cat(distances, dim=1)
return distances
[docs]
def partial_fit(self, outputs_batch, labels_batch):
super().partial_fit(outputs_batch=outputs_batch,
labels_batch=labels_batch)
outputs_batch = outputs_batch.detach()
self.centroids_dict.update(outputs_batch, labels_batch)
if self.cache:
self.input_cached.append(outputs_batch.cpu())
[docs]
def predict_cached(self):
"""
Predicts the output of a model, using cached output activations.
Returns
-------
torch.Tensor
Predicted labels.
"""
if not self.cache:
raise ValueError("Caching is turned off")
if len(self.input_cached) == 0:
raise ValueError("Empty cached input buffer")
input_cached = torch.cat(self.input_cached, dim=0)
return self.predict(input_cached)
[docs]
def predict(self, outputs_test):
argmin = self.distances(outputs_test).argmin(dim=1).cpu()
labels_stored = self.centroids_dict.labels()
labels_stored = torch.IntTensor(labels_stored)
labels_predicted = labels_stored[argmin]
return labels_predicted.to(device=outputs_test.device)
[docs]
def predict_proba(self, outputs_test):
distances = self.distances(outputs_test)
proba = 1 - distances / distances.sum(dim=1).unsqueeze(1)
return proba.to(device=outputs_test.device)