[PILR] Self-Supervised Learning of Pretext-Invariant Representations
{Image-based Pretext Tasks, Invariance}
Paper: https://arxiv.org/abs/1912.01991
Code: https://github.com/akwasigroch/Pretext-Invariant-Representations (not from Author)
{Image-based Pretext Tasks, Invariance}
Paper: https://arxiv.org/abs/1912.01991
Code: https://github.com/akwasigroch/Pretext-Invariant-Representations (not from Author)
The goal of self-supervised learning from images is to construct image representations that are semantically meaningful via pretext tasks that do not require semantic annotations.
Many pretext tasks lead to representations that are covariant with image transformations.
We argue that, instead, semantic representations ought to be invariant under such transformations.
Pretext-Invariant Representation Learning (PIRL, pronounced as “pearl”)
PIRL is related to approaches that learn image representations that are invariant to the image transformations rather than covariant. (Invariant to both the data augmentation and the pretext image transformations).
Combine pretext tasks with contrastive learning. (Adapt the “Jigsaw” pretext task [54] to work with PIRL).
Substantially improves the semantic quality of the learned image representations.
SIFT [45], HOG [8], and learned representations from ConvNets [37, 40, 69].
Practically useful representations are designed to be invariant to ‘nuisance’ factors like translations of pixels, changes in scale, color, and lighting, e.g., by using data augmentation [37] during training.
=> Leverage the invariance to self-supervised ‘pretext tasks’.
Learn feature representations without considering a corresponding (image-conditional) label distribution.
Sparse coding [58], adversarial training [12, 13, 50], autoencoders [49, 63, 76], or probabilistic versions thereof [67].
Predicting some low-level property of an image transformation which makes the final representations covariant to image transformations.
Contrastive Loss
Approaches that use a contrastive loss [24] in predictive learning [25, 26, 28, 59, 70, 73].
These prior approaches predict missing parts of the data, e.g., future frames in videos [25, 59], or operate on multiple views [73].
PIRL learns invariances rather than predicting missing data.
Create the representation to be invariant or capture as little information as possible of the input transform.
Feed-forward both the image I and its transformed version It through a ConvNet to get representations and encourage them to be similar.
First, take an original image I, and apply a transformation borrowed from some pretext task (e.g. rotation prediction) to get transformed image It.
Then, both the images are passed through ConvNet θ with shared weights to get representations VI and VIt.
The representation VI of the original image is passed through a projection head f(.) to get representation f(VI).
Similarly, a separate projection head g(.) is used to get representation g(VIt) for the transformed image.
These representations are tuned with a loss function such that representations of I and It are similar while making them different from other random image representations I′ stored in a memory bank.
Resnet 50 (R-50)
Each image passes through the Encoder to generate a 2048-dimensional representation.
Compute the representation f(.) of a transformed image I:
Extracting res5 features.
Average pooling.
A linear projection to obtain a 128-dimensional representation.
Compute the representation g(.) of a transformed image It:
Extract nine patches from image I.
Compute an image representation for each patch separately by extracting activations from the res5 layer of the ResNet-50 and average pool the activations.
Apply a linear projection to obtain a 128-dimensional patch representation.
Concatenate the patch representations in random order and apply a second linear projection on the result to obtain the final 128-dimensional image representation, g(.).
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.network = resnet50()
self.network = torch.nn.Sequential(*list(self.network.children())[:-1])
self.projection_original_features = nn.Linear(2048, 128)
self.connect_patches_feature = nn.Linear(1152, 128)
def forward_once(self, x):
return self.network(x)
def return_reduced_image_features(self, original):
original_features = self.forward_once(original)
original_features = original_features.view(-1, 2048)
original_features = self.projection_original_features(original_features)
return original_features
def return_reduced_image_patches_features(self, original, patches):
original_features = self.return_reduced_image_features(original)
patches_features = []
for i, patch in enumerate(patches):
patch_features = self.return_reduced_image_features(patch)
patches_features.append(patch_features)
patches_features = torch.cat(patches_features, axis=1)
patches_features = self.connect_patches_feature(patches_features)
return original_features, patches_features
def forward(self, images=None, patches=None, mode=0):
'''
mode 0: get 128 feature for image,
mode 1: get 128 feature for image and patches
'''
if mode == 0:
return self.return_reduced_image_features(images)
if mode == 1:
return self.return_reduced_image_patches_features(images, patches)
To learn better image representations, it’s better to compare the current image with a large number of negative images.
One common approach is to use larger batches and consider all other images in this batch as negative.
However, loading larger batches of images comes with its set of resource challenges.
=> PIRL uses a memory bank that caches representations of all images and uses that during training. => Allow to use a large number of negative pairs without increasing batch size.
We implement the memory bank as described in [72] and use the same hyperparameters for the memory bank.
Specifically, we set the temperature in Equation 3 to τ = 0.07, and use a weight of 0.5 to compute the exponential moving averages in the memory bank.
Unless stated otherwise, we use λ= 0.5 in Equation 5.
class Memory(object):
def __init__(self, device, size=2000, weight=0.5):
self.memory = np.zeros((size, 128))
self.weighted_sum = np.zeros((size, 128))
self.weighted_count = 0
self.weight = weight
self.device = device
def initialize(self, net, train_loader):
self.update_weighted_count()
print('Saving representations to memory')
bar = Progbar(len(train_loader), stateful_metrics=[])
for step, batch in enumerate(train_loader):
with torch.no_grad():
images = batch['original'].to(self.device)
index = batch['index']
output = net(images=images, mode=0)
self.weighted_sum[index, :] = output.cpu().numpy()
self.memory[index, :] = self.weighted_sum[index, :]
bar.update(step, values=[])
def update(self, index, values):
self.weighted_sum[index, :] = values + (1 - self.weight) * self.weighted_sum[index, :]
self.memory[index, :] = self.weighted_sum[index, :]/self.weighted_count
pass
def update_weighted_count(self):
self.weighted_count = 1 + (1 - self.weight) * self.weighted_count
def return_random(self, size, index):
if isinstance(index, torch.Tensor):
index = index.tolist()
#allowed = [x for x in range(2000) if x not in index]
allowed = [x for x in range(index[0])] + [x for x in range(index[0] + 1, 2000)]
index = random.sample(allowed, size)
return self.memory[index, :]
def return_representations(self, index):
if isinstance(index, torch.Tensor):
index = index.tolist()
return torch.Tensor(self.memory[index, :])
Similarity metric: Cosine Similarity.
NCE (Noise Contrastive Estimator).
Workflow:
In a mini-batch, we will have one positive (similar) pair and many negative (dissimilar) pairs.
We then compute the similarity between the transformed image’s feature vector and the rest of the feature vectors in the minibatch (one positive, the rest negative).
We then compute the score of a softmax-like function on the positive pair. Maximizing a softmax score means minimizing the rest of the scores, which is exactly what we want for an energy-based model.
The final loss function, therefore, allows us to build a model that pushes the energy down on similar pairs while pushing it up on dissimilar pairs.
Cosine similarity is used as a similarity measure of any two representations.
Below, we are comparing the similarity of a cat image and its rotated counterpart. It is denoted by s(.)
NCE computes the similarity score of two representations normalized by all negative images.
For a cat image and its rotated counterpart, the noise contrastive estimator is denoted by:
The loss for a pair of images is calculated using cross-entropy loss as:
Since we already have representation of image and negative images in the memory bank, we use that instead of computed representation as:
The loss only compares I to It and compares It to I′. It doesn’t compare I and I′.
To do that, we introduce another loss term and combine both these losses using the following formulation.
class NoiseContrastiveEstimator():
def __init__(self, device):
self.device = device
def __call__(self, original_features, path_features, index, memory, negative_nb=1000):
loss = 0
for i in range(original_features.shape[0]):
temp = 0.07
cos = torch.nn.CosineSimilarity()
criterion = torch.nn.CrossEntropyLoss()
negative = memory.return_random(size=negative_nb, index=[index[i]])
negative = torch.Tensor(negative).to(self.device).detach()
image_to_modification_similarity = cos(original_features[None, i, :], path_features[None, i, :])/temp
matrix_of_similarity = cos(path_features[None, i, :], negative) / temp
similarities = torch.cat((image_to_modification_similarity, matrix_of_similarity))
loss += criterion(similarities[None, :], torch.tensor([0]).to(self.device))
return loss / original_features.shape[0]
loss_1 = noise_contrastive_estimator(representations, output[1], index, memory, negative_nb=negative_nb)
loss_2 = noise_contrastive_estimator(representations, output[0], index, memory, negative_nb=negative_nb)
loss = loss_weight * loss_1 + (1 - loss_weight) * loss_2
PIRL doesn’t use the direct output of the convolutional feature extractor. It instead defines different heads f and g, which can be thought of as independent layers on top of the base convolutional feature extractor.
Dr. LeCun mentions that to make this work, it requires a large number of negative samples. In SGD, it can be difficult to consistently maintain a large number of these negative samples from mini-batches. => PIRL also uses a cached memory bank.
Question: Why do we use cosine similarity instead of L2 Norm?
With an L2 norm, it’s very easy to make two vectors similar by making them “short” (close to the centre) or make two vectors dissimilar by making them very “long” (away from the center).
This is because the L2 norm is just a sum of squared partial differences between the vectors.
Thus, using cosine similarity forces the system to find a good solution without “cheating” by making vectors short or long.
The authors state two promising areas for improving PIRL and learn better image representations:
Borrow transformation from other pretext tasks instead of jigsaw and rotation.
Combine PIRL with clustering-based approaches