An argument in favor of strong scaling for deep neural networks with small datasets

In this post, I'm writing about my most recent paper, which is joint work with Eduardo R. Rodrigues, Matheus Palhares Viana, and Dario Augusto Borges Oliveira. If you prefer videos, there's a pretty comprehensive version below. You can read the full paper on the arXiv.

Apart from all the hype that is certainly involved in Machine Learning, Deep Learning, and Artificial Intelligence these days, there is a lot of very good research being done. Since much of this research comes from big cloud players such as Amazon, Google, IBM, and Microsoft, it is natural for such research to favor big clusters with humongous datasets. This is also good marketing: we all want to have access to the biggest machines possible, and it is good business for a cloud provider to know how to handle big machine learning workloads.

Anyway, Deep Learning itself is data hungry, so it only makes sense to apply it to big datasets, right? Otherwise, the models would tend to overfit, given their large numbers of parameters, right? Right?! 🤔

Well, not quite! 😅

It turns out we can use smaller datasets to train models that augment such small datasets. The key idea here is that we can train a generative network to learn to approximate the distribution that generated the data we're seeing, but I digress.

The thing is: there's space for models that handle small data, although we have to be a bit more clever about how we handle them. Once we realize that, the question now becomes: "how do we scale such models?" One might argue that models with small data don't need to be scaled, since they are fast.

That's not the case, though. Our colleagues have developed a model capable of learning the distribution of lung nodules in the LUNA dataset, but it takes a whole week to train. That's a bit too much, since tuning hyperparameters takes forever, and we might end up with a poorly tuned model due to this limitation.

HPC to the rescue!

I hope I've convinced you we have a problem. Let's solve it! But first, we have to make sure we understand how these models are trained.

Minibatch Stochastic Gradient Descent

Most Deep Learning models are trained by backpropagation and gradient descent. In this setting, we sample some data points from the data set, evaluate the loss of the model with its current weights (which are the model's parameters) and the set of points we sampled. Then, we compute the derivative of the individual losses and move in the mean direction of the gradient we computed.

Slightly more formally, we:

  • Update weights $\color{blue}{w_{t}}$ *iteratively* by
  • Sampling entries $x_i$ from a mini-batch $\color{red}{\mathcal{B}}$
  • Taking a small step $\propto\color{gray}\eta$ towards
  • The mean of the derivative $\nabla$ of the loss $l$ of $x_i$ and $\color{blue}{w_t}$

In other words, we repeat the update below until convergence or our time runs out.

$${\color{blue}w_{t+1}} = {\color{blue}w_{t}} - {\color{gray}\eta}\frac{1}{|{\color{red}\mathcal{B}}|}\sum_{x\in{\color{red}\mathcal{B}}}\nabla l(x, {\color{blue}w_t})$$

From the above description, we see that $\nabla l(x_i, w_t)$ can be computed independently for each $x_i$ and, therefore, can be run in parallel.

Parallelizing training

There are at least two general frameworks able to parallelize the training of neural networks: IBM DDL and Horovod. Individual frameworks such as TensorFlow and PyTorch also have their own parallelization strategies.

We don't want to go into the specifics of it, since our point is that how you define your minibatch can greatly affect your model's ability to learn!

For example, if you open one of the horovod examples, you will notice two things:

  1. It scales the learning rate $\eta$ according to the number of GPUs
  2. It maintains the batch size constant for each GPU, essentially also increasing the effective batch size proportionally to the number of GPUs.

This seems to be a good idea, since we are keeping the work each processor does constant, and this seems to be able to achieve good training performance as well.

BUT this assumes the size of your data set is much greater than your effective batch size. If you have a small dataset, that might not be the case, and that's when this scaling strategy seems to break down.

In our experiments, this particular strategy seems to fail every single time.

The alternative

Our proposed alternative to fixing this problem is simple: maintain the effective batch-size constant. This obviously places an upper bound on the number of parallel workers: the number of workers is at most the size of the original batch size. But this has another consequence: apart from random fluctuations, optimization proceeds exactly as in the single processor case. Therefore, if the single processor version converges, the multi-processor version is guaranteed to converge as well.

In the multi-processor case, our update then becomes

$$w_{t+1} = w_t - \eta\frac{1}{|\mathcal{B}|}{\color{red}\sum_{i=0}^{N-1}}\left(\sum_{x\in B_i}\nabla l(x, w_t)\right)$$

where $N$ is the number of parallel workers, and $\cup_{i=0}^{N-1}B_i=\mathcal{B} \text{ and } \cap_{i=0}^{N-1}B_i=\varnothing$, which means there is no overlap in the minibatches processed by each worker.

We've compared our proposed approach to what's been proposed in the literature for a medical imaging application, and our proposed update not only was able to converge in all experiments, but it also had performance improvements in all instances.

How good is it, anyway?

Some results are shown below. The full discussion can be found in the paper. In the figure below, we see that the strong scaling approach is not only the implementation that scales the best, it also converges in all cases! (Missing points indicate implementations that fail to converge.)

Timing comparison of the various scaling strategies

Time to reach a given loss for the various algorithms for various numbers of GPUs. Our proposed update is shown in the blue line.

As can be seen in the graph above, scaling strategies that use the linear scaling rule fail to reach the target loss defined.

Weak scaling loss

Strong scaling loss

Quick analysis

As can be seen in the above graphs, for the particular application we tested (an autoencoder for 3D lung nodule imagery), strong scaling is the only strategy that converges in all tests. Not only that, but it also was the fastest one, particularly when more GPUs were used.

Hence, we have shown the general advice doesn't seem to hold for applications with small data sets, even though such applications can benefit from parallel implementations.

If you liked this blog post, I encourage you to read the full paper on the arXiv.

Renato Luiz de Freitas Cunha
Principal Research Software Engineer

My research interests include reinforcement learning and distributed systems