[FixMax unmerge]
{, }
Paper:
Code:
Fast Read: https://amitness.com/2020/03/fixmatch-semi-supervised/
{, }
Paper:
Code:
Fast Read: https://amitness.com/2020/03/fixmatch-semi-supervised/
1) Motivation, Objectives and Related Works:
Motivation:
Objectives:
FixMax is a simpler combination of previous methods such as UDA and ReMixMatch.
In this post, we will understand FixMatch and also see how it got 78% median accuracy and 84% maximum accuracy on CIFAR-10 with just 10 labeled images.
Related Works:
Contribution:
This method was proposed by Sohn et al. and combines pseudo-labeling and consistency regularization while vastly simplifying the overall method. It got state of the art results on a wide range of benchmarks.
As seen, we train a supervised model on our labeled images with cross-entropy loss. For each unlabeled image, weak augmentation and strong augmentations are applied to get two images. The weakly augmented image is passed to our model and we get prediction over classes. The probability for the most confident class is compared to a threshold. If it is above the threshold, then we take that class as the ground label i.e. pseudo-label. Then, the strongly augmented image is passed through our model to get a prediction over classes. This prediction is compared to ground truth pseudo-label using cross-entropy loss.
Both the losses are combined and the model is optimized.
2) Methodology:
Training Data and Augmentation:
FixMatch borrows this idea from UDA and ReMixMatch to apply different augmentations i.e weak augmentation on unlabeled image for the pseudo-label generation and strong augmentation on unlabeled image for prediction.
Weak augmentation
For weak augmentation, the paper uses a standard flip-and-shift strategy. It includes two simple augmentations:
Random Horizontal Flip: This augmentation is applied with a probability of 50%. This is skipped for the SVHN dataset since those images contain digits for which horizontal flip is not relevant.
Random Vertical and Horizontal Translation: This augmentation is applied up to 12.5%.
Strong Augmentation:
These include augmentations that output heavily distorted versions of the input images. FixMatch applies either RandAugment or CTAugment and then applies CutOut augmentation:
Cutout: This augmentation randomly removes a square part of the image and fills it with gray or black color.
RandAugment:
First, you have a list of 14 possible augmentations with a range of their possible magnitudes.
You select random N augmentations from this list. Here, we are selecting any two from the list.
Then you select a random magnitude M ranging from 1 to 10. We can select a magnitude of 5. This means a magnitude of 50% in terms of percentage as maximum possible M is 10 and so percentage = 5/10 = 50%.
Now, the selected augmentations are applied to an image in the sequence. Each augmentation has a 50% probability of being applied.
The values of N and M can be found by hyper-parameter optimization on a validation set with a grid search. In the paper, they use random magnitude from a pre-defined range at each training step instead of a fixed magnitude.
CTAugment: CTAugment was an augmentation technique introduced in the ReMixMatch
We have a set of 18 possible transformations similar to RandAugment
Magnitude values for transformations are divided into bins and each bin is assigned a weight. Initially, all bins weigh 1.
Now two transformations are selected at random with equal chances from this set and their sequence forms a pipeline. This is similar to RandAugment.
For each transformation, a magnitude bin is selected randomly with a probability according to the normalized bin weights
Now, a labeled example is augmented with these two transformations and passed to the model to get a prediction
Based on how close the model predictions were to the actual label, the magnitude bins weights for these transformations are updated.
Thus, it learns to choose augmentations that the model has a high chance to predict a correct label and thus augmentation that fall within the network tolerance.
Thus, we see that unlike RandAugment, CTAugment can learn magnitude for each transformation dynamically during training. So, we don’t need to optimize it on some supervised proxy task and it has no sensitive hyperparameters to optimize.
Weak Augmentations
RandAugment Steps
Model Architecture:
FixMatch uses wider and shallower variants of ResNet called Wide Residual Networks as the base architecture.
The exact variant used is Wide-Resnet-28-2 with a depth of 28 and a widening factor of 2. This model is two times wider than the ResNet. It has a total of 1.5 million parameters. The model is stacked with an output layer with nodes equal to the number of classes needed(e.g. 2 classes for cat/dog classification).
Model Training and Loss Function:
Step 1: Preparing batches
We prepare batches of the labeled images of size B and unlabeled images of batch size mB. Here m is a hyperparameter that decides the relative size of labeled: unlabeled images in a batch. For example, m=2 means that we use twice the number of unlabeled images compared to labeled images.
The paper tried increasing values of m and found that as we increased the number of unlabeled images, the error rate decreases. The paper uses m=7 for evaluation datasets.
Model Training and Loss Function:
Step 1: Preparing batches
Step 2: Supervised Learning
For the supervised part of the pipeline which is trained on labeled images, we use the regular cross-entropy loss H() for classification task. The total loss for a batch is defined by ls and is calculated by taking average of cross-entropy loss for each image in the batch.
Model Training and Loss Function:
Step 1: Preparing batches
Step 2: Supervised Learning
Step 3: Pseudolabeling
For the unlabeled images, first we apply weak augmentation to the unlabeled image and get the highest predicted class by applying argmax. This is the pseudo-label that will be compared with output of model on strongly augmented image.
Model Training and Loss Function:
Step 1: Preparing batches
Step 2: Supervised Learning
Step 3: Pseudolabeling
Step 4: Consistency Regularization
Now, the same unlabeled image is strongly augmented and it’s output is compared to our pseudolabel to compute cross-entropy loss H(). The total unlabeled batch loss is denoted by lu
Here t denotes the threshold above which we take a pseudo-label. This loss is similar to the pseudo-labeling loss. The difference is that we’re using weak augmentation to generate labels and strong augmentation for loss.
Model Training and Loss Function:
Step 1: Preparing batches
Step 2: Supervised Learning
Step 3: Pseudolabeling
Step 4: Consistency Regularization
Step 5: Curriculum Learning
We finally combine these two losses to get a total loss that we optimize to improve our model. lu is a fixed scalar hyperparameter that decides how much both the unlabeled image loss contribute relative to the labeled loss.
loss = ls + lu*lu
An interesting result comes from lu. Previous works have shown that increasing weight during training is good. But, in FixMatch, this is present in the algorithm itself.
Since initially, the model is not confident on labeled data, so its output predictions on unlabeled data will be below the threshold. As such, the model will be trained only on labeled data. But as the training progress, the model becomes more confident in labeled data and as such, predictions on unlabeled data will also start to cross the threshold. As such, the loss will soon start incorporating predictions on unlabeled images as well. This gives us a free form of curriculum learning.
Intuitively, this is similar to how we’re taught in childhood. In the early years, we learn easy concepts such as alphabets and what they represent before moving on to complex topics like word formation, sentence formation, and then essays.
3) Experimental Results:
Experimental Results:
Ablations:
n2 n0
θ