[MixMatch] MixMatch: A Holistic Approach to Semi-Supervised Learning
{, }
Paper: https://arxiv.org/abs/1905.02249
Code:
{, }
Paper: https://arxiv.org/abs/1905.02249
Code:
Semi-supervised learning has proven to be a powerful paradigm for leveraging unlabeled data to mitigate the reliance on large labeled datasets.
MixMatch works by guessing low-entropy labels for data-augmented unlabeled examples and mixing labeled and unlabeled data using MixUp.
MixMatch can help achieve a dramatically better accuracy-privacy trade-off for differential privacy.
Figure. MixMatch: Augmentation, Averaging and Sharpening
MixMatch is a “holistic” approach incorporating ideas and components from the dominant paradigms for prior SSL.
For the labeled image, we create an augmentation of it.
For the unlabeled image, we create K augmentations and get the model predictions on all K-images.
Then, the predictions are averaged and temperature scaling is applied to get a final pseudo-label.
This pseudo-label will be used for all the K-augmentations.
The batches of augmented labeled and unlabeled images are combined and the whole group is shuffled.
Then, the first N images of this group are taken as WL, and the remaining M images are taken as WU.
Mixup is applied between the augmented labeled batch and group WL.
Similarly, mixup is applied between the M augmented unlabeled group and the WU group.
Thus, we get the final labeled and unlabeled group.
Now, for the labeled group, we take model predictions and compute cross-entropy loss with the ground truth mixup labels.
Similarly, for the unlabeled group, we compute model predictions and compute mean square error (MSE) loss with the mixup pseudo labels.
A weighted sum is taken of these two terms with l weighting the MSE loss.
Given a batch X of labeled examples with one-hot targets (representing one of L possible labels) and an equally-sized batch U of unlabeled examples:
MixMatch produces a processed batch of augmented labeled examples X0 and a batch of augmented unlabeled examples with “guessed” labels U0.
U0 and X0 are then used in computing separate labeled and unlabeled loss terms.
More formally, the combined loss L for semi-supervised learning is defined:
where H(p, q) is the cross-entropy between distributions p and q.
Thus, cross-entropy loss is used for labeled set. And the squared L2 loss is used on predictions and guessed labels.
T=0.5, K=2, α=0.75, and λU are hyperparameters described below.
λU has different values for different datasets.
For each xb in the batch of labeled data X, a transformed version is generated: ^xb = Augment(xb) (algorithm 1, line 3).
For each ub in the batch of unlabeled data U, K augmentations are generated: ^ub,k = Augment(ub); where k is from 1 to K. (algorithm 1, line 5). These individual augmentations are used to generate a “guessed label” qb for each ub.
For each unlabeled example in U, MixMatch produces a “guess” for the example’s label using the model’s predictions. This guess is later used in the unsupervised loss term.
To do so, the average of the model’s predicted class distributions across all the K augmentations of ub are computed by:
Using data augmentation to obtain an artificial target for an unlabeled example is common in consistency regularization methods.
Given the average prediction over augmentations qb, a sharpening function is applied to reduce the entropy of the label distribution. In practice, for the sharpening function, the common approach is to have the “temperature” T for adjustment of this categorical distribution:
where p is some input categorical distribution (specifically in MixMatch, p is the average class prediction over augmentations), and T is a temperature hyperparameter.
Lowering the temperature encourages the model to produce lower-entropy predictions.
A slightly modified version of mixup is used. mixup is the data augmentation technique originally used in supervised learning.
For a pair of two examples with their corresponding labels probabilities (x1, p1), (x2, p2), (x’, p') is computed by:
where α is a hyperparameter for beta distribution.
(Please feel free to read mixup if interested.)
To apply mixup, all augmented labeled examples with their labels and all unlabeled examples with their guessed labels are first collected into:
Then, these collections are concatenated and shuffled to form W which will serve as a data source for mixup:
For each the i-th example-label pair in ^X, mixup is applied using W and add to the collection X’. The remainder of W is used for ^U where mixup is applied and add to the collection U’.
Thus, MixMatch transforms X into X’, a collection of labeled examples which have had data augmentation and mixup (potentially mixed with an unlabeled example) applied.
Similarly, U is transformed into U’, a collection of multiple augmentations of each unlabeled example with corresponding label guesses.
Wide ResNet WRN-28 is used as the network model.
CIFAR-10 (Left): MixMatch outperforms all other methods by a significant margin, for example reaching an error rate of 6.24% with 4000 labels.
SVHN (Right): MixMatch’s performance to be relatively constant (and better than all other methods) across all amounts of labeled data.
Figure. Error rate (%) on CIFAR-10 (left) and SVHN (right)
n2 n0
θ