Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Reinforcement Learning for Medical Device Control Made Easy

License

Notifications You must be signed in to change notification settings

SPARC-FAIR-Codeathon/2024-team-8

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

95 Commits
 
 
 
 
 
 
 
 

Repository files navigation


Stable Baselines 3 Gymnasium SB3-Contrib TensorFlow

Reinforcement Learning for Medical Device Control Made Easy

The control of medical devices, particularly in applications like neuromodulation, is increasingly focused on the development of closed-loop systems. These systems are gaining attention because they allow for real-time adjustments based on continuous feedback, leading to more precise and personalized treatments. However, developing the adaptive intelligence required for closed-loop control often demands specialized knowledge in sophisticated control theory, e.g. in reinforcement learning, posing a significant barrier for many researchers.

SPARC.RL addresses this challenge by offering a proof of concept toolchain that simplifies the training of state-of-the-art reinforcement learning agents, even for those without deep expertise in the field. By leveraging the powerful Stable Baselines3 framework and seamlessly integrating data from the SPARC platform with models from oSPARC, SPARC.RL democratizes access to advanced reinforcement learning techniques. This toolchain empowers researchers to explore and implement sophisticated control strategies, accelerating the development of more effective and personalized medical interventions through closed-loop systems.

To demonstrate how the SPARC.RL toolchain enables researchers, even those without specific domain knowledge in reinforcement learning, to effortlessly train state-of-the-art reinforcement learning (RL) agents for robust medical device control we have implemented a pipeline that trains a closed-loop vagus nerve stimulation controller for heart rate control. By allowing users to integrate other data from the SPARC platform and other models from oSPARC seamlessly into a reinforcement learning pipeline running on the oSPARC platform, SPARC.RL provides a framework that can be applied to arbitrary control problems in the biomedical field.

This toolchain and oSPARC pipeline were developed during the 2024 SPARC FAIR Codeathon by Max Haberbusch and John Bentley.

The framework resulting from this codeathon can be accessed on oSPARC.

Note: While this project offers powerful capabilities, it is an initial prototype serving as a proof of concept and no guarantees are made regarding its absence of bugs and interoperability with datasets other than those used during development.

Key Features

The SPARC.RL toolchain consists of a stand-alone client that is used to obtain and preprocess appropriate datasets for reinforcement learning from the SPARC platform and an oSPARC pipeline that allows the development of reinforcement learning-based control algorithms using Stable Baselines3 (Figure 1). This involves using a large language model (LLM) to present suggestions to the user on how to use the selected dataset in reinforcement learning. The stand-alone client helps to preprocess the data, design a proper neural network architecture, and train a surrogate deep neural network model. The resulting trained model is saved to hard disk.

The second component is an oSPARC application that enables the use of Stable Baselines3 on oSPARC (Figure 2). Here the user can load the trained surrogate model, select the reinforcement learning policy, and set the training hyperparameters. This final stage of the pipeline produces a trained control policy that can then be used in the closed-loop system.

Additionally, SPARC.RL provides a fully integrated reinforcement learning pipeline running on the oSPARC platform that also enables the training of surrogate models to efficiently represent system dynamics (Figure 2).


Overview of the SPARC.RL toolchain.
Figure 1. Overview of the SPARC.RL toolchain.



Figure 2. Fully integrated SPARC.RL reinforcement pipeline on oSPARC.


Dataset and Model Integration

SPARC.RL supports the selection and use of time-series datasets directly loaded from the SPARC platform using the SPARC Python client. Users can also work with select oSPARC models, enabling the training of RL agents in a highly flexible and customizable manner.

Note: During development of our toolchain in the 2024 SPARC FAIR Codeathon, we used the Oliver Armitage et al. "Influence of vagus nerve stimulation on vagal and cardiac activity in freely moving pigs" dataset available on SPARC (DOI: 10.26275/aw2z-a49z).

Customizable Inputs and Outputs

The ultimate goal is to allow users to choose from multiple available model inputs (actions) and model outputs (observables) to tailor the reinforcement learning process to their specific needs. However, in the current version, this can only be done by modifying the source code. Later versions will allow the user to pick appropriate actions and observables directly from the graphical user interface.

Data-driven Modelling

SPARC.RL offers multiple deep learning architectures to create surrogate models of experimental data available on SPARC or oSPARC models. Users can select from various RNNs optimized for time-series modeling, including vanilla Recurrent Neural Networks (RNNs), Long Short-Term Memory (LSTM) networks, Bi-directional LSTM (BiLSTM) networks, and Gated Recurrent Units (GRUs), providing flexibility in how the models are trained. Users can configure their network and training parameters according to their needs. The adjustable parameters include the number of layers, number of units per layer, optimizer, learning rate, batch size, number of epochs, and early stopping policies.

Steps to Generate the Surrogate Model

Installation

Clone the Repository
git clone https://github.com/SPARC-FAIR-Codeathon/2024-team-8.git
cd 2024-team-8/sparcrl_surrogate
Create and Activate a Conda Environment

If you don't have Conda installed, you can install Miniconda or Anaconda as they come with Conda included.

Create and activate the environment using the provided environment.yml file:

conda env create -f environment.yml
conda activate sparcrl
Run the Surrogate Modeling Tool

Now you are all set to run the surrogate modeling tool. To do so run the following command on your command line.

python sparcrl_surrogate.py

In the first step, select a dataset from the dropdown menu which is automatically populated with available datasets on the SPARC platform. Currently, the datasets are limited to time series data. Once you have selected a model, you can inspect the model metadata, including model description, creator, creation date, version, etc. Additionally, a large language model is used to generate suggestions on how to use the dataset for reinforcement learning. Once you have chosen your dataset, you can download and extract the data from SPARC by hitting the 'Get Dataset!' button. You will be asked for the file path to save the data. After that, you can proceed to the next step, which is to select the file(s) to use for training the surrogate model, by hitting the 'Next' button.


Select dataset from SPARC platform
Figure 3. Select dataset from SPARC platform to train surrogate model.


Once you have chosen and downloaded an appropriate dataset, you can select one of the available files containing experimental data using the dropdown menu. The data is automatically filtered for appropriate file types. Currently, only the .hdf5 file format is supported. After you have selected a file, the data is pre-processed to bring it in the proper format for training the model. You can display the pre-processeed data by hitting the 'Plot Data!' button. If you are satisfied with the preprocessed data, you can move to the next step by hitting the 'Next' button.

Select data file from dataset and preprocess
Figure 4. Select a file from the dataset for model training and inspect preprocessed data.


After loading the data for training, you can define your model architecture. For now, the tool allows you to use different types of recurrent neural networks including LSTM, Bi-LSTM, GRU and vanilla RNNs. You can adjust the number of layers and units per layer based on the complexity of the dynamics in the data that you are trying to capture. Additionally, you can specify training-related parameters like batch-size, learning rate, optimizer, number of epochs, and also early stopping to prevent model overfitting. Once you have defined the parameters, you can hit the 'Train!' button to start the training. This will print the final model architecture and start the training. For now, a fixed ratio of 8:1:1 for training, validation, and test datasets is used. Currently, if you want to adjust the ratio, you must modify the source code.

Note: The status messages about the training are written to the console and not passed on to the graphical user interface. If you want to observe the training progress, please check the terminal that you used to start the graphical user interface. Also, during the training, the user interface might become unresponsive. Do not worry, just wait until the training is finished.

Define model architecture and start the training of the surrogate model.
Figure 5. Define model architecture and set training parameters.


Now you can sit back and watch Tensorflow do its magic to train your surrogate model. The trained surrogate model is saved along with the training data into the `training_data` directory in your project folder.

Training of the model
Figure 6. Observe training progress.


After the training is completed, you can access the training data (.csv files) and the trained model (.h5 file) that was saved to your hard disk from your project directory.

Data saved to hard disk.
Figure 7. Training data and trained model saved to hard disk.


Reinforcement Learning Using SPARC.RL on oSPARC

Using SPARC.RL Train Surrogate Model Node on oSPARC

Training of the surrogate model can also be done on the oSPARC platform. However, this approach does not provide the ability to directly select data from the SPARC platform as can be done in the stand-alone client. The training can be run using the SPARC.RL Train Surrogate Model node. This node tries to approximate the underlying dynamical system based on the relationship of the inputs and outputs from the csv files (input.csv and output .csv) passed to the node. The SPARC.RL Train Surrogate Model node saves the trained deep neural network to a .h5 file (model.h5).

The trained surrogate model then serves as an input to the SPARC.RL Train Agent node, which is used to train the reinforcement learning agent. The output of this node is a .zip file containing the trained reinforcement learning agent (ppo_cardiovascular.zip), which then can be used as a controller.


Data saved to hard disk.
Figure 8. Overview of connected SPARC.RL nodes in a Study on oSPARC.


Below, in Figure 9, you can see an example output of training a surrogate model using the SPARC.RL Train Surrogate Model node on a synthetic dataset generated with the model from [Haberbusch et al.](https://sparc.science/datasets/335). Here the stimulation intensity was varied from 0 mA to 5 mA in steps of 0.1 mA and the respective heart rate change was calculated. The model included only a single Long Short-Term Memory (LSTM) layer. The data was split into training, test, and validation sets with a ratio of 8:1:1.

Surrogate training on oSPARC using SPARC.RL Train Surrogate Model Node
Figure 9. Surrogate training on oSPARC using SPARC.RL Train Surrogate Model node.


The trained surrogate model showed very good performance in reproducing the dynamics of the full in silico model, as illustrated for one example from the test dataset shown below (Figure 10).

Surrogate model predictions compared to ground truth running in SPARC.RL Train Surrogate Model node on oSPARC.
Figure 10. Surrogate model predictions compared to ground truth running in SPARC.RL Train Surrogate Model node on oSPARC.


After training the surrogate model, users can parameterize the RL process by selecting from a range of popular RL algorithms such as A2C, DDPG, DQN, HER, PPO, SAC, and TD3, along with their respective policies. The tool supports detailed customization, including choosing the type of action space (discrete or continuous), specifying value ranges, and setting the number of actions for discrete spaces.

SPARC.RL Train Agent Node

Paramterize Reinforcement Learning

The SPARC.RL Train Agent node is designed to allow various aspects of the reinforcement learning setup and testing process to be parameterized. Environment settings, such as the choice between discrete or continuous action spaces and the number of parallel environments for training, can be adjusted. The path to the surrogate model and the specific heart rate targets used during testing are also configurable. PPO model parameters, including the policy type, number of training steps, batch size, and total timesteps, can be defined to optimize the training process. Additionally, testing parameters, such as the number of iterations and the interval for changing heart rate targets, can be customized. Finally, the paths for saving and loading trained models are configurable, enabling the script to be flexible and adaptable to different experimental needs.

The reinforcement learning can be adjusted to ones needs by modifying the rl_config.ini file that is used as input to the SPARC.RL Train agent node. As an example, the parameters used during the codeathon are provided below:

[Environment]
discrete = True
model_path = model.h5
env_id = 1337
parallel_envs = 6
target_hrs = 72.0, 74.0, 76.0, 78.0, 80.0, 82.0

[PPO]
policy = MlpPolicy
n_steps = 256
batch_size = 32
total_timesteps = 50000
tensorboard_log = ./tensorboard_logs/
tb_log_name = first_run

[Testing]
test_iterations = 1000
steps_per_target_change = 50
render_mode = human

[SavedModel]
save_path = ppo_cardiovascular

An example output of the training to the respective Jupyterlab notebook of the SPARC.RL Train Agent node is shown in Figure 11 below.


Running Proximal Policy Optimization (PPO) with the previously trained surrogate model in SPARC.RL Train Agent node on oSPARC.
Figure 11. Running Proximal Policy Optimization (PPO) with the previously trained surrogate model in SPARC.RL Train Agent node on oSPARC.


The same output is visualized in Figure 12 below as a tensorboard depicting the loss function, policy gradient loss, and value loss for the training of the reinforcement learning algorithm over 50,000 total timesteps using the configuration specified in the rl_config.ini file shown above. It shows good convergence of the learning, suggesting proper controller performance.

Training loss reinforcement learning.
Figure 12. Loss for reinforcement learning agent training.


After completion of the training phase, the trained agent was tested on the surrogate model in a heart rate tracking task lasting for 1000 seconds. The agent had to track several randomly changing setpoint heart rates (changes occured every 50 seconds) from the setpoints specified in the .ini file above (72.0 bpm, 74.0 bpm, ..., 82.0 bpm). The controller showed very satisfactory performance in terms of steady state error quantified by the mean squared error between setpoint and measured heart rate of only 1.75 bpm (Figure 13).

Testing trained agent on the surrogate model.
Figure 13. Testing the trained reinforcement learning agent on the surrogate model in SPARC.RL Train Agent node on oSPARC. Running 1000 seconds of heart rate tracking with random setpoint heart rates resulted in a steady state error quantified by mean squared error between setpoint and measured heart rate of only 1.75 bpm.


SPARC.RL Controller

The controller can be deployed with the full cardiovascular system model from Haberbusch et al. to form a complete control loop for evualation. It is available on oSPARC.

License

This project is distributed under the terms of the MIT License.

About

Reinforcement Learning for Medical Device Control Made Easy

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages