[BYOL] Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
{Rolling Weight Updates}
Paper: https://arxiv.org/abs/2006.07733
Code:
{Rolling Weight Updates}
Paper: https://arxiv.org/abs/2006.07733
Code:
BYOL
Using the same main ideas as AMDIM (but with the last feature map only), but with two changes.
BYOL uses only positive examples in the loss function.
BYOL builds on the momentum network concept of MoCo, adding an MLP qθ to predict z’ from z. Rather than using a contrastive loss, BYOL uses the L2 error between the normalized prediction p and target z’.
Using our dog image example, BYOL tries to convert both crops of the dog image into the same representation vector (make p and z’ equal.) Because this loss function does not require negative examples, there is no use for a memory bank in BYOL.
Both MLPs in BYOL use batch normalization after the first linear layer only.
By the above description, it appears that BYOL can learn without explicitly contrasting between multiple different images. Surprisingly, however, we found that BYOL is not only doing contrastive learning, but that contrastive learning is essential to its success.
Using torchvision.transforms, with inputs crop_size=96 and s=0.5, the color jitter strength, the transform function for training is:
Compose([
RandomResizedCrop(crop_size, scale=(0.2, 1.0)),
RandomApply(
[ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)], p=0.8
),
RandomGrayscale(p=0.2),
RandomHorizontalFlip(),
RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
BYOL uses two encoders instead of one. The second encoder is actually an exact copy of the first encoder but instead of updating the weights in every pass, it updates them on a rolling average.
BYOL does not use negative samples. Instead relies on the rolling weight updates as a way to give a contrastive signal to the training. However, a recent ablation discovered that this may not be necessary and that in fact adding batch-normalization is what keeps ensures the system does not generate trivial solutions.
BYOL drops the need for the denominator and instead relies on the weighted updates to the second encoder to provide the contrastive signal.
However, as mentioned earlier, recent ablations show that this in fact may not actually be the driver of the contrastive signal.
Network Architecture:
The architecture of the BYOL network is shown below. θ and ϵ represent online and target network parameters respectively and f_θ and f_ϵ are online and target encoders respectively. Target network weights are slowly moving average of the online network weights i.e.
Idea is to train the online network f_θ in the first step and use those learned representations for downstream tasks and fine-tune them further using labelled data in the second step. The first step i.e. BYOL could be summarized in the following 5 straightforward steps.
Given an input image x, two views of the same image v and v’ are generated by applying two random augmentations to x.
Given v and v’ to online and target encoders in order, vector representations y_θ and y’_ϵ are obtained.
Now, these representations are projected to another subspace z. These projected representations are indicated by z_θ and z’_ϵ in the image below.
Since the target network is the slow moving average of the online network, the online representations should be predictive of the target representations, i.e. z_θ should predict z’_ϵ and hence another predictor(q_θ) is put on top of z_θ.
Contrastive loss is reduced between <q_θ(z_θ), z’_ϵ>.
Mathematically, Contrastive loss is computed as mean squared error between q_θ(z_θ) and z’_ϵ. Before computing the mean squared error, the labels z’_ϵ and targets q_θ(z_θ) are L2-normalized. The equation is,
z`_ϵ bar , is the L2 normalized z`_ϵ and q_θ(z_θ) bar is L2 normalized q_θ(z_θ).
Why BYOL?
BYOL method helps in learning useful representations for a variety of downstream computer vision tasks such as object recognition, object detection, semantic segmentation, etc. Once these representations are learned in BYOL way, they could be used with any standard object classification model such as Resnet, VGGnet, or any semantic segmentation network such as FCN8s, deeplabv3, etc or any other task-specific network and it gets to a better result than training these networks from scratch. This is the major reason behind the popularity of BYOL. The below graph shows that the BYOL representations learned using Imagenet images beats all previous unsupervised learning methods and achieves classification accuracy of 74.1% with Resnet50 under linear evaluation protocol. In case you are not sure about Linear evaluation protocol, it is described in my last post in detail.
The power of BYOL is leveraged more efficiently in dense prediction tasks where generally only a few labels are available due to the complex and costly task of data labelling. When BYOL is used for one such task namely semantic segmentation using cityscapes dataset with FCN8s network along with Resnet50 backbone, it outperforms the version of the network trained from scratch i.e. with random weights. The below graph compares the performance of 3 main networks on the cityscapes dataset.
Resnet50 trained from Imagenet weights and fine-tuned using 3k cityscapes labelled images(dotted red line).
Resnet50 trained from random weights using 3k cityscapes images only(dotted black line).
Resnet50 pre-trained on BYOL using 20k unlabelled cityscapes images, then fine-tuned using 3k cityscapes image(solid blue line).
The below graph clearly shows that the BYOL significantly helps in learning useful representations for this task and hints that it should be considered as a pre-training step for other computer vision industrial applications where Imagenet weights could not be used due to licensing regulations and lots of unlabelled data is present for unsupervised training.
Results:
BYOL achieves state-of-the-art performance without using any negative samples. Fundamentally, like a siamese network, BYOL uses two same encoder networks referred to as online and target network for obtaining representations and reduces the contrastive loss between the two representations.
Although MoCo showed good results but the dependency on negative samples has complicated the method. Recently BYOL[7] was introduced based on the instance discrimination method and it has shown that using two networks similar to MoCo, better visual representations could be learnt even without negatives. Their method achieves 74.3% top-1 classification accuracy on ImageNet under linear evaluation protocol using Resnet50 and further reduces the gap with their supervised counterpart using wider and deeper Resnets. The results are shown below.
The actual BYOL training method is worthy of a separate post and I would leave it for future posts.
Implementation:
For Image augmentations, the following set of augmentations are used. First, a random crop is selected from the image and resized to 224x224. Then random horizontal flip is applied, followed by random color distortion and random grayscale conversion. Random color distortion consists of a random sequence of brightness, contrast, saturation, hue adjustments. The following code snippet implements the BYOL augmentation pipeline in PyTorch..
from torchvision import transforms as tfms
byol_tfms = tfms.Compose([
tfms.RandomResizedCrop(size=512, scale=(0.3, 1)),
tfms.RandomHorizontalFlip(),
tfms.ToPILImage(),
tfms.RandomApply([
tfms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
tfms.RandomGrayscale(p=0.2),
tfms.ToTensor()
])
In the actual BYOL implementations, Resnet50 is used as an encoder network. For the projection MLP, the 2048 dimensional feature vector is projected onto 4096-dimensional vector space first with Batch norm followed by ReLU non-linear activation and then it is reduced to the 256-dimensional feature vector.
The same architecture is used for the predictor network.
Below PyTorch snippet implements the Resnet50 based BYOL network, but it could also be used in conjunction with any arbitrary encoder network such as VGG, InceptionNet, etc. without any significant change.
Why BYOL works the way it works:
Another interesting fact is, although a collapsed solution exists for the task curated for BYOL, the model avoids it safely and the actual reason for it is unknown. Collapsed solution means, the model might get away by learning a constant vector for any view of any image and gets to zero loss, but it does not happen.
The authors of the original paper[1], conjecture that it might be due to the complex network(Deep Resnet with skip connections) used in the backbone, the model never gets to the straightforward collapsed solution. But in another recent paper SimSiam[2] Chen, Xineli and He, found out it is not the complex network architecture but the “stop-gradient” operation that makes the model to avoid the collapsed representations. “stop-gradient” means that the network never gets to update the weights of the target network directly through gradients and hence never gets to the collapsed solution. They also show that there isn’t any need for a momentum target network to avoid collapsed representation but it certainly gives better representations for downstream tasks if used.
That was the quick summary of BYOL along with code in PyTorch. For full implementation, this GitHub repo https://github.com/nilesh0109/self-supervised-sem-seg could be referred.
Affection of Batch Normalization and Contrastive Machenism in BYOL: https://imbue.com/research/2020-08-24-understanding-self-supervised-contrastive-learning/