diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/README.md b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/README.md
new file mode 100644
index 0000000000..ee57df63e9
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/README.md
@@ -0,0 +1,71 @@
+# MedMNIST 2D Classification Using FedProx Optimizer Tutorial
+
+![MedMNISTv2_overview](https://raw.githubusercontent.com/MedMNIST/MedMNIST/main/assets/medmnistv2.jpg)
+
+For more details, please refer to the original paper:
+**MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification** ([arXiv](https://arxiv.org/abs/2110.14795)), and [PyPI](https://pypi.org/project/medmnist/).
+
+This example differs from PyTorch_MedMNIST_2D in that it uses the FedProx Optimizer. For more information on FedProx see:
+**Federated Optimization in Heterogeneous Networks** ([arXiv](https://arxiv.org/abs/1812.06127)).
+
+## I. About model and experiments
+
+We use a simple convolutional neural network and settings coming from [the experiments](https://github.com/MedMNIST/experiments) repository.
+
+
+## II. How to run this tutorial (without TLC and locally as a simulation):
+### 0. If you haven't done so already, create a virtual environment, install OpenFL, and upgrade pip:
+ - For help with this step, visit the "Install the Package" section of the [OpenFL installation instructions](https://openfl.readthedocs.io/en/latest/install.html#install-the-package).
+
+
+### 1. Split terminal into 3 (1 terminal for the director, 1 for the envoy, and 1 for the experiment)
+
+
+### 2. Do the following in each terminal:
+ - Activate the virtual environment from step 0:
+
+ ```sh
+ source venv/bin/activate
+ ```
+ - If you are in a network environment with a proxy, ensure proxy environment variables are set in each of your terminals.
+ - Navigate to the tutorial:
+
+ ```sh
+ cd openfl/openfl-tutorials/interactive_api/PyTorch_FedProx_MedMNIST
+ ```
+
+
+### 3. In the first terminal, run the director:
+
+```sh
+cd director
+./start_director.sh
+```
+
+
+### 4. In the second terminal, install requirements and run the envoy:
+
+```sh
+cd envoy
+pip install -r requirements.txt
+./start_envoy.sh env_one envoy_config.yaml
+```
+
+Optional: Run a second envoy in an additional terminal:
+ - Ensure step 2 is complete for this terminal as well.
+ - Run the second envoy:
+```sh
+cd envoy
+./start_envoy.sh env_two envoy_config.yaml
+```
+
+
+### 5. In the third terminal (or forth terminal, if you chose to do two envoys) run the Jupyter Notebook:
+
+```sh
+cd workspace
+jupyter lab Pytorch_FedProx_MedMNIST_2D.ipynb
+```
+- A Jupyter Server URL will appear in your terminal. In your browser, proceed to that link. Once the webpage loads, click on the Pytorch_FedProx_MedMNIST_2D.ipynb file.
+- To run the experiment, select the icon that looks like two triangles to "Restart Kernel and Run All Cells".
+- You will notice activity in your terminals as the experiments runs, and when the experiment is finished the director terminal will display a message that the experiment was finished successfully.
diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/director_config.yaml b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/director_config.yaml
new file mode 100644
index 0000000000..e51c9c8892
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/director_config.yaml
@@ -0,0 +1,6 @@
+settings:
+ listen_host: localhost
+ listen_port: 50051
+ sample_shape: ['28', '28', '3']
+ target_shape: ['1','1']
+
diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/start_director.sh b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/start_director.sh
new file mode 100755
index 0000000000..5806a6cc0a
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/start_director.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+set -e
+
+fx director start --disable-tls -c director_config.yaml
\ No newline at end of file
diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/envoy_config.yaml b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/envoy_config.yaml
new file mode 100644
index 0000000000..05ee5cec4c
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/envoy_config.yaml
@@ -0,0 +1,11 @@
+params:
+ cuda_devices: []
+
+optional_plugin_components: {}
+
+shard_descriptor:
+ template: medmnist_shard_descriptor.MedMNISTShardDescriptor
+ params:
+ rank_worldsize: 1, 1
+ datapath: data/.
+ dataname: bloodmnist
diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/medmnist_shard_descriptor.py b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/medmnist_shard_descriptor.py
new file mode 100644
index 0000000000..d5e639fed4
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/medmnist_shard_descriptor.py
@@ -0,0 +1,129 @@
+# Copyright (C) 2020-2021 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+"""MedMNIST Shard Descriptor."""
+
+import logging
+import os
+from typing import Any, List, Tuple
+from medmnist.info import INFO, HOMEPAGE
+
+import numpy as np
+
+from openfl.interface.interactive_api.shard_descriptor import ShardDataset
+from openfl.interface.interactive_api.shard_descriptor import ShardDescriptor
+
+logger = logging.getLogger(__name__)
+
+
+class MedMNISTShardDataset(ShardDataset):
+ """MedMNIST Shard dataset class."""
+
+ def __init__(self, x, y, data_type: str = 'train', rank: int = 1, worldsize: int = 1) -> None:
+ """Initialize MedMNISTDataset."""
+ self.data_type = data_type
+ self.rank = rank
+ self.worldsize = worldsize
+ self.x = x[self.rank - 1::self.worldsize]
+ self.y = y[self.rank - 1::self.worldsize]
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """Return an item by the index."""
+ return self.x[index], self.y[index]
+
+ def __len__(self) -> int:
+ """Return the len of the dataset."""
+ return len(self.x)
+
+
+class MedMNISTShardDescriptor(ShardDescriptor):
+ """MedMNIST Shard descriptor class."""
+
+ def __init__(
+ self,
+ rank_worldsize: str = '1, 1',
+ datapath: str = '',
+ dataname: str = 'bloodmnist',
+ **kwargs
+ ) -> None:
+ """Initialize MedMNISTShardDescriptor."""
+ self.rank, self.worldsize = tuple(int(num) for num in rank_worldsize.split(','))
+
+ self.datapath = datapath
+ self.dataset_name = dataname
+ self.info = INFO[self.dataset_name]
+
+ (x_train, y_train), (x_test, y_test) = self.load_data()
+ self.data_by_type = {
+ 'train': (x_train, y_train),
+ 'val': (x_test, y_test)
+ }
+
+ def get_shard_dataset_types(self) -> List[str]:
+ """Get available shard dataset types."""
+ return list(self.data_by_type)
+
+ def get_dataset(self, dataset_type='train') -> MedMNISTShardDataset:
+ """Return a shard dataset by type."""
+ if dataset_type not in self.data_by_type:
+ raise Exception(f'Wrong dataset type: {dataset_type}')
+ return MedMNISTShardDataset(
+ *self.data_by_type[dataset_type],
+ data_type=dataset_type,
+ rank=self.rank,
+ worldsize=self.worldsize
+ )
+
+ @property
+ def sample_shape(self) -> List[str]:
+ """Return the sample shape info."""
+ return ['28', '28', '3']
+
+ @property
+ def target_shape(self) -> List[str]:
+ """Return the target shape info."""
+ return ['1', '1']
+
+ @property
+ def dataset_description(self) -> str:
+ """Return the dataset description."""
+ return (f'MedMNIST dataset, shard number {self.rank}'
+ f' out of {self.worldsize}')
+
+ @staticmethod
+ def download_data(datapath: str = 'data/',
+ dataname: str = 'bloodmnist',
+ info: dict = {}) -> None:
+
+ logger.info(f"{datapath}\n{dataname}\n{info}")
+ try:
+ from torchvision.datasets.utils import download_url
+ download_url(url=info["url"],
+ root=datapath,
+ filename=dataname,
+ md5=info["MD5"])
+ except Exception:
+ raise RuntimeError('Something went wrong when downloading! '
+ + 'Go to the homepage to download manually. '
+ + HOMEPAGE)
+
+ def load_data(self) -> Tuple[Tuple[Any, Any], Tuple[Any, Any]]:
+ """Download prepared dataset."""
+
+ dataname = self.dataset_name + '.npz'
+ dataset = os.path.join(self.datapath, dataname)
+
+ if not os.path.isfile(dataset):
+ logger.info(f"Dataset {dataname} not found at:{self.datapath}.\n\tDownloading...")
+ MedMNISTShardDescriptor.download_data(self.datapath, dataname, self.info)
+ logger.info("DONE!")
+
+ data = np.load(dataset)
+
+ x_train = data["train_images"]
+ x_test = data["test_images"]
+
+ y_train = data["train_labels"]
+ y_test = data["test_labels"]
+ logger.info('MedMNIST data was loaded!')
+ return (x_train, y_train), (x_test, y_test)
diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/requirements.txt b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/requirements.txt
new file mode 100644
index 0000000000..363c0d69f9
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/requirements.txt
@@ -0,0 +1,3 @@
+medmnist
+setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
+wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability
diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/start_envoy.sh b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/start_envoy.sh
new file mode 100755
index 0000000000..cdd84e7fb6
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/start_envoy.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+set -e
+ENVOY_NAME=$1
+ENVOY_CONF=$2
+
+fx envoy start -n "$ENVOY_NAME" --disable-tls --envoy-config-path "$ENVOY_CONF" -dh localhost -dp 50051
diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/workspace/Pytorch_FedProx_MedMNIST_2D.ipynb b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/workspace/Pytorch_FedProx_MedMNIST_2D.ipynb
new file mode 100644
index 0000000000..3bae96ef48
--- /dev/null
+++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/workspace/Pytorch_FedProx_MedMNIST_2D.ipynb
@@ -0,0 +1,628 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "26fdd9ed",
+ "metadata": {},
+ "source": [
+ "# Federated MedMNIST2D Using FedProx Aggregation Algorithm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5504ab79",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install medmnist"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d0570122",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install dependencies if not already installed\n",
+ "import tqdm\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from torchvision import transforms as T\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import medmnist\n",
+ "import openfl.utilities.optimizers.torch.fedprox as FP"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "22ba64da",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from medmnist import INFO, Evaluator\n",
+ "\n",
+ "## Change dataflag here to reflect the ones defined in the envoy_conifg_xxx.yaml\n",
+ "dataname = 'bloodmnist'\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "246f9c98",
+ "metadata": {},
+ "source": [
+ "## Connect to the Federation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d657e463",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a federation\n",
+ "from openfl.interface.interactive_api.federation import Federation\n",
+ "\n",
+ "# please use the same identificator that was used in signed certificate\n",
+ "client_id = 'api'\n",
+ "director_node_fqdn = 'localhost'\n",
+ "director_port=50051\n",
+ "\n",
+ "# 2) Run with TLS disabled (trusted environment)\n",
+ "# Federation can also determine local fqdn automatically\n",
+ "federation = Federation(\n",
+ " client_id=client_id,\n",
+ " director_node_fqdn=director_node_fqdn,\n",
+ " director_port=director_port, \n",
+ " tls=False\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "47dcfab3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shard_registry = federation.get_shard_registry()\n",
+ "shard_registry"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a2a6c237",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# First, request a dummy_shard_desc that holds information about the federated dataset \n",
+ "dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)\n",
+ "dummy_shard_dataset = dummy_shard_desc.get_dataset('train')\n",
+ "sample, target = dummy_shard_dataset[0]\n",
+ "f\"Sample shape: {sample.shape}, target shape: {target.shape}\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cc0dbdbd",
+ "metadata": {},
+ "source": [
+ "## Describing FL experiment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fc88700a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9b3081a6",
+ "metadata": {},
+ "source": [
+ "## Load MedMNIST INFO"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e0377d3a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "num_epochs = 3\n",
+ "TRAIN_BS, VALID_BS = 64, 128\n",
+ "\n",
+ "lr = 0.001\n",
+ "gamma=0.1\n",
+ "milestones = [0.5 * num_epochs, 0.75 * num_epochs]\n",
+ "\n",
+ "info = INFO[dataname]\n",
+ "task = info['task']\n",
+ "n_channels = info['n_channels']\n",
+ "n_classes = len(info['label'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b0979470",
+ "metadata": {},
+ "source": [
+ "### Register dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f0dc457e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "## Data transformations\n",
+ "data_transform = T.Compose([T.ToTensor(), \n",
+ " T.Normalize(mean=[.5], std=[.5])]\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "09ba2f64",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from PIL import Image\n",
+ "\n",
+ "class TransformedDataset(Dataset):\n",
+ " \"\"\"Image Person ReID Dataset.\"\"\"\n",
+ "\n",
+ "\n",
+ " def __init__(self, dataset, transform=None, target_transform=None):\n",
+ " \"\"\"Initialize Dataset.\"\"\"\n",
+ " self.dataset = dataset\n",
+ " self.transform = transform\n",
+ " self.target_transform = target_transform\n",
+ "\n",
+ " def __len__(self):\n",
+ " \"\"\"Length of dataset.\"\"\"\n",
+ " return len(self.dataset)\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " \n",
+ " img, label = self.dataset[index]\n",
+ " \n",
+ " if self.target_transform:\n",
+ " label = self.target_transform(label) \n",
+ " else:\n",
+ " label = label.astype(int)\n",
+ " \n",
+ " if self.transform:\n",
+ " img = Image.fromarray(img)\n",
+ " img = self.transform(img)\n",
+ " else:\n",
+ " base_transform = T.PILToTensor()\n",
+ " img = Image.fromarray(img)\n",
+ " img = base_transform(img) \n",
+ "\n",
+ " return img, label\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "db2d563e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class MedMnistFedDataset(DataInterface):\n",
+ " def __init__(self, **kwargs):\n",
+ " self.kwargs = kwargs\n",
+ " \n",
+ " @property\n",
+ " def shard_descriptor(self):\n",
+ " return self._shard_descriptor\n",
+ " \n",
+ " @shard_descriptor.setter\n",
+ " def shard_descriptor(self, shard_descriptor):\n",
+ " \"\"\"\n",
+ " Describe per-collaborator procedures or sharding.\n",
+ "\n",
+ " This method will be called during a collaborator initialization.\n",
+ " Local shard_descriptor will be set by Envoy.\n",
+ " \"\"\"\n",
+ " self._shard_descriptor = shard_descriptor\n",
+ "\n",
+ " self.train_set = TransformedDataset(\n",
+ " self._shard_descriptor.get_dataset('train'),\n",
+ " transform=data_transform\n",
+ " ) \n",
+ " \n",
+ " self.valid_set = TransformedDataset(\n",
+ " self._shard_descriptor.get_dataset('val'),\n",
+ " transform=data_transform\n",
+ " )\n",
+ " \n",
+ " def get_train_loader(self, **kwargs):\n",
+ " \"\"\"\n",
+ " Output of this method will be provided to tasks with optimizer in contract\n",
+ " \"\"\"\n",
+ " return DataLoader(\n",
+ " self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True)\n",
+ "\n",
+ " def get_valid_loader(self, **kwargs):\n",
+ " \"\"\"\n",
+ " Output of this method will be provided to tasks without optimizer in contract\n",
+ " \"\"\"\n",
+ " return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])\n",
+ "\n",
+ " def get_train_data_size(self):\n",
+ " \"\"\"\n",
+ " Information for aggregation\n",
+ " \"\"\"\n",
+ " return len(self.train_set)\n",
+ "\n",
+ " def get_valid_data_size(self):\n",
+ " \"\"\"\n",
+ " Information for aggregation\n",
+ " \"\"\"\n",
+ " return len(self.valid_set)\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b0dfb459",
+ "metadata": {},
+ "source": [
+ "### Create Mnist federated dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4af5c4c2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fed_dataset = MedMnistFedDataset(train_bs=TRAIN_BS, valid_bs=VALID_BS)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7f63908e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fed_dataset.shard_descriptor = dummy_shard_desc\n",
+ "for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):\n",
+ " print(sample.shape, target.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "075d1d6c",
+ "metadata": {},
+ "source": [
+ "### Describe the model and optimizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8477a001",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define a simple CNN model\n",
+ "class Net(nn.Module):\n",
+ " def __init__(self, in_channels, num_classes):\n",
+ " super(Net, self).__init__()\n",
+ "\n",
+ " self.layer1 = nn.Sequential(\n",
+ " nn.Conv2d(in_channels, 16, kernel_size=3),\n",
+ " nn.BatchNorm2d(16),\n",
+ " nn.ReLU())\n",
+ "\n",
+ " self.layer2 = nn.Sequential(\n",
+ " nn.Conv2d(16, 16, kernel_size=3),\n",
+ " nn.BatchNorm2d(16),\n",
+ " nn.ReLU(),\n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ "\n",
+ " self.layer3 = nn.Sequential(\n",
+ " nn.Conv2d(16, 64, kernel_size=3),\n",
+ " nn.BatchNorm2d(64),\n",
+ " nn.ReLU())\n",
+ " \n",
+ " self.layer4 = nn.Sequential(\n",
+ " nn.Conv2d(64, 64, kernel_size=3),\n",
+ " nn.BatchNorm2d(64),\n",
+ " nn.ReLU())\n",
+ "\n",
+ " self.layer5 = nn.Sequential(\n",
+ " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n",
+ " nn.BatchNorm2d(64),\n",
+ " nn.ReLU(),\n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ "\n",
+ " self.fc = nn.Sequential(\n",
+ " nn.Linear(64 * 4 * 4, 128),\n",
+ " nn.ReLU(),\n",
+ " nn.Linear(128, 128),\n",
+ " nn.ReLU(),\n",
+ " nn.Linear(128, num_classes))\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.layer1(x)\n",
+ " x = self.layer2(x)\n",
+ " x = self.layer3(x)\n",
+ " x = self.layer4(x)\n",
+ " x = self.layer5(x)\n",
+ " x = x.view(x.size(0), -1)\n",
+ " x = self.fc(x)\n",
+ " return x\n",
+ "\n",
+ "model = Net(in_channels=n_channels, num_classes=n_classes)\n",
+ " \n",
+ "# define loss function and optimizer\n",
+ "if task == \"multi-label, binary-class\":\n",
+ " criterion = nn.BCEWithLogitsLoss()\n",
+ "else:\n",
+ " criterion = nn.CrossEntropyLoss()\n",
+ " \n",
+ "optimizer = FP.FedProxOptimizer(params = model.parameters(), lr=lr, momentum=0.9)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f2154486",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8d1c78ee",
+ "metadata": {},
+ "source": [
+ "### Register model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "59831bcd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from copy import deepcopy\n",
+ "\n",
+ "framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'\n",
+ "MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)\n",
+ "\n",
+ "# Save the initial model state\n",
+ "initial_model = deepcopy(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "849c165b",
+ "metadata": {},
+ "source": [
+ "## Define and register FL tasks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4ff463bd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "TI = TaskInterface()\n",
+ "\n",
+ "train_custom_params={'criterion':criterion,'task':task}\n",
+ "\n",
+ "# Task interface currently supports only standalone functions.\n",
+ "@TI.add_kwargs(**train_custom_params)\n",
+ "@TI.register_fl_task(model='model', data_loader='train_loader',\n",
+ " device='device', optimizer='optimizer')\n",
+ "def train(model, train_loader, device, optimizer, criterion, task):\n",
+ " total_loss = []\n",
+ " \n",
+ " train_loader = tqdm.tqdm(train_loader, desc=\"train\")\n",
+ " model.train()\n",
+ " model.to(device)\n",
+ " \n",
+ " for inputs, targets in train_loader:\n",
+ " \n",
+ " optimizer.set_old_weights(list(model.parameters()))\n",
+ " optimizer.zero_grad()\n",
+ " \n",
+ " outputs = model(inputs.to(device))\n",
+ " \n",
+ " if task == 'multi-label, binary-class':\n",
+ " targets = targets.to(torch.float32).to(device)\n",
+ " loss = criterion(outputs, targets)\n",
+ " else:\n",
+ " targets = torch.squeeze(targets, 1).long().to(device)\n",
+ " loss = criterion(outputs, targets)\n",
+ " \n",
+ " total_loss.append(loss.item())\n",
+ " \n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " return {'train_loss': np.mean(total_loss),}\n",
+ "\n",
+ "\n",
+ "val_custom_params={'criterion':criterion, \n",
+ " 'task':task}\n",
+ "\n",
+ "@TI.add_kwargs(**val_custom_params)\n",
+ "@TI.register_fl_task(model='model', data_loader='val_loader', device='device')\n",
+ "def validate(model, val_loader, device, criterion, task):\n",
+ "\n",
+ " val_loader = tqdm.tqdm(val_loader, desc=\"validate\")\n",
+ " model.eval()\n",
+ " model.to(device)\n",
+ "\n",
+ " val_score = 0\n",
+ " total_samples = 0\n",
+ " total_loss = []\n",
+ " y_score = torch.tensor([]).to(device)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " for inputs, targets in val_loader:\n",
+ " outputs = model(inputs.to(device))\n",
+ " \n",
+ " if task == 'multi-label, binary-class':\n",
+ " targets = targets.to(torch.float32).to(device)\n",
+ " loss = criterion(outputs, targets)\n",
+ " m = nn.Sigmoid()\n",
+ " outputs = m(outputs).to(device)\n",
+ " else:\n",
+ " targets = torch.squeeze(targets, 1).long().to(device)\n",
+ " loss = criterion(outputs, targets)\n",
+ " m = nn.Softmax(dim=1)\n",
+ " outputs = m(outputs).to(device)\n",
+ " targets = targets.float().resize_(len(targets), 1)\n",
+ "\n",
+ " total_loss.append(loss.item())\n",
+ " \n",
+ " total_samples += targets.shape[0]\n",
+ " pred = outputs.argmax(dim=1)\n",
+ " val_score += pred.eq(targets).sum().cpu().numpy()\n",
+ " \n",
+ " acc = val_score / total_samples \n",
+ " test_loss = sum(total_loss) / len(total_loss)\n",
+ "\n",
+ " return {'acc': acc,\n",
+ " 'test_loss': test_loss,\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8f0ebf2d",
+ "metadata": {},
+ "source": [
+ "## Time to start a federated learning experiment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d41b7896",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create an experimnet in federation\n",
+ "experiment_name = 'medmnist_exp'\n",
+ "fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "41b44de9",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n",
+ "fl_experiment.start(model_provider=MI, \n",
+ " task_keeper=TI,\n",
+ " data_loader=fed_dataset,\n",
+ " rounds_to_train=3,\n",
+ " opt_treatment='RESET',\n",
+ " device_assignment_policy='CUDA_PREFERRED')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "01fa7cea",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# If user want to stop IPython session, then reconnect and check how experiment is going\n",
+ "# fl_experiment.restore_experiment_state(model_interface)\n",
+ "\n",
+ "fl_experiment.stream_metrics(tensorboard_logs=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "92940763",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1690ea49",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "10d7d5a2",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}