Introduction to Federated Learning
- Federated Learning is a relatively new type of learning that avoids centralized data collection and model training
- Each mobile device (or edge node) trains a model using its data and the trained model is shared with the centralized server
- Even though the data is not directly shared with the centralized server, FL does not ensure full user privacy as shared model parameters can be used to learn the data used to train the model
The high-level steps involved in FL:
- A generic model (i.e. neural network) is created in the server, but it is not trained in the server (instead in the mobile devices)
- The model is sent to mobile devices so that each user trains a model in parallel using their local data. It is trained using an optimization algorithm such as stochastic gradient descent
- A summary of changes made to the model (i.e. model weights) are sent to the server
- The server aggregates the summaries from all models to improve the shared model. It uses an algorithm called federated averaging algorithm
- The last 3 steps are repeated until the convergence
For supervised learning, it is assumed that labeled data can be generated automatically from user interactions.
Some example tasks:
- NLP — voice recognition, next word prediction, the whole sentence prediction
- Image classification — predicting which photos are to be viewed frequently
A key goal of FL is to preserve user privacy, but there has been attacks showing that FL does not provide privacy as one can learn input data from model parameters. We will look into privacy attacks on FL in another post.