Stanford DAWN

Home

YellowFin: An automatic tuner for momentum SGD

SGD with momentum

Hand-tuned momentum SGD is competitive with state-of-the-art adaptive methods, like Adam. We introduce YellowFin, an automatic tuner for the hyperparameters of momentum SGD. YellowFin can train models such as large LSTMs and certain ResNets in fewer iterations than the state of the art. It performs even better in asynchronous settings via an on-the-fly momentum adaptation scheme that uses a novel momentum measurement component along with a negative-feedback loop mechanism.

Summary of results
Comparing YellowFin to Adam on training a ResNet on CIFAR100 (left) synchronously; (right) asynchronously, using 16 workers.

Intro

Hyperparameter tuning is one of the most painful parts of any deep learning pipeline. The literature covers a wide spectrum of amazing results that seem to be powered by black magic: they work because the we spent an obscene amount of time exploring the hyperparameter space. Finding a working configuration can be a very frustrating affair.

Too much tuning.

Methods like Adam and RMSProp tune learning rates for individual variables and make life easier.

Our experimental results show that, on ResNets and LSTMs, those adaptive methods may not perform better than carefully tuned, good ol’ momentum SGD.

This understanding is supported by recent theoretical results, which also suggest adaptive methods can suffer from bad generalization. The hypothesis there is that variable-level adaptation can lead to completely different minima. Here we point out another important, overlooked factor: momentum.

Momentum tuning is critical for efficiently training deep learning models.

Classic convex results and recent papers study momentum and emphasize its importance. Asynchronous dynamics are another reason to carefully tune momentum. Our recent paper shows that training asynchronously introduces momentum-like dynamics in the gradient decent update. Those added dynamics make momentum tuning even more important. Sometimes even negative momentum values can be optimal!

Despite these good reasons, the state-of-the-art does not automatically tune momentum!

The majority of deep learning literature sticks to the standard momentum 0.9, leaving significant performance improvements on the table. It is no accident that the most successful GAN papers hand-tune the momentum parameter to a small positive or zero value.

YellowFin’s momentum

We revisit SGD with Polyak’s momentum, study some of its robustness properties and extract the design principles for a tuner, YellowFin. YellowFin automatically tunes a single learning rate and momentum value for SGD. The end result is fast!

Weeeeeee

YellowFin is momentum SGD. This means that if you already have a good hyperparameter schedule for momentum SGD, carefully-tuned on your favorite model, you do not need our tuner. YellowFin is geared towards workflows that involve fast iterations over changing models. The rest of the post gives a high level review of results in our paper:

  • Robustness: Momentum is robust in curvature variation. Empirically, this can yield a constant rate of convergence for some non-convex objectives. This observation powers YellowFin’s momentum tuning.
  • On-the-fly asynchrony compensation: We can measure the total momentum in a system, including any asynchrony-induced momentum. We use a simple negative-feedback loop to control the algorithmic momentum value.
  • Flying fish: Experimental results on synchronous and asynchronous settings show that YellowFin often requires fewer iterations to train models such as ResNets and large LSTMs compared to Adam. People have reported great results on NLP tasks!

To try our tuner, get your fins on here for Tensorflow and here for PyTorch.

Robustness properties of momentum

At the heart of YellowFin, there is a very simple technical nugget; we describe here the basic idea. Let us focus, for a moment, on quadratic objectives. Classic results by Polyak and Nesterov prescribe the optimal value of momentum as a function of the dynamic range of curvatures: i.e. the condition number, κ.

Constant rate achieved for non-convex scalar objective

Assuming the learning rate is also selected wisely, the above momentum value ensures that all variables converge to their final value in the same linear rate.

Constant rate achieved for non-convex scalar objective

Specifically, this comes from the fact that the 2x2 linear operator describing momentum dynamics along any scalar slice of the objective, has the exact same spectral radius, ρ, given above. We call the family of hyperparameter configurations that satisfy this condition, the robust region.

We give a simple generalization of this result for certain non-convex functions. In particular, we can define a generalized condition number, that captures curvature variations along a scalar function. For example the function shown in the following figure (left), is composed of two different quadratics with curvatures c and 1000c, and thus has generalized condition number 1000. Our analysis implies that tuning momentum according to that generalized condition number, ensures a constant spectral radius for the momentum operator on all parts of the function.

While this result on the spectral radius does not necessarily imply a convergence guarantee for non-quadratic objectives, we observe this behavior in practice on the next figure (right).

Constant rate achieved for non-convex scalar objective
Constant convergence rate on a toy non-convex objective, shown on the left. The right plot shows the progress of the momentum algorithm in black, and the root-momentum rate expected from theory in red.

We validate this on on real models, like the LSTM in the following figure. We observe that for large values of momentum, most variables (grey lines) follow the root μ convergence rate (red line) from our quadratic model.

Constant rate achieved training LSTM
Constant convergence rate when training a real model (LSTM).

This observation informs YellowFin’s design principles.

YellowFin’s design and results

Design principle 1: Stay in the robust region.

We tune the momentum value to keep all variables in the robust region. On a quadratic approximation, this guarantees convergence of all model variables at a common rate, though it empirically extends to certain non-convex objectives.

To tune momentum, YellowFin keeps a running estimate of curvatures along the way, which yields an estimate of the generalized condition number. This estimate doesn’t need to be accurate. We see in practice that rough measurements from noisy gradients give good results. This design principle gives a lower bound on the value of momentum.

Design principle 2: Optimize hyperparameters at each step to minimize a local quadratic approximation.

Given the constraint from principle 1, we simply tune the learning rate and momentum to minimize the expected squared distance to the minimum of a local quadratic approximation. Please refer to our paper for full implementation details.

Results

Our experiments show that YellowFin, without tuning, often needs fewer iterations than Adam with hand-tuned base learning rate and hand-tuned momentum SGD that uses constant hyperparameters or simple schedules to train ResNets and LSTMs.

Tuning results on two different resnet models
Training loss for tuned momentum SGD, tuned Adam, and YellowFin on (left) 110-layer ResNet for CIFAR10 and (right) 164-layer ResNet for CIFAR100.
Tuning results on three different models
LSTM test metrics (reporting best value so far) for tuned momentum SGD, tuned Adam, and YellowFin on (left) word-level language modeling; (middle) character-level language modeling; (right) constituency parsing.

Asynchronous dynamics and closed-loop YellowFin

Our recent work suggests that asynchrony induces momentum. This result means that when we run asynchronously, the total momentum present in the system is strictly more than the algorithmic momentum value we feed into the optimizer, because it includes added asynchrony-induced momentum.

In our new paper we demonstrate for the first time that it is possible to measure total momentum. The next figure shows that our measurement exactly matches the algorithmic value when training synchronously (left). On asynchronous systems, however, the measured total momentum is strictly more that the algorithmic value (right).

Measuring total momentum and closing the loop
Measured total momentum (left) matches the algorithmic value in the synchronous case; (right) is higher than the algorithmic value, when using 16 asynchronous workers.

This momentum excess can be bad for statistical efficiency. Our ability to measure total momentum allows us to compensate for asynchrony on the fly. Specifically, we use a negative feedback loop to make sure the measured total momentum tracks the target momentum decided by YellowFin.

Measuring total momentum and closing the loop
Closing the momentum loop on 16 asynchronous workers: the negative feedback loop uses the total momentum measurement to reduce the algorithm momentum value. The end result is that total momentum closely follows the target value.

Results

Closing the momentum loop results in less algorithmic momentum, sometimes even negative! Here we see that this adaptation is very beneficial in an asynchronous setting. (Open-loop) YellowFin already performs better that Adam, mostly due to its ability to reach lower losses. However, when we close the loop, the result is about 2x faster (and almost 3x faster to reach Adam’s lowest losses).

Tuning results on three different models
Closing the momentum loop on 16 asynchronous workers: the negative feedback loop uses the total momentum measurement to reduce the algorithm momentum value. The end result is that total momentum closely follows the target value.

Conclusion

YellowFin is an automatic tuner for momentum SGD that is competitive with state-of-the-art adaptive methods that use individual learning rates for each variable.

In asynchronous settings, it uses a novel closed-loop design which significantly reduces the iteration overhead.

Too much tuning.

What’s next

  • YellowFin assumes some measurement oracles for quantities like curvature variation and gradient variance to tune its hyperparameters. For our first implementation we used some simple approximations based solely on gradients that give us good results. It would be interesting to study different implementations for those oracles, e.g. using a backprop-based estimate for the Hessian diagonal, or smart line-searches to estimate the distance from a local minimum.
  • Computational optimizations: so far we focused on showing that it is possible to engineer a smart momentum tuner for plain momentum SGD that is competitive with the state of the art in terms of statistical efficiency (the number of iterations to solution). Next we will be working on minimizing the computational overhead of our tuner.
  • In our paper we show that hand-tuning Adam’s momentum can improve its asynchronous performance. We believe that designing and implementing our momentum measurement and negative feedback loop for other methods can make them better under asynchrony.

Endorsements/testimonials