Source code for mighty.loss.contrastive

from abc import ABC

import torch
import torch.nn as nn
import warnings


class PairLossSampler(nn.Module, ABC):
    """
    A base class for TripletLossSampler and ContrastiveLossSampler.

    Parameters
    ----------
    criterion : nn.Module
        A criterion module to compute the loss, followed by pairs sampling.
    pairs_multiplier : int, optional
        Defines how many pairs to create from a single sample. The typical
        range is ``[1, 10]``.
        Default: 1
    """

    def __init__(self, criterion: nn.Module, pairs_multiplier: int = 1):
        super().__init__()
        self.criterion = criterion
        self.pairs_multiplier = pairs_multiplier

    def extra_repr(self):
        margin = getattr(self.criterion, 'margin')
        margin_str = '' if margin is None else f", criterion.margin={margin}"
        return f"pairs_multiplier={self.pairs_multiplier}{margin_str}"

    def pairs_to_sample(self, labels):
        """
        Estimates how many pairs to sample to a batch of `labels`.

        The probability of two random samples having the same class is
        :code:`1/n_classes`. On average, each single sample in a batch
        produces :code:`1/n_classes` pairs or
        :code:`1/n_classes * (1 - 1/n_classes)` triplets.

        Parameters
        ----------
        labels : (B,) torch.LongTensor
            A batch of labels.

        Returns
        -------
        n_random_pairs : int
            An estimated number of random permutations to sample to get the
            desired number of pairs or triplets.
        """
        batch_size = len(labels)
        n_unique = len(labels.unique(sorted=False))
        n_random_pairs = self.pairs_multiplier * n_unique * batch_size
        return n_random_pairs

    def _check_non_nan(self, loss: torch.Tensor):
        if torch.isnan(loss):
            warnings.warn("Loss evaluated to NaN probably because there were "
                          "no pairs to sample. Increase the "
                          "'pairs_multiplier'.")

    def forward(self, outputs, labels):
        """
        Converts the input batch into pairs or triplets and computes
        Contrastive or Triplet Loss.

        Parameters
        ----------
        outputs : (B, N) torch.Tensor
            The output of a model.
        labels : (B,) torch.LongTensor
            A batch of the true labels.

        Returns
        -------
        loss : torch.Tensor
            Loss scalar tensor.
        """
        raise NotImplementedError


[docs] class ContrastiveLossSampler(PairLossSampler): """ Contrastive Loss [1]_ with random sampling of vector pairs out of the conventional :code:`(outputs, labels)` batch. A convenient convertor of :code:`(outputs, labels)` batch into :code:`(outputs, target)` same-same and same-other pairs. Parameters ---------- criterion : nn.Module Contrastive Loss module (e.g., ``nn.CosineEmbeddingLoss``) to compute the loss, followed by pairs sampling. pairs_multiplier : int, optional Defines how many pairs to create from a single sample. The typical range is ``[1, 10]``. Default: 1 References ---------- .. [1] Hadsell, R., Chopra, S., & LeCun, Y. (2006, June). Dimensionality reduction by learning an invariant mapping. In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06) (Vol. 2, pp. 1735-1742). IEEE. """
[docs] def forward(self, outputs, labels): n_samples = len(outputs) assert n_samples > 1, "Cannot sample pairs" pairs_to_sample = (self.pairs_to_sample(labels),) left_indices = torch.randint(low=0, high=n_samples, size=pairs_to_sample, device=outputs.device) right_indices = torch.randint(low=0, high=n_samples, size=pairs_to_sample, device=outputs.device) # exclude (a, a) pairs indices_different = left_indices != right_indices left_indices = left_indices[indices_different] right_indices = right_indices[indices_different] # exclude [(a, b), (b, a)] duplicate pairs left_indices, right_indices = torch.stack( [left_indices, right_indices], dim=1).sort(dim=1).values.sort( dim=0).values.unique(sorted=False, dim=0).t() is_same = labels[left_indices] == labels[right_indices] y_target = 2 * is_same - 1 loss = self.criterion(outputs[left_indices], outputs[right_indices], y_target) self._check_non_nan(loss) return loss