[Sim CLR] A Simple Framework for Contrastive Learning of Visual Representations
{Instance Discrimination Methods, Contrastive loss, Augmentation, Projection, Model Scale}
Paper: http://proceedings.mlr.press/v119/chen20j/chen20j.pdf
{Instance Discrimination Methods, Contrastive loss, Augmentation, Projection, Model Scale}
Paper: http://proceedings.mlr.press/v119/chen20j/chen20j.pdf
Contrastive self-supervised learning algorithms without requiring specialized architectures or a memory bank - SimCLR.
SimCLR first learns generic representations of images on an unlabelled dataset, and then it can be fine-tuned with a small amount of labelled images to achieve good performance for a given classification task (just like medical imaging task).
The generic representations are learned by simultaneously maximizing agreement between differently transformed views of the same image and minimizing agreement between transformed views of different images, following a method called contrastive learning.
SimCLR randomly draws examples from the original dataset, transforming each example twice using a combination of simple augmentations, creating two sets of corresponding views.
It then computes the image representation using a CNN, based on ResNet architecture.
SimCLR computes a non-linear projection of the image representation using a fully-connected network (i.e., MLP), which amplifies the invariant features and maximizes the ability of the network to identify different transformations of the same image.
The trained model not only does well at identifying different transformations of the same image but also learns representations of similar concepts (e.g., chairs vs. dogs), which later can be associated with labels through fine-tuning.
SimCLR Framework [Source]
Updating the parameters of a neural network using this contrastive objective causes representations of corresponding views to “attract” each other, while representations of non-corresponding views “repel” each other.
ImageNet images are of different resolutions, so random crops are typically applied.
To remove co-founding:
First random crop an image and resize to a standard resolution.
Then apply a single or a pair of augmentations on one branch, while keeping the other as identity mapping.
This is suboptimal than applying augmentations to both branches, but sufficient for ablation.
Figure. Two branches of Network
Generate batches of size N from the raw images. Ex: N = 256 - 8192.
Random transformation function T = random(crop + flip + color jitter + grayscale).
For each image in a batch, a random transformation function is applied to get a pair of 2 images. Thus, for a batch size of N, we get 2*N total images.
self.color_jitter = transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength
)
data_transforms = [
transforms.RandomResizedCrop(size=self.input_height),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([self.color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2)
]
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
self.p = p
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < self.p:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample
ResNet-50.
Output: a 2048-dimensional vector h.
The encoder used is generic and replaceable with other architectures. The two encoders shown below have shared weights and we get vectors hi and hj.
2-layer non-linear MLP (fully connected net) => Embedding vector.
The representations hi and hj of the two augmented images are then passed through a series of non-linear Dense -> Relu -> Dense layers to apply non-linear transformation and project it into a representation zi and zj. This is denoted by g(.) in the paper and called the projection head.
class Projection(nn.Module):
def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.model = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Flatten(),
nn.Linear(self.input_dim, self.hidden_dim, bias=True),
nn.BatchNorm1d(self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.output_dim, bias=False))
def forward(self, x):
x = self.model(x)
return F.normalize(x, dim=1)
Cosine Similarity: For two augmented images xi and xj, the cosine similarity is calculated on its projected representations zi and zj.
Loss Function: NT-Xent (Normalized Temperature-Scaled Cross-Entropy Loss).
The softmax calculation is equivalent to getting the probability of the second augmented cat image being the most similar to the first cat image in the pair. Here, all remaining images in the batch are sampled as a dissimilar image (negative pair). Thus, we don’t need specialized architecture, memory bank or queue need by previous approaches like InstDisc, MoCo or PIRL.
Then, the loss is calculated for a pair by taking the negative of the log of the above calculation. This formulation is the Noise Contrastive Estimation(NCE) Loss.
We calculate the loss for the same pair a second time as well where the positions of the images are interchanged.
We calculate the loss for the same pair a second time as well where the positions of the images are interchanged.
Based on the loss, the encoder and projection head representations improves over time and the representations obtained place similar images closer in the space.
def nt_xent_loss(out_1, out_2, temperature):
out = torch.cat([out_1, out_2], dim=0)
n_samples = len(out)
# Full similarity matrix
cov = torch.mm(out, out.t().contiguous())
sim = torch.exp(cov / temperature)
mask = ~torch.eye(n_samples, device=sim.device).bool()
neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)
# Positive similarity
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0)
loss = -torch.log(pos / neg).mean()
return loss
Use LARS optimizer for a higher Learning Rate.
LR = 4.8 with the batch size 4096.
Learning Rate = 0.075 x sqrt(batchsize)
Linear Warmup for 10 epochs. Cosine Decay.
Weight Decay of 10^-6 (exclude normalization and bias layer)
ImageNet
CIFAR-10
MNIST
Linear Classifier.
Fine-tuning.
Transfer Learning.
On ImageNet ILSVRC-2012, it achieves 76.5% top-1 accuracy which is 7% improvement over previous SOTA self-supervised method Contrastive Predictive Coding and on-par with supervised ResNet50.
When trained on 1% of labels, it achieves 85.8% top-5 accuracy outperforming AlexNet with 100x fewer labels
In practice, InfoNCE loss performance is dependent upon the number of negatives and it requires a high number of negatives while calculating the loss term. Hence, simCLR is trained with a high number of batches (as big as 8k) for best results which are very computationally demanding and require multi-GPU training. This is considered as the main drawback of simCLR method.
Important things:
Composition of data augmentations plays a critical role in defining effective predictive tasks.
Introducing a learnable nonlinear transformation between the representation and the contrastive loss substantially improves the quality of the learned representations.
Contrastive learning benefits from larger batch sizes and more training steps compared to supervised learning.
(For NCE, please feel free to read NCE, Negative Sampling, CPC.)
(For temperature parameter, please feel free to read Distillation.)
It is named as NT-Xent (the normalized temperature-scaled cross entropy