This notebook discusses supervised contrastive loss (SupCon) and its related NT-Xent loss. The original paper is at https://arxiv.org/abs/2004.11362.
The SupCon paper gives the definition of SupCon loss like this: $$\mathcal{L} = \underset{i\in I}{\Sigma}\frac{-1}{|P(i)|}\underset{p\in P(i)}{\Sigma}log\frac{exp(z_i\cdot z_p/\tau)}{\underset{a\in A(i)}{\Sigma}exp(z_i\cdot z_a/\tau)}$$ Here $A(i)\equiv I\backslash \{i\}$ and $P(i) ≡ \{p \in A(i) : \hat{y}_p = \hat{y}_i\}$ is the set of indices of all positives in the multiviewed batch distinct from $i$, and $|P(i)|$ is its cardinality.
NT-Xent loss is defined as: $$\mathcal{L}=-\underset{i\in I}{\Sigma}log\frac{exp(z_i\cdot z_{j(i)}/\tau)}{\underset{a\in A(i)}{\Sigma}exp(z_i\cdot z_a/\tau)}$$ Essentially, NT-Xent loss deals with one positive pair and multiple negative pairs in the batch contrastive approach while SupCon loss extends NT-Xent loss to allow multiple positive and negative pairs in the multiviewed batch, leveraging class label information. The paper shows pretraining with SupCon consistently outperforms cross-entropy with standard data augmentations.
The focus of this notebook is on the math of the SupCon and NT-Xent losses. What are they really calculating? Which pairs are the positive and negative pairs? There are differences in the implementation of these losses between the SupCon paper (github) and the PyTorch Metric Learning package (github). I will demonstrate that with naive numpy/scipy implementation.

In [1]:
import numpy as np
from scipy import spatial
from scipy.special import softmax
import torch 
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning import losses

Below is the SupCon loss implementation in PyTorch by the original authors.

In [2]:
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""

    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError(
                    'Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

2 class x 2 member example¶

To understand what SupCon or NT-Xent loss is calculating, I'm making up a simple example with 2 classes each having 2 members, a total of 4 data points each in $R^5$.

In [3]:
embeddings = torch.randn(4,5)
embeddings
Out[3]:
tensor([[-1.0192, -1.2504,  0.6781, -1.4978, -0.2036],
        [-0.3796, -0.6767,  0.7013, -1.0596, -0.9061],
        [ 0.4155, -0.4153, -1.3006,  0.6536, -1.1877],
        [ 0.5762,  1.0520, -0.2545, -0.2044,  0.9800]])

Their labels are known. Let's assume that a,c are of one class and b,d are of the other class.
image.png

In [4]:
labels = torch.tensor([0,1,0,1])

PyTorch Metric Learning calculation¶

In [5]:
# SupCon
loss_func = losses.SupConLoss(temperature=1)
loss_func(embeddings, labels)
Out[5]:
tensor(1.5018)
In [6]:
# NT-Xent
loss_func = losses.NTXentLoss(temperature=1)
loss_func(embeddings, labels)
Out[6]:
tensor(1.5018)

SupCon paper calculation¶

To use the SupCon loss implemented by the original authors, we first need to rearrange the data to a 3-D tensor. Because the implementation uses dot products as the scoring function, we need to normalize the data to be consistent with other methods.

In [7]:
e1, e2 = torch.split(embeddings, [2, 2], dim=0)
features = torch.cat([e1.unsqueeze(1), e2.unsqueeze(1)], dim=1)
features
Out[7]:
tensor([[[-1.0192, -1.2504,  0.6781, -1.4978, -0.2036],
         [ 0.4155, -0.4153, -1.3006,  0.6536, -1.1877]],

        [[-0.3796, -0.6767,  0.7013, -1.0596, -0.9061],
         [ 0.5762,  1.0520, -0.2545, -0.2044,  0.9800]]])
In [8]:
features = F.normalize(features,dim=2)
features
Out[8]:
tensor([[[-0.4408, -0.5408,  0.2932, -0.6477, -0.0880],
         [ 0.2111, -0.2110, -0.6607,  0.3321, -0.6034]],

        [[-0.2178, -0.3883,  0.4024, -0.6080, -0.5199],
         [ 0.3640,  0.6646, -0.1608, -0.1291,  0.6191]]])

Labels need to reformatted too.

In [9]:
labels = torch.tensor([0,1])
In [10]:
# SupCon
criterion = SupConLoss(temperature=1,base_temperature=1,contrast_mode="all")
criterion(features,labels)
Out[10]:
tensor(1.5018)
In [11]:
# NT-Xent
criterion = SupConLoss(temperature=1,base_temperature=1,contrast_mode="all")
criterion(features)
Out[11]:
tensor(1.5018)

The two approaches gave the same result for both losses. But what are they really calculating? I did in a naive way using numpy and scipy to show just that.

numpy scipy calculation¶

First, I calculate pairwise cosine similarity between the data points. I fill in the diagonal of the distance matrix with negative infinity to mask out self cosine similarity in the calculation of softmax in the next step.

In [12]:
mat = 1-spatial.distance.cdist(embeddings,embeddings,'cosine')
np.fill_diagonal(mat, -np.inf)
mat
Out[12]:
array([[       -inf,  0.8635144 , -0.33467347, -0.53787604],
       [ 0.8635144 ,        -inf, -0.11811661, -0.64539926],
       [-0.33467347, -0.11811661,        -inf, -0.37355578],
       [-0.53787604, -0.64539926, -0.37355578,        -inf]])

Then I calculate the softmax of each pair. Below using the softmax function from scipy I calculate the softmax of both positive and negative pairs but later I only need the softmax of the positive pairs. This is essentially calculating $\frac{exp(z_i\cdot z_p/\tau)}{\underset{a\in A(i)}{\Sigma}exp(z_i\cdot z_a/\tau)}$ (for simplicity temperature $\tau$ is set to 1 across all calculations).

In [13]:
softmax_mat = softmax(mat,axis = 1)
softmax_mat
Out[13]:
array([[0.        , 0.64599699, 0.19492346, 0.15907955],
       [0.62662548, 0.        , 0.23479623, 0.13857828],
       [0.31214322, 0.38761748, 0.        , 0.3002393 ],
       [0.32502912, 0.29189425, 0.38307663, 0.        ]])
In [14]:
# SupCon == NT-Xent
-(np.log(softmax_mat[0,2])+ np.log(softmax_mat[1,3])+np.log(softmax_mat[2,0])+np.log(softmax_mat[3,1]))/4
Out[14]:
1.5017812656385385

So now we know what the SupCon loss is calculating:
image.png
For simplicity I omitted the exponential and logarithm operations and only focused on what the positive and negative pairs are. Because there are two classes each with two members, here NT-Xent loss is equivalent to SupCon loss.

3 class x 3 member example¶

This example coveres a more complicated situation where we have 3 classes each with 3 members.
image.png

In [15]:
embeddings = torch.randn(9,5)
embeddings
Out[15]:
tensor([[-0.2212, -3.0712,  0.7776,  1.9826,  0.6233],
        [ 0.6597,  1.4035,  0.2884, -0.8971, -0.5095],
        [ 0.5410,  1.2093,  0.8458,  1.2963, -0.3695],
        [ 1.2558, -1.5282, -1.6592, -0.0406, -1.7611],
        [ 0.2712,  0.6679, -0.3595, -0.5791, -1.6139],
        [ 0.6026, -0.4435, -0.7715, -0.2869,  0.1029],
        [-0.5828, -0.4375, -0.5919,  0.3499,  0.2800],
        [-1.1419, -1.1917,  0.5485, -0.2690,  1.1158],
        [ 0.6038, -1.4093, -0.1854, -1.5674, -2.6291]])
In [16]:
labels = torch.tensor([0,1,2,0,1,2,0,1,2])

PyTorch Metric Learning calculation¶

In [17]:
# SupCon
loss_func = losses.SupConLoss(temperature=1)
loss_func(embeddings, labels)
Out[17]:
tensor(2.1960)
In [18]:
# NT-Xent
loss_func = losses.NTXentLoss(temperature=1)
loss_func(embeddings, labels)
Out[18]:
tensor(2.0615)

SupCon paper calculation¶

In [19]:
e1, e2, e3 = torch.split(embeddings, [3, 3, 3], dim=0)
features = torch.cat([e1.unsqueeze(1), e2.unsqueeze(1), e3.unsqueeze(1)], dim=1)
features
Out[19]:
tensor([[[-0.2212, -3.0712,  0.7776,  1.9826,  0.6233],
         [ 1.2558, -1.5282, -1.6592, -0.0406, -1.7611],
         [-0.5828, -0.4375, -0.5919,  0.3499,  0.2800]],

        [[ 0.6597,  1.4035,  0.2884, -0.8971, -0.5095],
         [ 0.2712,  0.6679, -0.3595, -0.5791, -1.6139],
         [-1.1419, -1.1917,  0.5485, -0.2690,  1.1158]],

        [[ 0.5410,  1.2093,  0.8458,  1.2963, -0.3695],
         [ 0.6026, -0.4435, -0.7715, -0.2869,  0.1029],
         [ 0.6038, -1.4093, -0.1854, -1.5674, -2.6291]]])
In [20]:
features = F.normalize(features,dim=2)
features
Out[20]:
tensor([[[-0.0583, -0.8092,  0.2049,  0.5224,  0.1642],
         [ 0.4018, -0.4890, -0.5309, -0.0130, -0.5635],
         [-0.5602, -0.4205, -0.5690,  0.3363,  0.2692]],

        [[ 0.3500,  0.7446,  0.1530, -0.4759, -0.2703],
         [ 0.1432,  0.3526, -0.1897, -0.3057, -0.8519],
         [-0.5480, -0.5719,  0.2632, -0.1291,  0.5355]],

        [[ 0.2613,  0.5840,  0.4085,  0.6261, -0.1785],
         [ 0.5394, -0.3970, -0.6906, -0.2568,  0.0921],
         [ 0.1761, -0.4111, -0.0541, -0.4572, -0.7669]]])
In [21]:
labels = torch.tensor([0,1,2])
In [22]:
# SupCon
criterion = SupConLoss(temperature=1,base_temperature=1,contrast_mode="all")
criterion(features,labels)
Out[22]:
tensor(2.1960)
In [23]:
# NT-Xent
criterion = SupConLoss(temperature=1,base_temperature=1,contrast_mode="all")
criterion(features)
Out[23]:
tensor(2.1960)

The two approaches for SupCon calculation agree with each other but the results for NT-Xent loss are different. I will show what each loss is calculating with numpy and scipy below.

numpy scipy calculation¶

In [24]:
mat = 1-spatial.distance.cdist(embeddings,embeddings,'cosine')
np.fill_diagonal(mat, -np.inf)
mat
Out[24]:
array([[       -inf, -0.88462103, -0.10639146,  0.16416448, -0.63209395,
         0.02929249,  0.47625965,  0.56912093, -0.05346975],
       [-0.88462103,        -inf,  0.33909669, -0.14616784,  0.65939628,
        -0.11514784, -0.82911448, -0.66067173,  0.17217712],
       [-0.10639146,  0.33909669,        -inf, -0.30501748,  0.12644982,
        -0.55027215, -0.46188373, -0.54601025, -0.36549701],
       [ 0.16416448, -0.14616784, -0.30501748,        -inf,  0.46986255,
         0.7289007 ,  0.12651431, -0.38032699,  0.73850455],
       [-0.63209395,  0.65939628,  0.12644982,  0.46986255,        -inf,
         0.06831768, -0.45264536, -0.74673335,  0.68360896],
       [ 0.02929249, -0.11514784, -0.55027215,  0.7289007 ,  0.06831768,
               -inf,  0.19614875, -0.16784296,  0.34227567],
       [ 0.47625965, -0.82911448, -0.46188373,  0.12651431, -0.45264536,
         0.19614875,        -inf,  0.49842916, -0.25522288],
       [ 0.56912093, -0.66067173, -0.54601025, -0.38032699, -0.74673335,
        -0.16784296,  0.49842916,        -inf, -0.227278  ],
       [-0.05346975,  0.17217712, -0.36549701,  0.73850455,  0.68360896,
         0.34227567, -0.25522288, -0.227278  ,        -inf]])
In [25]:
softmax_mat = softmax(mat,axis = 1)
softmax_mat
Out[25]:
array([[0.        , 0.04929066, 0.10733602, 0.14068455, 0.06345061,
        0.12293407, 0.19221516, 0.21091952, 0.11316942],
       [0.05399649, 0.        , 0.18357746, 0.11299823, 0.25288537,
        0.11655837, 0.05707839, 0.06755002, 0.15535567],
       [0.1351044 , 0.21093203, 0.        , 0.11076621, 0.17052622,
        0.08667514, 0.09468499, 0.08704533, 0.10426568],
       [0.11370849, 0.08337144, 0.07112623, 0.        , 0.15436751,
        0.20001139, 0.10950694, 0.06596647, 0.20194153],
       [0.05677426, 0.20655636, 0.1212224 , 0.17089333, 0.        ,
        0.11437639, 0.06793366, 0.05062491, 0.21161868],
       [0.11301259, 0.09781314, 0.06330312, 0.22749026, 0.11751011,
        0.        , 0.13353391, 0.09279231, 0.15454455],
       [0.19914009, 0.05398112, 0.07793422, 0.14036739, 0.07865755,
        0.15049015, 0.        , 0.20360423, 0.09582523],
       [0.24173924, 0.07067324, 0.07925959, 0.09354211, 0.06484536,
        0.11568786, 0.2252403 , 0.        , 0.1090123 ],
       [0.09591486, 0.12019409, 0.07020599, 0.21175612, 0.20044495,
        0.14248067, 0.0783909 , 0.08061242, 0.        ]])
In [26]:
# SupCon
-((np.log(softmax_mat[0,3])+np.log(softmax_mat[0,6]))/2+\
  (np.log(softmax_mat[1,4])+np.log(softmax_mat[1,7]))/2+\
  (np.log(softmax_mat[2,5])+np.log(softmax_mat[2,8]))/2+\
  (np.log(softmax_mat[3,0])+np.log(softmax_mat[3,6]))/2+\
  (np.log(softmax_mat[4,1])+np.log(softmax_mat[4,7]))/2+\
  (np.log(softmax_mat[5,2])+np.log(softmax_mat[5,8]))/2+\
  (np.log(softmax_mat[6,0])+np.log(softmax_mat[6,3]))/2+\
  (np.log(softmax_mat[7,1])+np.log(softmax_mat[7,4]))/2+\
  (np.log(softmax_mat[8,2])+np.log(softmax_mat[8,5]))/2)/9
Out[26]:
2.1959722995368245

image.png
It gets a little crazy. But the above shows how both approaches calculate SupCon loss: in each addend, the numerator is the average of the positive pairs of each data point, and the denominator is all the pairs that can be formed with this data point.

In [27]:
# NTXent
-((np.log(np.exp(mat[0,3])/(np.exp(mat[0,3])+np.exp(mat[0,1])+np.exp(mat[0,2])+np.exp(mat[0,4])+np.exp(mat[0,5])+np.exp(mat[0,7])+np.exp(mat[0,8])))+\
 np.log(np.exp(mat[0,6])/(np.exp(mat[0,6])+np.exp(mat[0,1])+np.exp(mat[0,2])+np.exp(mat[0,4])+np.exp(mat[0,5])+np.exp(mat[0,7])+np.exp(mat[0,8]))))/2+\
(np.log(np.exp(mat[1,4])/(np.exp(mat[1,4])+np.exp(mat[1,0])+np.exp(mat[1,2])+np.exp(mat[1,3])+np.exp(mat[1,5])+np.exp(mat[1,6])+np.exp(mat[1,8])))+\
 np.log(np.exp(mat[1,7])/(np.exp(mat[1,7])+np.exp(mat[1,0])+np.exp(mat[1,2])+np.exp(mat[1,3])+np.exp(mat[1,5])+np.exp(mat[1,6])+np.exp(mat[1,8]))))/2+\
(np.log(np.exp(mat[2,5])/(np.exp(mat[2,5])+np.exp(mat[2,0])+np.exp(mat[2,1])+np.exp(mat[2,3])+np.exp(mat[2,4])+np.exp(mat[2,6])+np.exp(mat[2,7])))+\
 np.log(np.exp(mat[2,8])/(np.exp(mat[2,8])+np.exp(mat[2,0])+np.exp(mat[2,1])+np.exp(mat[2,3])+np.exp(mat[2,4])+np.exp(mat[2,6])+np.exp(mat[2,7]))))/2+\
(np.log(np.exp(mat[3,0])/(np.exp(mat[3,0])+np.exp(mat[3,1])+np.exp(mat[3,2])+np.exp(mat[3,4])+np.exp(mat[3,5])+np.exp(mat[3,7])+np.exp(mat[3,8])))+\
 np.log(np.exp(mat[3,6])/(np.exp(mat[3,6])+np.exp(mat[3,1])+np.exp(mat[3,2])+np.exp(mat[3,4])+np.exp(mat[3,5])+np.exp(mat[3,7])+np.exp(mat[3,8]))))/2+\
(np.log(np.exp(mat[4,1])/(np.exp(mat[4,1])+np.exp(mat[4,0])+np.exp(mat[4,2])+np.exp(mat[4,3])+np.exp(mat[4,5])+np.exp(mat[4,6])+np.exp(mat[4,8])))+\
 np.log(np.exp(mat[4,7])/(np.exp(mat[4,7])+np.exp(mat[4,0])+np.exp(mat[4,2])+np.exp(mat[4,3])+np.exp(mat[4,5])+np.exp(mat[4,6])+np.exp(mat[4,8]))))/2+\
(np.log(np.exp(mat[5,2])/(np.exp(mat[5,2])+np.exp(mat[5,0])+np.exp(mat[5,1])+np.exp(mat[5,3])+np.exp(mat[5,4])+np.exp(mat[5,6])+np.exp(mat[5,7])))+\
 np.log(np.exp(mat[5,8])/(np.exp(mat[5,8])+np.exp(mat[5,0])+np.exp(mat[5,1])+np.exp(mat[5,3])+np.exp(mat[5,4])+np.exp(mat[5,6])+np.exp(mat[5,7]))))/2+\
(np.log(np.exp(mat[6,0])/(np.exp(mat[6,0])+np.exp(mat[6,1])+np.exp(mat[6,2])+np.exp(mat[6,4])+np.exp(mat[6,5])+np.exp(mat[6,7])+np.exp(mat[6,8])))+\
 np.log(np.exp(mat[6,3])/(np.exp(mat[6,3])+np.exp(mat[6,1])+np.exp(mat[6,2])+np.exp(mat[6,4])+np.exp(mat[6,5])+np.exp(mat[6,7])+np.exp(mat[6,8]))))/2+\
(np.log(np.exp(mat[7,1])/(np.exp(mat[7,1])+np.exp(mat[7,0])+np.exp(mat[7,2])+np.exp(mat[7,3])+np.exp(mat[7,5])+np.exp(mat[7,6])+np.exp(mat[7,8])))+\
 np.log(np.exp(mat[7,4])/(np.exp(mat[7,4])+np.exp(mat[7,0])+np.exp(mat[7,2])+np.exp(mat[7,3])+np.exp(mat[7,5])+np.exp(mat[7,6])+np.exp(mat[7,8]))))/2+\
(np.log(np.exp(mat[8,2])/(np.exp(mat[8,2])+np.exp(mat[8,0])+np.exp(mat[8,1])+np.exp(mat[8,3])+np.exp(mat[8,4])+np.exp(mat[8,6])+np.exp(mat[8,7])))+\
 np.log(np.exp(mat[8,5])/(np.exp(mat[8,5])+np.exp(mat[8,0])+np.exp(mat[8,1])+np.exp(mat[8,3])+np.exp(mat[8,4])+np.exp(mat[8,6])+np.exp(mat[8,7]))))/2)/9
Out[27]:
2.0614844973461555

image.png

The above shows how PyTorch Metric Learning calculates NT-Xent loss. For each data point, it calculates the average of the ratio where the numerator is one of the positive pairs this data point can form and the denominator is this one positive pair plus all the negative pairs. This is different from SupCon loss for which the denominator is all the pairs that contain this data point. The author of PyTorch Metric Learning talked about it in here.
But the SupCon paper assumes each sample in the batch belongs to just 1 positive pair. That's why their implementation of NT-Xent loss in this example is the same as SupCon loss.

2 classes x 2 images x 2 views¶

So far I have only been considering the case where each member within a class is an individual instance, without data augmentation. Now let's consider the situation where the members within a class can be augmented data (e.g., two "views" of the same image, $a_1$ and $a_2$ are two views of $a$).
image.png

In [28]:
embeddings = torch.randn(8,5)
embeddings
Out[28]:
tensor([[-0.0031, -1.2745, -0.5312,  0.0224, -0.8322],
        [-1.3428, -0.7471, -1.7685, -2.0902, -2.1853],
        [ 0.2657,  0.0047,  0.3926,  0.0517, -0.2075],
        [ 0.6994,  0.0610, -0.4057,  0.7450,  0.1561],
        [-0.4123, -0.9179, -0.8227,  1.5111, -0.5250],
        [ 0.7570,  0.0307, -0.0365, -0.9719,  0.4156],
        [-0.6779,  0.9506,  0.2841, -0.3664,  2.0260],
        [ 0.1467,  2.2670,  0.7604,  1.1003,  0.3780]])
In [29]:
labels = torch.tensor([0,0,1,1,0,0,1,1])

PyTorch Metric Learning calculation¶

In [30]:
# SupCon
loss_func = losses.SupConLoss(temperature=1)
loss_func(embeddings, labels)
Out[30]:
tensor(1.8374)
In [31]:
# NT-Xent
loss_func = losses.NTXentLoss(temperature=1)
loss_func(embeddings, labels)
Out[31]:
tensor(1.4141)

SupCon paper calculation¶

In [32]:
e1, e2 = torch.split(embeddings, [4,4], dim=0)
features = torch.cat([e1.unsqueeze(1), e2.unsqueeze(1)], dim=1)
features
Out[32]:
tensor([[[-0.0031, -1.2745, -0.5312,  0.0224, -0.8322],
         [-0.4123, -0.9179, -0.8227,  1.5111, -0.5250]],

        [[-1.3428, -0.7471, -1.7685, -2.0902, -2.1853],
         [ 0.7570,  0.0307, -0.0365, -0.9719,  0.4156]],

        [[ 0.2657,  0.0047,  0.3926,  0.0517, -0.2075],
         [-0.6779,  0.9506,  0.2841, -0.3664,  2.0260]],

        [[ 0.6994,  0.0610, -0.4057,  0.7450,  0.1561],
         [ 0.1467,  2.2670,  0.7604,  1.1003,  0.3780]]])
In [33]:
features = F.normalize(features,dim=2)
features
Out[33]:
tensor([[[-0.0020, -0.7905, -0.3295,  0.0139, -0.5162],
         [-0.2000, -0.4453, -0.3991,  0.7331, -0.2547]],

        [[-0.3510, -0.1953, -0.4623, -0.5464, -0.5713],
         [ 0.5819,  0.0236, -0.0281, -0.7470,  0.3194]],

        [[ 0.5109,  0.0090,  0.7549,  0.0994, -0.3989],
         [-0.2844,  0.3988,  0.1192, -0.1537,  0.8499]],

        [[ 0.6289,  0.0548, -0.3648,  0.6699,  0.1403],
         [ 0.0551,  0.8512,  0.2855,  0.4131,  0.1419]]])
In [34]:
labels = torch.tensor([0,0,1,1])
In [35]:
# SupCon
criterion = SupConLoss(temperature=1,base_temperature=1,contrast_mode="all")
criterion(features,labels)
Out[35]:
tensor(1.8374)
In [36]:
# NT-Xent
criterion = SupConLoss(temperature=1,base_temperature=1,contrast_mode="all")
criterion(features)
Out[36]:
tensor(1.7731)

This time the two approaches still agree with each other on how to calculate SupCon loss, but the NT-Xent loss values are different with each other, and both are different from the SupCon loss.

numpy scipy calculation¶

In [37]:
mat = 1-spatial.distance.cdist(embeddings,embeddings,'cosine')
np.fill_diagonal(mat, -np.inf)
mat
Out[37]:
array([[       -inf,  0.59466433, -0.04950174,  0.0124912 ,  0.62554687,
        -0.18579583, -0.79471934, -0.83458033],
       [ 0.59466433,        -inf, -0.35647372, -0.5090221 ,  0.08662948,
         0.02979664, -0.43467342, -0.62441156],
       [-0.04950174, -0.35647372,        -inf,  0.05703969, -0.23304645,
         0.07466048, -0.40608212,  0.23575862],
       [ 0.0124912 , -0.5090221 ,  0.05703969,        -inf,  0.45076226,
        -0.07812047, -0.18416916,  0.27381313],
       [ 0.62554687,  0.08662948, -0.23304645,  0.45076226,        -inf,
        -0.74472811, -0.49741087, -0.23732455],
       [-0.18579583,  0.02979664,  0.07466048, -0.07812047, -0.74472811,
               -inf,  0.22689736, -0.21915818],
       [-0.79471934, -0.43467342, -0.40608212, -0.18416916, -0.49741087,
         0.22689736,        -inf,  0.4149237 ],
       [-0.83458033, -0.62441156,  0.23575862,  0.27381313, -0.23732455,
        -0.21915818,  0.4149237 ,        -inf]])
In [38]:
softmax_mat = softmax(mat,axis = 1)
softmax_mat
Out[38]:
array([[0.        , 0.24618043, 0.12926941, 0.13753681, 0.25390172,
        0.11279867, 0.06135527, 0.05895769],
       [0.28242156, 0.        , 0.10909977, 0.09366405, 0.1699265 ,
        0.1605384 , 0.10089326, 0.08345645],
       [0.14613935, 0.10751051, 0.        , 0.16256892, 0.12163392,
        0.1654589 , 0.10230722, 0.19438116],
       [0.13845257, 0.08218843, 0.14475988, 0.        , 0.21460495,
        0.12645875, 0.11373458, 0.17980084],
       [0.25987005, 0.15160286, 0.11012195, 0.21819673, 0.        ,
        0.06601663, 0.08453993, 0.10965184],
       [0.12984157, 0.16108102, 0.1684723 , 0.14460276, 0.07424591,
        0.        , 0.19617522, 0.12558122],
       [0.0756086 , 0.108377  , 0.11152037, 0.13922902, 0.1017866 ,
        0.21001663, 0.        , 0.25346178],
       [0.0652064 , 0.08045729, 0.19016586, 0.19754199, 0.11848821,
        0.12066038, 0.22747986, 0.        ]])
In [39]:
# SupCon
-((np.log(softmax_mat[0,1])+np.log(softmax_mat[0,4])+np.log(softmax_mat[0,5]))/3+\
  (np.log(softmax_mat[1,0])+np.log(softmax_mat[1,4])+np.log(softmax_mat[1,5]))/3+\
  (np.log(softmax_mat[2,3])+np.log(softmax_mat[2,6])+np.log(softmax_mat[2,7]))/3+\
  (np.log(softmax_mat[3,2])+np.log(softmax_mat[3,6])+np.log(softmax_mat[3,7]))/3+\
  (np.log(softmax_mat[4,0])+np.log(softmax_mat[4,1])+np.log(softmax_mat[4,5]))/3+\
  (np.log(softmax_mat[5,0])+np.log(softmax_mat[5,1])+np.log(softmax_mat[5,4]))/3+\
  (np.log(softmax_mat[6,2])+np.log(softmax_mat[6,3])+np.log(softmax_mat[6,7]))/3+\
  (np.log(softmax_mat[7,2])+np.log(softmax_mat[7,3])+np.log(softmax_mat[7,6]))/3)/8
Out[39]:
1.8373793815717723

image.png

The above shows how SupCon loss is calculated. There is no distinguishing between augmented data and data belong to the same class: they all form positive pairs.

In [40]:
# NT-Xent consistent with PyTorch Metric Learning method
-((np.log(np.exp(mat[0,1])/(np.exp(mat[0,1])+np.exp(mat[0,2])+np.exp(mat[0,3])+np.exp(mat[0,6])+np.exp(mat[0,7])))+\
np.log(np.exp(mat[0,4])/(np.exp(mat[0,4])+np.exp(mat[0,2])+np.exp(mat[0,3])+np.exp(mat[0,6])+np.exp(mat[0,7])))+\
np.log(np.exp(mat[0,5])/(np.exp(mat[0,5])+np.exp(mat[0,2])+np.exp(mat[0,3])+np.exp(mat[0,6])+np.exp(mat[0,7]))))/3+\
(np.log(np.exp(mat[1,0])/(np.exp(mat[1,0])+np.exp(mat[1,2])+np.exp(mat[1,3])+np.exp(mat[1,6])+np.exp(mat[1,7])))+\
np.log(np.exp(mat[1,4])/(np.exp(mat[1,4])+np.exp(mat[1,2])+np.exp(mat[1,3])+np.exp(mat[1,6])+np.exp(mat[1,7])))+\
np.log(np.exp(mat[1,5])/(np.exp(mat[1,5])+np.exp(mat[1,2])+np.exp(mat[1,3])+np.exp(mat[1,6])+np.exp(mat[1,7]))))/3+\
(np.log(np.exp(mat[2,3])/(np.exp(mat[2,3])+np.exp(mat[2,0])+np.exp(mat[2,1])+np.exp(mat[2,4])+np.exp(mat[2,5])))+\
np.log(np.exp(mat[2,6])/(np.exp(mat[2,6])+np.exp(mat[2,0])+np.exp(mat[2,1])+np.exp(mat[2,4])+np.exp(mat[2,5])))+\
np.log(np.exp(mat[2,7])/(np.exp(mat[2,7])+np.exp(mat[2,0])+np.exp(mat[2,1])+np.exp(mat[2,4])+np.exp(mat[2,5]))))/3+\
(np.log(np.exp(mat[3,2])/(np.exp(mat[3,2])+np.exp(mat[3,0])+np.exp(mat[3,1])+np.exp(mat[3,4])+np.exp(mat[3,5])))+\
np.log(np.exp(mat[3,6])/(np.exp(mat[3,6])+np.exp(mat[3,0])+np.exp(mat[3,1])+np.exp(mat[3,4])+np.exp(mat[3,5])))+\
np.log(np.exp(mat[3,7])/(np.exp(mat[3,7])+np.exp(mat[3,0])+np.exp(mat[3,1])+np.exp(mat[3,4])+np.exp(mat[3,5]))))/3+\
(np.log(np.exp(mat[4,0])/(np.exp(mat[4,0])+np.exp(mat[4,2])+np.exp(mat[4,3])+np.exp(mat[4,6])+np.exp(mat[4,7])))+\
np.log(np.exp(mat[4,1])/(np.exp(mat[4,1])+np.exp(mat[4,2])+np.exp(mat[4,3])+np.exp(mat[4,6])+np.exp(mat[4,7])))+\
np.log(np.exp(mat[4,5])/(np.exp(mat[4,5])+np.exp(mat[4,2])+np.exp(mat[4,3])+np.exp(mat[4,6])+np.exp(mat[4,7]))))/3+\
(np.log(np.exp(mat[5,0])/(np.exp(mat[5,0])+np.exp(mat[5,2])+np.exp(mat[5,3])+np.exp(mat[5,6])+np.exp(mat[5,7])))+\
np.log(np.exp(mat[5,1])/(np.exp(mat[5,1])+np.exp(mat[5,2])+np.exp(mat[5,3])+np.exp(mat[5,6])+np.exp(mat[5,7])))+\
np.log(np.exp(mat[5,4])/(np.exp(mat[5,4])+np.exp(mat[5,2])+np.exp(mat[5,3])+np.exp(mat[5,6])+np.exp(mat[5,7]))))/3+\
(np.log(np.exp(mat[6,2])/(np.exp(mat[6,2])+np.exp(mat[6,0])+np.exp(mat[6,1])+np.exp(mat[6,4])+np.exp(mat[6,5])))+\
np.log(np.exp(mat[6,3])/(np.exp(mat[6,3])+np.exp(mat[6,0])+np.exp(mat[6,1])+np.exp(mat[6,4])+np.exp(mat[6,5])))+\
np.log(np.exp(mat[6,7])/(np.exp(mat[6,7])+np.exp(mat[6,0])+np.exp(mat[6,1])+np.exp(mat[6,4])+np.exp(mat[6,5]))))/3+\
(np.log(np.exp(mat[7,2])/(np.exp(mat[7,2])+np.exp(mat[7,0])+np.exp(mat[7,1])+np.exp(mat[7,4])+np.exp(mat[7,5])))+\
np.log(np.exp(mat[7,3])/(np.exp(mat[7,3])+np.exp(mat[7,0])+np.exp(mat[7,1])+np.exp(mat[7,4])+np.exp(mat[7,5])))+\
np.log(np.exp(mat[7,6])/(np.exp(mat[7,6])+np.exp(mat[7,0])+np.exp(mat[7,1])+np.exp(mat[7,4])+np.exp(mat[7,5]))))/3)/8
Out[40]:
1.4140549545242016
$$(\frac{\frac{a_1b_1}{a_1b_1+a_1c_1+a_1d_1+a_1c_2+a_1d_2}+\frac{a_1a_2}{a_1a_2+a_1c_1+a_1d_1+a_1c_2+a_1d_2}+\frac{a_1b_2}{a_1b_2+a_1c_1+a_1d_1+a_1c_2+a_1d_2}}{3}+\frac{\frac{b_1a_1}{b_1a_1+b_1c_1+b_1d_1+b_1c_2+b_1d_2}+\frac{b_1a_2}{b_1a_2+b_1c_1+b_1d_1+b_1c_2+b_1d_2}+\frac{b_1b_2}{b_1b_2+b_1c_1+b_1d_1+b_1c_2+b_1d_2}}{3}+\frac{\frac{c_1d_1}{c_1d_1+c_1a_1+c_1b_1+c_1a_2+c_1b_2}+\frac{c_1c_2}{c_1c_2+c_1a_1+c_1b_1+c_1a_2+c_1b_2}+\frac{c_1d_2}{c_1d_2+c_1a_1+c_1b_1+c_1a_2+c_1b_2}}{3}+\frac{\frac{d_1c_1}{d_1c_1+d_1a_1+d_1b_1+d_1a_2+d_1b_2}+\frac{d_1c_2}{d_1c_2+d_1a_1+d_1b_1+d_1a_2+d_1b_2}+\frac{d_1d_2}{d_1d_2+d_1a_1+d_1b_1+d_1a_2+d_1b_2}}{3}+\frac{\frac{a_2a_1}{a_2a_1+a_2c_1+a_2d_1+a_2c_2+a_2d_2}+\frac{a_2b_1}{a_2b_1+a_2c_1+a_2d_1+a_2c_2+a_2d_2}+\frac{a_2b_2}{a_2b_2+a_2c_1+a_2d_1+a_2c_2+a_2d_2}}{3}+\frac{\frac{b_2a_1}{b_2a_1+b_2c_1+b_2d_1+b_2c_2+b_2d_2}+\frac{b_2b_1}{b_2b_1+b_2c_1+b_2d_1+b_2c_2+b_2d_2}+\frac{b_2a_2}{b_2a_2+b_2c_1+b_2d_1+b_2c_2+b_2d_2}}{3}+\frac{\frac{c_2c_1}{c_2c_1+c_2a_1+c_2b_1+c_2a_2+c_2b_2}+\frac{c_2d_1}{c_2d_1+c_2a_1+c_2b_1+c_2a_2+c_2b_2}+\frac{c_2d_2}{c_2d_2+c_2a_1+c_2b_1+c_2a_2+c_2b_2}}{3}+\frac{\frac{d_2c_1}{d_2c_1+d_2a_1+d_2b_1+d_2a_2+d_2b_2}+\frac{d_2d_1}{d_2d_1+d_2a_1+d_2b_1+d_2a_2+d_2b_2}+\frac{d_2c_2}{d_2c_2+d_2a_1+d_2b_1+d_2a_2+d_2b_2}}{3})/8$$

Using the same idea of only having one positive pair in the denominator as before, we can get the same result of NT-Xent loss as the PyTorch Metric Learning method. But the SupCon paper method calculates NT-Xent loss in a different way. It only considers 2 views of the same image as the same class, ignoring actual class labels:

In [41]:
# NT-Xent consistent with SupCon paper method
-(np.log(softmax_mat[0,4])+np.log(softmax_mat[1,5])+np.log(softmax_mat[2,6])+np.log(softmax_mat[3,7])+\
np.log(softmax_mat[4,0])+np.log(softmax_mat[5,1])+np.log(softmax_mat[6,2])+np.log(softmax_mat[7,3]))/8
Out[41]:
1.7730605206131949

If we want to get consistent results between the two methods, we can change how we label the classes. Below shows the PyTorch Metric Learning method getting the same NT-Xent loss as the SupCon paper method:

In [42]:
labels = torch.tensor([0,1,2,3,0,1,2,3])
# NT-Xent
loss_func = losses.NTXentLoss(temperature=1)
loss_func(embeddings, labels)
Out[42]:
tensor(1.7731)
In [ ]: