Training on the device

22 November 2017 15 minutes

Table of contents

As you’ve no doubt noticed, machine learning on mobile is a thing now. 😃 Apple mentioned it about 100 times in the WWDC 2017 keynote. With all the hype, it’s no surprise that developers are scrambling to add ML to their apps (I can help out).

However… most of these machine learning models are only used for inference, they make predictions using a fixed set of knowledge. Despite the term “machine learning”, no actual learning happens on the device — the knowledge inside the model never improves from using it.

A big reason is that training a model takes a lot of computational power and mobile phones are just not fast enough. It’s more practical to train offline on a server farm and include any model improvements in an app update.

That said, training on the device does make sense for certain apps — and I believe that it’s only a matter of time before training models on the device becomes just as normal as using them for inference.

In this blog post I want to explore the possibilities. Let’s see if we can put “learning” into machine learning.

The Rock telling his iPhone how it is

Today: inference only

The most common use of machine/deep learning in apps right now is probably computer vision for analyzing photos and videos. That makes sense because the iPhone is the most popular camera in the world.

But ML is not limited to just images, it’s also used for audio, language, time series, and many other types of data. A modern phone has a dozen different sensors plus fast internet access, so there is lots of data to feed into our models.

iOS itself uses several kinds of on-device deep learning models, such as the face detection in the Camera and Photos apps, listening for the “Hey Siri” phrase, and handwriting recognition for Chinese characters.

But all of these models do not learn from the user.

Pretty much all mobile machine learning APIs (MPSCNN, TensorFlow Lite, Caffe2) only support inference. That is, you can use models to make predictions based on the user’s data or behavior, but you cannot make these models learn new things from that data.

At the moment, training typically happens on a massive server with lots of GPUs. It’s a slow process that needs a lot of data. A convolutional neural network, for example, is trained on thousands or even millions of images. To train a modern CNN from scratch takes days on a powerful multi-GPU server, weeks on a desktop computer, and an eternity on a mobile device — definitely not something you can do on a single battery charge.

Training on a high-end server is a perfectly fine strategy for when model updates happen infrequently and each user always uses the exact same model. The app only gets model updates whenever it is updated in the App Store, or by periodically downloading new parameters from the cloud.

But just because training large models on the device isn’t feasible today, doesn’t mean it will be impossible forever. Also, not all models need to be large. And most importantly: one-model-for-everyone may not be the best we can do.

Why learn on the device?

There are advantages to training models on the device:

It’s not appropriate for all situations, but there are definitely applications of on-device learning that make sense. I think the mean benefit is that it allows you to tailor the model to the individual user.

The following apps already do learning on iOS devices:

Some of these tasks are simpler than others. Often “learning” is just a matter of remembering the last thing the user did. For many apps this is good enough and it doesn’t require any fancy machine learning algorithms.

The model for the predictive keyboard is simple enough that training can happen on the device in real-time. The Photos app’s people learning task is much slower and uses a lot of power, which is why this only works when the device is plugged in. Most practical uses of on-device learning will sit somewhere between these two extremes.

Other examples of existing software that learns from you are spam detection (your email client refines its idea of what spam is based on which emails you classify as junk), spelling and grammar correction (it learns the common mistakes you make while typing and fixes them), and smart calendars such as Google Now (learns to recognize repeated actions that you perform).

How far can we take this?

If the goal of training on the device is to adapt the machine learning model to the needs or usage patterns of specific users, then what sort of things can we do with this?

Here is a fun toy example: a neural net that turns gestures into Emoji. It asks you to draw a few different shapes and then it trains the model to detect these strokes.

Teaching a neural net to detect gestures

This is implemented as a Swift Playground, which are not exactly known for being speedy. Even so, if doesn’t take very long to train this neural network — on a device it only takes a few seconds. (If you’re curious, here is a great description of how this model works.)

So if your model is not too complex — like this 2-layer neural net — training on the device is already in reach right now.

Note: On iPhone X, developers have access to a low-resolution 3D model of the user’s face from the Face ID sensors. You could use this data to train a similar model that chooses an emoji — or some action in your app — based on the user’s facial expressions.

Here are some future possibilities that are a bit more advanced:

These are just some ideas. Since everyone is different, it makes sense that the machine learning models we use will get tweaked to suit our specific needs and desires. Instead of building one-size-fits-all models, training on the device lets us build a unique model for every unique user.

Different scenarios for training models

Before you can deploy a machine learning model you first need to train it. And afterwards, you can keep training to refine the model. I believe the big benefit of training on the device is that you can customize the model to each user, and the key idea there is to train the model on that user’s data rather than with a generic dataset (or at least in addition to it).

These are the different options for learning from the user:

Don’t learn from user data at all. Collect your own data or use a publicly available dataset and build a one-size-fits-all model. Whenever you improve the model, release an app update or make the app download the new parameters. This is what the current crop of ML-enabled apps does: train offline, use the model for inference only. The point behind this blog post is to move on from that.

Central learning. If your app or service already requires that the data from the user (or about the user) is stored on your servers (non-encrypted, so you can read it), then it makes sense to do training on the server as well. Send the user’s data to the server and keep it there to learn things from it, possibly specific to that user or for all users in general. This is what platforms like Facebook do.

This setup has issues with privacy (there is none), security (all the data is in one place), scaling (more users means you need a bigger server), etc. Those are the situations we want to avoid by training on the device.

Note: There are other ways to avoid the privacy issue, such as what Apple does with their “differential privacy” approach to gathering user data, but that has its own shortcomings.

Collaborative learning. This is mostly just a way to move the cost of training to the users instead. Training happens on the device and each user trains a small part of the model. These partial model updates are shared with other users so they can also learn from your data, and you from theirs. But it’s still a one-size-fits-all model, as everyone still ends up with the same learned parameters.

The main benefit is that the training is decentralized and instead gets distributed over users’ devices. In theory this is better for privacy, but research shows it may actually be worse.

Each user trains their own model. This is the option that I’m personally most interested in, as it lets us customize machine learning for each individual user. The model can be learned from scratch (such as in the gesture-to-emoji example) or it can be a pre-trained model that is fine-tuned on your own data. In both cases, we can keep refining the model over time. For example, the predictive keyboard starts with a generic model trained in a specific language but over time it learns to predict the kinds of sentences that you write.

The downside of this train-your-own-model approach is that other users cannot benefit from the things the app has learned from you. So this really only makes sense for apps that use data that is relatively unique for each user.

How to actually do on-device training?

One thing to keep in mind is that learning from an individual user’s data is different from training on a large dataset. The initial model for the predictive keyboard may have been trained on a standard corpus (such as all text from Wikipedia) but a text message or email may not have the same writing style as a typical Wikipedia article. And this writing style will be different from one user to the next. The model must allow for these kinds of variations.

There is also the problem that our best training methods for (deep) models are brute-force and rather inefficient. As I’ve pointed out, training an image classifier can take days or weeks. The bigger the computer the better — with 1024 GPUs you can train an ImageNet classifier in 11 minutes.

It takes so long because the training process, Stochastic Gradient Descent (SGD), needs to take small steps. There are typically a million or so images in the dataset and the neural network looks at each image about 100 times.

Obviously, this training method is not feasible on a mobile device.

To be fair, you often don’t need to train a model from scratch. Most people take a pre-trained model and then use transfer learning to make it fit their own dataset. But these smaller datasets typically still consist of thousands of images, and so even transfer learning is rather slow.

It’s safe to say that with our current training methods, fine-tuning deep learning models on mobile is still a ways off.

However, not all is lost. For simple models it’s already possible to train them on the device. The gestures-to-emoji neural network we’ve seen is a basic feed-forward network with one hidden layer. It is trained using SGD with momentum and only takes a few seconds to complete.

Classical machine learning models, such as logistic regression, decision trees, or naive Bayes are typically very quick to train as well, especially when using second-order optimization methods such as L-BFGS or conjugate gradient. Even a basic recurrent neural network should be within the realm of possibilities.

For models like the predictive keyboard, some kind of “online” training method might work. Here, you do a single training pass after every X characters or words that the user types. Likewise for models using the accelerometer and motion data, where the data comes in as a constant stream of numbers. Since these models are trained on only a small piece of data at a time, each training update is fast.

So if your model is small and you have relatively little data, then the training time will be on the order of seconds on a modern device. If you do it in a background thread, no one will notice.

But if your model is not small, or you have a lot of data to process, then you need to get creative. A model that wants to learn the faces of the people in your photo library simply has a lot of data to go through, and so you’ll need to find a balance between the speed and accuracy of your learning algorithm.

These are some of the issues to overcome:

It’s still early days for training on the device, but in my opinion it’s an inevitable technology — and one that will become important in the design of software.

I wrote this blog post as a way for me to think through what is already possible and where things may be heading. I hope you found it a useful exercise. 🏋 As always, I look forward to hearing your thoughts!

Written by Matthijs Hollemans. First published on Wednesday, 22 November 2017.

I hope you found this post useful! Let me know on Twitter @mhollemans or email me at matt@machinethink.net.

Want to add machine learning to your app? Let me help!