The Illustrated FixMatch for Semi-Supervised Learning

11 minute read

Deep Learning has shown very promising results in the field of Computer Vision. But when applying it to practical domains such as medical imaging, lack of labeled data is a major challenge.

In practical settings, labeling data is a time consuming and expensive process. Though, you have a lot of images, only a small portion of them can be labeled due to resource constraints. In such settings, how can we leverage the remaining unlabeled images along with the labeled images to improve the performance of our model? The answer is semi-supervised learning.

FixMatch is a recent semi-supervised approach by Sohn et al. from Google Brain that improved the state of the art in semi-supervised learning(SSL). It is a simpler combination of previous methods such as UDA and ReMixMatch. In this post, we will understand the concept of FixMatch and also see how it got 78% median accuracy and 84% maximum accuracy on CIFAR-10 with just 10 labeled images.

Intuition behind FixMatch

Let’s say we’re doing a cat vs dog classification where we have limited labeled data and a lot of unlabelled images of cats and dogs.

Example of Labeled vs Unlabeled Images

Our usual supervised learning approach would be to just train a classifier on labeled images and ignore the unlabelled images.

Usual Supervised Learning Approach

Instead of ignoring unlabeled images, we could instead apply the below approach. We know that a model should be able to handle perturbations of an image as well to improve generalization.

What if we create augmented versions of unlabeled images and make the supervised model predict those images. Since it’s the same image, the predicted labels should be the same for both.

Concept of FixMatch

Thus, even without knowing their correct labels, we can use the unlabeled images as a part of our training pipeline. This is the core idea behind FixMatch and many preceding papers it builds upon.

The FixMatch Pipeline

With the intuition clear, let’s see how FixMatch is actually applied in practice. The overall pipeline is summarized by the following figure:

End to End Pipeline of FixMatch paper

Synopsis:

As seen, we train a supervised model on our labeled images with cross-entropy loss. For each unlabeled image, weak augmentation and strong augmentations are applied to get two images. The weakly augmented image is passed to our model and we get prediction over classes. The probability for the most confident class is compared to a threshold. If it is above the threshold, then we take that class as the ground label i.e. pseudo-label. Then, the strongly augmented image is passed through our model to get a prediction over classes. This probability distribution is compared to ground truth pseudo-label using cross-entropy loss. Both the losses are combined and the model is tuned.

Pipeline Components

1. Training Data and Augmentation

FixMatch borrows this idea from UDA and ReMixMatch to apply different augmentation i.e weak augmentation on unlabeled image for the pseudo-label generation and strong augmentation on unlabeled image for prediction.

a. Weak Augmentation
For weak augmentation, the paper uses a standard flip-and-shift strategy. It includes two simple augmentations:

  • Random Horizontal Flip
    Example of Random Horizontal Flip

This augmentation is applied with a probability of 50%. This is skipped for the SVHN dataset since those images contain digits for which horizontal flip is not relevant. In PyTorch, this can be performed using transforms as:

from PIL import Image
import torchvision.transforms as transforms

im = Image.open('dog.png')
weak_im = transforms.RandomHorizontalFlip(p=0.5)(im)
  • Random Vertical and Horizontal Translation
    Example of Random Vertical and Horizontal Translation

This augmentation is applied up to 12.5%. In PyTorch, this can be implemented using the following code where 32 is the size of the image needed:

import torchvision.transforms as transforms
from PIL import Image

im = Image.open('dog.png')
resized_im = transforms.Resize(32)(im)
translated = transforms.RandomCrop(size=32, 
                                   padding=int(32*0.125), 
                                   padding_mode='reflect')(resized_im)

b. Strong Augmentation
These include augmentations that output heavily distorted versions of the input images. FixMatch applies either RandAugment or CTAugment and then applies CutOut augmentation.

1. Cutout
Example of Cutout Augmentation This augmentation randomly removes a square part of the image and fills it with gray or black color. PyTorch doesn’t have a built-in implementation of Cutout but we can reuse its RandomErasing transformation to apply CutOut effect.

import torch
import torchvision.transforms as transforms

# Image of 520*520
im = torch.rand(3, 520, 520)

# Fill cutout with gray color
gray_code = 127

# ratio=(1, 1) to set aspect ratio of square
# p=1 means probability is 1, so always apply cutout
# scale=(0.01, 0.01) means we want to get cutout of 1% of image area
# Hence: Cuts out gray square of 52*52
cutout_im = transforms.RandomErasing(p=1, 
                                     ratio=(1, 1), 
                                     scale=(0.01, 0.01), 
                                     value=gray_code)(im)

2. AutoAugment Variants

Previous SSL work used AutoAugment, which trained a Reinforcement Learning algorithm to find augmentations that leads to the best accuracy on a proxy task(e.g. CIFAR-10). This is problematic since we require some labeled dataset to learn the augmentation and also due to resource requirements associated with RL.

So, FixMatch uses one among two variants of AutoAugment:

a. RandAugment
The idea of Random Augmentation(RandAugment) is very simple.

  • First, you have a list of 14 possible augmentations with a range of their possible magnitudes. Pool of Augmentations in RandAugment

  • You select random N augmentations from this list. Here, we are selecting any two from the list. Random Selection of Augmentations in RandAugment

  • Then you select a random magnitude M ranging from 1 to 10. We can select a magnitude of 5. This means a magnitude of 50% in terms of percentage as maximum possible M is 10 and so percentage = 5/10 = 50%. Random Magnitude Selection in RandAugment
  • Now, the selected augmentations are applied to an image in the sequence. Each augmentation has a 50% probability of being applied. Applying RandAugment to Images
  • The values of N and M can be found by hyper-parameter optimization on a validation set with a grid search. In the paper, they use random magnitude from a pre-defined range at each training step instead of fixed magnitude. Grid Search to Find Optimal Configuration in RandAugment

b. CTAugment
CTAugment was an augmentation technique introduced in the ReMixMatch paper and uses ideas from control theory to remove the need for Reinforcement Learning in AutoAugment. Here’s how it works:

  • We have a set of 18 possible transformations similar to RandAugment
  • Magnitude values for transformations are divided into bins and each bin is assigned a weight. Initially, all bins have a weight of 1.
  • Now two transformations are selected at random with equal chances from this set and their sequence forms a pipeline. This is similar to RandAugment.
  • For each transformation, a magnitude bin is selected randomly with a probability according to the normalized bin weights
  • Now, a labeled example is augmented with these two transformations and passed to the model to get a prediction
  • Based on how close the model predictions were to the actual label, the magnitude bins weights for these transformations are updated.
  • Thus, it learns to choose augmentations that the model has a high chance to predict a correct label and thus augmentation that fall within the network tolerance.

Thus, we see that unlike RandAugment, CTAugment can learn magnitude for each transformation dynamically during training. So, we don’t need to optimize it on some supervised proxy task and it has no sensitive hyperparameters to optimize. Thus, this is very suitable for the semi-supervised setting where labeled data is scarce.

2. Model Architecture

The paper uses wider and shallower variants of ResNet called Wide Residual Networks as the base architecture. The exact variant used is Wide-Resnet-28-2 with a depth of 28 and a widening factor of 2. Thus, this model is two times wider than the ResNet. It has total 1.5 million parameters. This model is stacked with an output layer with nodes equal to the number of classes needed(e.g. 2 classes for cat/dog classification).

3. Model Training and Loss Function

  • Step 1: Preparing batches
    We prepare batches of the labeled images of size B and unlabeled images of batch size \(\color{#774cc3}{\mu} B\). Here \(\color{#774cc3}{\mu}\) is a hyperparameter that decides the relative size of labeled:unlabeled images in a batch. For example, \(\color{#774cc3}{\mu}=2\) means that we use twice the number of unlabeled images compared to labeled images. Ratio of Labeled to Unlabeled Images The paper tried increasing values of \(\color{#774cc3}{\mu}\) and found that as we increased the number of unlabeled images, the error rate decreases. The paper uses \(\color{#774cc3}{\mu} = 7\) for evaluation datasets. Impact of increasing unlabeled data on error rate

Source: FixMatch paper

  • Step 2: Supervised Learning
    For the supervised part of the pipeline which is trained on labeled images, we use the regular cross-entropy loss H() for classification task. The total loss for a batch is defined by \(l_s\) and is calculated by taking average of cross-entropy loss for each image in the batch. Supervised Part of FixMatch
\[l_s = \frac{1}{B} \sum_{b=1}^{B} \color{#9A0007}{H(}\ \color{#7ead16}{p_{b}}, \color{#5CABFD}{p_{m}(}\ y\ | \color{#FF8A50}{\alpha(} \color{#8c5914}{x_b} \color{#FF8A50}{)}\ \color{#5CABFD}{)} \color{#9A0007}{)}\]
  • Step 3: Pseudolabeling
    For the unlabeled images, first we apply weak augmentation to the unlabeled image and get the highest predicted class by applying argmax. This is the pseudo-label that will be compared with output of model on strongly augmented image. Generating Pseudolabels in FixMatch
\[\color{#5CABFD}{q_b} = p_m(y | \color{#8C5914}{\alpha(} \color{#007C91}{u_b} \color{#8C5914}{)} )\] \[\color{#866694}{\hat{q_b}} = \color{#48A999}{argmax(}\color{#5CABFD}{q_b} \color{48A999}{)}\]
  • Step 4: Consistency Regularization
    Now, the same unlabeled image is strongly augmented and it’s output is compared to our pseudolabel to compute cross-entropy loss H(). The total unlabeled batch loss is denoted by \(l_u\) and given by: Consistency Regularization in FixMatch
\[l_u = \frac{1}{\mu B} \sum_{b=1}^{\mu B} 1(max(q_b) >= \color{#d11e77}{\tau})\ \color{#9A0007}{H(} \color{#866694}{\hat{q_b}}, p_m(y | \color{#25561F}{A(} \color{#007C91}{u_b} \color{#25561F}{)} \ \color{#9A0007}{)}\]

Here \(\color{#d11e77}{\tau}\) denotes the threshold above which we take a pseudo-label. This loss is similar to the pseudo-labeling loss. The difference is that we’re using weak augmentation to generate labels and strong augmentation for loss.

  • Step 5: Curriculum Learning
    We finally combine these two losses to get total loss that we optimize to improve our model. \(\lambda_u\) is a fixed scalar hyperparameter that decides how much both the unlabeled image loss contribute relative to the labeled loss.
\[loss = l_s + \lambda_u l_u\]

Free Curriculum Learning in FixMatch

An interesting result comes from \(\lambda_u\). Previous works have shown that increasing weight during course of training is good. But, in FixMatch, this is builtin automatically. Since initially, the model is not confident on labeled data, so its output predictions on unlabeled data will be below threshold. As such, the model will be trained only on labeled data. But as the training progress, the model becomes more confident on labeled data and as such, predictions on unlabeled data will also start to cross threshold. As such, the loss will soon start incorporating predictions on unlabeled images as well. This gives us a free form of curriculum learning.
Intuitively, this is similar to how we’re taught in childhood. In early years, we learn easy concepts such as alphabets and what they represent before moving on to complex topics like word formation, sentence formation and then essays.

Paper Insights

1. Can we learn with just one image per class?

The authors performed a really interesting experiment on the CIFAR-10 dataset. They trained a model on CIFAR-10 using only 10 labeled images i.e. 1 labeled example of each class.

  • They created 4 datasets by randomly selecting 1 example per class from the dataset and trained on each dataset 4 times. They reached a test accuracy between 48.58% to 85.32% with a median accuracy of 64.28%. This variability in the accuracy was caused due to the quality of labeled examples. It is difficult for a model to learn each class effectively when provided with low quality examples. Learning with just 1 image per class
  • To test this, they created 8 training datasets with examples ranging from most representative to the least representative. They followed the ordering from this paper and divided the ordering into 8 buckets. The first bucket would contain the most representative images while the last bucket would contain outliers. Then, they took one example of each class randomly from each bucket to create 8 labeled training sets and trained the FixMatch model. Results were:
    • Most representative bucket: 78% median accuracy with a maximum accuracy of 84%
    • Middle bucket: 65% accuracy
    • Outlier bucket: Fails to converge completely with only 10% accuracy

Evaluation and Results

The authors ran evaluations on datasets commonly used for SSL such as CIFAR-10, CIFAR-100, SVHN, STL-10, and ImageNet.

  • CIFAR-10 and SVHN:
    FixMatch achieves the state of the art results on CIFAR-10 and SVHN benchmarks. They use 5 different folds for each dataset. FixMatch SOTA on CIFAR-10 and SVHN

  • CIFAR-100
    On CIFAR-100, ReMixMatch is a bit superior to FixMatch. To understand why, the authors borrowed various components from ReMixMatch to FixMatch and measured their impact on performance. They found that the Distribution Alignment(DA) component which encourages the model to emit all classes with equal probability was the cause. So, when they combined FixMatch with DA, they achieved a 40.14% error rate compared to a 44.28% error rate of ReMixMatch. ReMixMatch is better than FixMatch on CIFAR-100

  • STL-10:
    STL-10 dataset consists of 100,000 unlabeled images and 5000 labeled images. We need to predict 10 classes(airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck.). It is a more representative evaluation for semi-supervised learning because its unlabeled set has out-of-distribution images.
    FixMatch achieves the lowest error rate with CTAugment when evaluated on 5-folds of 1000 labeled images each among all methods. FixMatch gets SOTA on STL-10 dataset

  • ImageNet
    The author also evaluate model on ImageNet to verify if it works on large and complex datasets. They take 10% of the training data as labeled images and all remaining 90% as unlabeled. Also, the architecture used is ResNet-50 instead of WideResNet and RandAugment is used as strong augmentation. They achieve a top-1 error rate of \(28.54\pm0.52%\) which is \(2.68\%\) better than UDA. The top-5 error rate is \(10.87\pm0.28\%\).

Code Implementation

The official implementation of FixMatch in Tensorflow by the paper authors is available here. Unofficial implementations of FixMatch paper in PyTorch are available on github (first, second, third). They use RandAugment and are evaluated on CIFAR-10 and CIFAR-100.

The paper is available here: FixMatch on Arxiv.

Citation Info (BibTex)

If you found this blog post useful, please consider citing it as:

@misc{chaudhary2020fixmatch,
  title   = {The Illustrated FixMatch for Semi-Supervised Learning},
  author  = {Amit Chaudhary},
  year    = 2020,
  note    = {\url{https://amitness.com/2020/03/illustrated-fixmatch}}
}

References

Comments