Source code for mighty.loss.triplet

import torch
import torch.nn as nn
import torch.nn.functional as F

from mighty.loss.contrastive import PairLossSampler


[docs] class TripletCosineLoss(nn.Module): r""" Creates a criterion that measures the cosine triplet loss given an input tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. This is used for measuring a relative similarity between samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative examples` respectively). The shapes of all input tensors should be :math:`(N, D)`. The loss function for each sample in the mini-batch is: .. math:: L(a, p, n) = \max \{cos(a_i, n_i) - cos(a_i, p_i) + {\rm margin}, 0\} """ def __init__(self, margin=0.): super().__init__() self.margin = margin
[docs] def forward(self, anchor, positive, negative): sim_positive = F.cosine_similarity(anchor, positive, dim=1) sim_negative = F.cosine_similarity(anchor, negative, dim=1) loss = torch.relu(sim_negative - sim_positive + self.margin).mean() return loss
[docs] class TripletLossSampler(PairLossSampler): """ TripletLoss [1]_ with random sampling of triplets out of the conventional :code:`(outputs, labels)` batch. A convenient convertor of :code:`(outputs, labels)` batch into :code:`(anchor, same, other)` triplets. Parameters ---------- criterion : nn.Module Triplet Loss module (e.g., ``nn.TripletMarginLoss``) to compute the loss, followed by pairs sampling. pairs_multiplier : int, optional Defines how many pairs to create from a single sample. The typical range is ``[1, 10]``. Default: 1 References ---------- .. [1] Schroff, F., Kalenichenko, D., & Philbin, J. (2015). Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 815-823). """
[docs] def forward(self, outputs, labels): n_samples = len(outputs) pairs_to_sample = (self.pairs_to_sample(labels),) anchor = torch.randint(low=0, high=n_samples, size=pairs_to_sample, device=outputs.device) positive = torch.randint(low=0, high=n_samples, size=pairs_to_sample, device=outputs.device) negative = torch.randint(low=0, high=n_samples, size=pairs_to_sample, device=outputs.device) triplets = (anchor != positive) & \ (labels[anchor] == labels[positive]) & \ (labels[anchor] != labels[negative]) anchor = anchor[triplets] positive = positive[triplets] negative = negative[triplets] loss = self.criterion(outputs[anchor], outputs[positive], outputs[negative]) self._check_non_nan(loss) return loss