How to Train Your MAML

November 24, 2018 - 2 minute read -
Paper Summary Machine Learning Meta Learning

This paper presents empirical problems of MAML and proposes some fixes to address this. Overall, I did not find this paper to be too novel in its solutions but the proposed fixes are simple and seem effective. When I was training MetaGAN, I would run into these problems a lot. A good read for anyone wanting to apply MAML.

Problems Training Instability from Exploding Gradients

MAML works by minimizing the target loss after completing all of the inner-loop steps for a task with respect to meta prior. Optimizing this outer loop means back propagating the gradients through the unfolded inner loop multiple times, which causes exploding gradients since gradients will be multiplied by same sets of parameters multiple times.

I personally ran into this exploding gradient problem a lot in MetaGAN, and was unable to find a “true” fix aside from playing with learning rate and number of inner step.

The paper addresses this by calculating the meta prior gradient at every step in the inner loop instead of in one pass over the unrolled inner loop. Then they do a weighted sum of the inner loop meta priors to get the final meta prior gradient.

Second Order Derivative Cost

First order approximations are more efficient, and suffer less from exploding gradients than Second Order Derivatives. However, they are less accurate. The authors simply propose first using First Order derivatives in the first few epochs of the training and then switch to second order derivatives to fine tune the meta prior.

How to use Batch Normalization effectively with MAML

The original implementation of MAML did not use running statistics for batch normalization. They chose to use current batch statistics, which reduces the effectiveness of BN. However, simply using running statistics does not work with MAML. This is because it only learns a single set of weights and biases for each BN layer. The argument is that during the inner loop update the feature distributions diverge quickly because of fast adaptation. Learning a single set of biases assumes that the feature distributions do not shift much per inner loop step. However, this is false as the fast adaptation drastically changes the model per step. To address this, they introduce Per-Step biases, which expands the number of biases learned for a BN layer from 1 to K where K is the number of inner loop steps. Now biases are calculated with respect to their inner update step instead of a single bias across the entire inner update loop.

Fixed Inner Loop Learning Rate

They learn an additional learning rate and direction (positive or negative) for each layer in the network. They also learn this across the inner loop steps.