Collaborative Machine Learning without Sharing Data
In December 2021 we published several assets to support Medical Imaging with Azure Machine Learning. The great interest and numerous inquiries have surprised us very much. It once again makes clear that AI applications are becoming increasingly important in medical practice.
With Federated Learning, we are today introducing an exciting new addition to our portfolio, which has great significance inside and outside healthcare to protect data and intellectual property.
The following illustration provides an overview of our demo use cases including our latest addition of Federated Learning.
In the demo scenario, you can build a global Federated Learning scenario with simulated participating hospitals in the United States, Europe, and Asia to develop a common ML model for detecting pneumonia in X-ray images. In this article, we describe the conceptual basis of Federated Learning and walk through the key elements of the demo.
Check out the repository to try out this use case or the other scenarios and adapt them for your own business scenarios.
Critical success factors for generalizable deep learning models are the availability of extensive and heterogeneous training data. A reliable cancer detection model should be trained based on thousands of medical images showing healthy tissue and tumors in contrast. This should also represent the real-world range of gender, age, and other demographic properties of patients. In addition, visual features resulting from the different imaging techniques should be represented. The required variety of representation simply cannot be covered by a single institution. This is especially true in the case of rare diseases, where only a few data points are available in one organization.
Consequently, a collaborative ML development in which many diverse hospitals contribute their own data is obvious. Unfortunately, this often fails in practice due to data protection concerns, as patient data is regulated by strict data protection laws in most countries (for good reasons).
Federated Learning addresses this issue by allowing multiple parties to collaboratively train a machine learning model without the need of sharing data as the following illustration shows:
While peer-to-peer topologies are possible with Federated Learning, most implementations use a central server for model parameter aggregation and orchestration of the process.
The key characteristic of Federated Learning is that the local data of the participants is never shared. Each client regularly receives a copy of the global model and performs local training with local data. The only information that is shared are the “insights” gained from training, namely the model parameter updates. One task of the server is to aggregate the weights of all participating clients into a new model version. The updated model is then redistributed to the clients to initiate the next federated training round. This process is repeated until model convergence is achieved. It can also be used for continuous training by utilizing new data points that are acquired by the participating clients in the production phase.
The following diagram illustrates the workflow of a single training round for the server-based model:
Let’s look at some key considerations along the steps of model distribution, local training, weight aggregation, and model update.
Among the first Federated Learning applications were smart keyboard scenarios for sentence completion with millions of participating smartphones (“cross device”). In contrast, Federated Learning scenarios between organizations (“cross-silo”) have different characteristics and requirements. Often, they involve a moderate number of participants like tens of hospitals. Typically, there are strict requirements related to security, privacy, and compliance. Federated Learning provides control options for such organizational use cases. One of them is to be able to define which clients are allowed to participate and to receive copies of the global model.
A big advantage of the cross-silo approach is that participants have an IT infrastructure and can utilize powerful GPU resources to support the large models that are common in medical imaging or other deep learning use cases. Our demo shows the on-demand use of Azure CPU or GPU computational resources for Federated Learning.
Clear agreements on data annotation standards are key requirements of the multiparty training setting, in particular for advanced tasks like image segmentation. Systematic deviations in labeling practices or quality may negatively affect the resulting model. In addition, in the Federated Learning concept there is no central quality assurance about data or annotation because data is not shared.
As Federated Learning gains more traction in enterprise scenarios, weight aggregation approaches become more sophisticated to address corporate requirements. An intuitive aggregation method is called Federated Aggregation: Each parameter is weighted by the number of local data points. For example, if hospital A’s update is based on 100 images and hospital B was trained on 200 images, then hospital B’s contribution would be twice as high as hospital A. Several variations of federated aggregation exist to improve convergence given deviating distributions or to improve privacy and security. An essential extension is Secure Aggregation, which protects communication against man in the middle attacks and prevents the server from accessing updates from individual clients in clear text. It is based on a cryptographic protocol that ensures that a client’s contribution looks like random noise unless combined in aggregate with the other clients‘ contributions.
For an end-to-end privacy-preserving solution, we also need to look at the resulting model. Today’s large deep learning models tend to memorize individual data points. For example, researchers have shown that it is possible to reconstruct a recognizable image of a person’s face from a trained face detection classification model. The risk of such model inversion attacks is not reduced with Federated Learning as it protects the input data, but not the resulting model. Therefore, Federated Learning is often combined with Differential Privacy to protect the resulting model by adding a defined amount of statistical noise during the training process. This reduces the risk of reconstructing individual data points or linkage attacks substantially. Check out the Differential Privacy demo of our repository if you want to learn more about the concept.
The demo in the medical imaging repository comprises a federated learning scenario where multiple hospitals train a central model using their own private datasets. The approach uses a central server for aggregation, which means that the server sends out instructions to the different hospitals and collects the model weights after the local training run is finished. After aggregation on the central server, the updated model is sent to the hospitals, to serve as a starting point for the next round.
In terms of architecture, the full set-up is implemented using Azure components. The hospitals are represented by Machine Learning workspaces, deployed in different regions to create a global set-up. The Federated Server is implemented using an Azure Data Science Virtual Machine. For configuration and orchestration, the NVIDIA Flare package is used.
The model is trained on a public Kaggle dataset containing x-ray images of patients with and without pneumonia¹. Using this dataset, a binary classification model is trained to classify x-ray images in the categories ‘pneumonia’ or ‘normal’.
Check out our step-by-step guide if you like to recreate the demo setup on your own. To make it easy to deploy the solution, there are some GitHub Action Workflows provided to prepare the environment for running the experiments.
One of the workflows available in the repository creates the Azure resources, mainly a Data Science Virtual Machine with specific configurations and three different Azure Machine Learning Workspaces. There is also a Compute Instance created for every workspace. The Compute Instances will be running client packages generated by NVIDIA Flare. Client packages are used to initiate communication with the server and admin client, which will be running on the Data Science Virtual Machine.
There is also a workflow provided that downloads the pneumonia dataset from Kaggle and resamples it. After resampling, the data is split into three different parts to create private datasets for the different clients. The last step of the preparation phase is to use the Azure ML CLI to upload and register the private datasets as data assets to the different workspaces. Using this approach, we can consume the training data in the training script that is run on the different clients.
After setting up the infrastructure, we can initiate the packages that will be uploaded to clients and server. To generate them, we login to the virtual machine using, for example by establishing an SSH connection via a local terminal. Then we clone the Medical Imaging repository on the VM.
The federatedlearning-folder in the repository contains a file called project.yml. This file contains the settings for the server, clients and more configuration options for the federated setup. For more details, see this page.
The NVFlare provisioning tool can be run with the following command:
provision -p project.yml
This will generate a startup kit for each participant that are created by default in a directory prefixed with “prod_” within a folder of the project name in the workspace folder where the provisioning command is run. The client packages are copied to the different machine learning workspaces that are created. Both the server and admin package stay on the federated server but are moved to the home directory of the VM.
The experiments can be controlled using the admin client. The client and server package can run on the same Virtual Machine. To initiate connections between server and clients, start by running the start command on the server. VM. We do this by navigating to the startup folder in the server startup kit and running bash start.sh.
When the server is started, the start scripts on the clients can be run. This should lead to a successful connection between the server and the client. In the terminal of the server, the output should look like the output below:
After running the start up scripts for the three different clients, we can verify using the admin client whether the connection succeeded. For this, we open another terminal, and run the startup script in the admin folder. Sign in by using email@example.com as username.
By typing ‘?’, the different commands of the admin client are listed. Some of them are relevant during the training cycle, like ‘deploy app’ and ‘start app’. Other commands are mostly focused on the IT management part of the solution.
By running the check_status command, we can verify if our three clients are successfully connected to the federated server.
After connecting the clients to the server, the experiment can be initiated.
The repository contains two applications as we call them in NVFlare, that can be used to run experiments. Both are using the same model architecture and training logic. The main difference is that one of the applications trains a federated model, and the other a central one. This allows us to do experiments where the performance of a centralized versus a federated run are compared.
To prepare our run, we need to set the run number for all the clients and run the check_client command afterwards to verify:
The next command we need to run is upload_app pneumonia-federated. This will upload our app to the transfer folder of the federated server, ready to be copied to the different clients.
To deploy the app to the different clients, we use the command deploy_app pneumonia-federated all. This will copy the training instructions to the different clients and server. If successful, the clients will return an OK-status to the server.
The final command we need is start_app all. This initiated the training run for the different clients:
The progress of the training run can be tracked using the terminal where the server startup script was run, or using the terminals from the separate clients.
An obvious question is how federated training compares to traditional (centralized) training. To answer that question a centralized training can be run after the federated experiment. In this experiment, the same algorithm will be trained on the full pneumonia dataset, but this time we are only using one client. The GitHub Actions workflow that is used to prepare the data, contains the step to register the full dataset to the Asia workspace. Therefore, we are using that client in this step.
The centralized experiment can be started by running the following command subsequently in the admin terminal:
deploy_app pneumonia-central client FL-Asia-Hospital
start_app client FL-Asia-Hospital
On the client side, events are being created during the training run and streamed to the server. These metrics can be found on the server in a folder with the name ‘run_
To run TensorBoard on the federated server, run the following command:
tensorboard --logdir=fedserver --bind_all
The -bind_all parameter makes the application accessible from the outside, which is needed if you want to access the dashboard from your local machine without leveraging Remote Desktop options. From a network perspective, we also need to open the port TensorBoard is using (6006 by default).
Below is a screenshot of what our TensorBoard looks like. On the left slide, we have the option to filter different runs that we want to visualize. This is very useful for comparing performance between runs with different parameters. The values that are visualized are training loss and validation accuracy. Training loss will indicate how good the models for the different clients are able to generalize on the local dataset. Validation accuracy shows the performance of the local model on a global validation dataset.
Based on the two experiments set-ups, we can compare results between a centralized run and a federated run. What can be seen from some experiments, is that a slightly higher accuracy is achieved with a centralized run. In some cases, the same performance can be achieved using both set-ups, although it takes longer for the federated training run to converge.
The visual below shows a comparison between a centralized and federated run, where we can see that the centralized run resulted in a higher accuracy. Further experiments with other parameters (e.g., learning rate, number of epochs), should be done to draw final conclusions, which should also not be generalized across different use cases.
Federated Learning is a compelling concept for co-development of ML solutions where training data cannot be shared. An obvious application area is privacy-preserving ML training. Other examples like the MELLODY project for drug discovery show applications where competing companies develop a joint model without having to worry about their intellectual property being compromised.
As maturity of the concept increases, including support for secure and robust federated aggregation, it becomes more prevalent in the enterprise and public sector organizations with their specific requirements.
The combination of Azure ML and NVIDIA NVFlare enables rapid realization even for global Federated Learning deployments and provides significant benefits in terms of compliance, data protection, manageability and traceability.
An end-to-end privacy preserving solution requires the combination of Federated Learning with technologies such as Differential Privacy.