diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/README.md b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/README.md new file mode 100644 index 0000000000..ee57df63e9 --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/README.md @@ -0,0 +1,71 @@ +# MedMNIST 2D Classification Using FedProx Optimizer Tutorial + +![MedMNISTv2_overview](https://raw.githubusercontent.com/MedMNIST/MedMNIST/main/assets/medmnistv2.jpg) + +For more details, please refer to the original paper: +**MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification** ([arXiv](https://arxiv.org/abs/2110.14795)), and [PyPI](https://pypi.org/project/medmnist/). + +This example differs from PyTorch_MedMNIST_2D in that it uses the FedProx Optimizer. For more information on FedProx see: +**Federated Optimization in Heterogeneous Networks** ([arXiv](https://arxiv.org/abs/1812.06127)). + +## I. About model and experiments + +We use a simple convolutional neural network and settings coming from [the experiments](https://github.com/MedMNIST/experiments) repository. +
+ +## II. How to run this tutorial (without TLC and locally as a simulation): +### 0. If you haven't done so already, create a virtual environment, install OpenFL, and upgrade pip: + - For help with this step, visit the "Install the Package" section of the [OpenFL installation instructions](https://openfl.readthedocs.io/en/latest/install.html#install-the-package). +
+ +### 1. Split terminal into 3 (1 terminal for the director, 1 for the envoy, and 1 for the experiment) +
+ +### 2. Do the following in each terminal: + - Activate the virtual environment from step 0: + + ```sh + source venv/bin/activate + ``` + - If you are in a network environment with a proxy, ensure proxy environment variables are set in each of your terminals. + - Navigate to the tutorial: + + ```sh + cd openfl/openfl-tutorials/interactive_api/PyTorch_FedProx_MedMNIST + ``` +
+ +### 3. In the first terminal, run the director: + +```sh +cd director +./start_director.sh +``` +
+ +### 4. In the second terminal, install requirements and run the envoy: + +```sh +cd envoy +pip install -r requirements.txt +./start_envoy.sh env_one envoy_config.yaml +``` + +Optional: Run a second envoy in an additional terminal: + - Ensure step 2 is complete for this terminal as well. + - Run the second envoy: +```sh +cd envoy +./start_envoy.sh env_two envoy_config.yaml +``` +
+ +### 5. In the third terminal (or forth terminal, if you chose to do two envoys) run the Jupyter Notebook: + +```sh +cd workspace +jupyter lab Pytorch_FedProx_MedMNIST_2D.ipynb +``` +- A Jupyter Server URL will appear in your terminal. In your browser, proceed to that link. Once the webpage loads, click on the Pytorch_FedProx_MedMNIST_2D.ipynb file. +- To run the experiment, select the icon that looks like two triangles to "Restart Kernel and Run All Cells". +- You will notice activity in your terminals as the experiments runs, and when the experiment is finished the director terminal will display a message that the experiment was finished successfully. diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/director_config.yaml b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/director_config.yaml new file mode 100644 index 0000000000..e51c9c8892 --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/director_config.yaml @@ -0,0 +1,6 @@ +settings: + listen_host: localhost + listen_port: 50051 + sample_shape: ['28', '28', '3'] + target_shape: ['1','1'] + diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/start_director.sh b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/start_director.sh new file mode 100755 index 0000000000..5806a6cc0a --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/director/start_director.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -e + +fx director start --disable-tls -c director_config.yaml \ No newline at end of file diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/envoy_config.yaml b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/envoy_config.yaml new file mode 100644 index 0000000000..05ee5cec4c --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/envoy_config.yaml @@ -0,0 +1,11 @@ +params: + cuda_devices: [] + +optional_plugin_components: {} + +shard_descriptor: + template: medmnist_shard_descriptor.MedMNISTShardDescriptor + params: + rank_worldsize: 1, 1 + datapath: data/. + dataname: bloodmnist diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/medmnist_shard_descriptor.py b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/medmnist_shard_descriptor.py new file mode 100644 index 0000000000..d5e639fed4 --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/medmnist_shard_descriptor.py @@ -0,0 +1,129 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""MedMNIST Shard Descriptor.""" + +import logging +import os +from typing import Any, List, Tuple +from medmnist.info import INFO, HOMEPAGE + +import numpy as np + +from openfl.interface.interactive_api.shard_descriptor import ShardDataset +from openfl.interface.interactive_api.shard_descriptor import ShardDescriptor + +logger = logging.getLogger(__name__) + + +class MedMNISTShardDataset(ShardDataset): + """MedMNIST Shard dataset class.""" + + def __init__(self, x, y, data_type: str = 'train', rank: int = 1, worldsize: int = 1) -> None: + """Initialize MedMNISTDataset.""" + self.data_type = data_type + self.rank = rank + self.worldsize = worldsize + self.x = x[self.rank - 1::self.worldsize] + self.y = y[self.rank - 1::self.worldsize] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Return an item by the index.""" + return self.x[index], self.y[index] + + def __len__(self) -> int: + """Return the len of the dataset.""" + return len(self.x) + + +class MedMNISTShardDescriptor(ShardDescriptor): + """MedMNIST Shard descriptor class.""" + + def __init__( + self, + rank_worldsize: str = '1, 1', + datapath: str = '', + dataname: str = 'bloodmnist', + **kwargs + ) -> None: + """Initialize MedMNISTShardDescriptor.""" + self.rank, self.worldsize = tuple(int(num) for num in rank_worldsize.split(',')) + + self.datapath = datapath + self.dataset_name = dataname + self.info = INFO[self.dataset_name] + + (x_train, y_train), (x_test, y_test) = self.load_data() + self.data_by_type = { + 'train': (x_train, y_train), + 'val': (x_test, y_test) + } + + def get_shard_dataset_types(self) -> List[str]: + """Get available shard dataset types.""" + return list(self.data_by_type) + + def get_dataset(self, dataset_type='train') -> MedMNISTShardDataset: + """Return a shard dataset by type.""" + if dataset_type not in self.data_by_type: + raise Exception(f'Wrong dataset type: {dataset_type}') + return MedMNISTShardDataset( + *self.data_by_type[dataset_type], + data_type=dataset_type, + rank=self.rank, + worldsize=self.worldsize + ) + + @property + def sample_shape(self) -> List[str]: + """Return the sample shape info.""" + return ['28', '28', '3'] + + @property + def target_shape(self) -> List[str]: + """Return the target shape info.""" + return ['1', '1'] + + @property + def dataset_description(self) -> str: + """Return the dataset description.""" + return (f'MedMNIST dataset, shard number {self.rank}' + f' out of {self.worldsize}') + + @staticmethod + def download_data(datapath: str = 'data/', + dataname: str = 'bloodmnist', + info: dict = {}) -> None: + + logger.info(f"{datapath}\n{dataname}\n{info}") + try: + from torchvision.datasets.utils import download_url + download_url(url=info["url"], + root=datapath, + filename=dataname, + md5=info["MD5"]) + except Exception: + raise RuntimeError('Something went wrong when downloading! ' + + 'Go to the homepage to download manually. ' + + HOMEPAGE) + + def load_data(self) -> Tuple[Tuple[Any, Any], Tuple[Any, Any]]: + """Download prepared dataset.""" + + dataname = self.dataset_name + '.npz' + dataset = os.path.join(self.datapath, dataname) + + if not os.path.isfile(dataset): + logger.info(f"Dataset {dataname} not found at:{self.datapath}.\n\tDownloading...") + MedMNISTShardDescriptor.download_data(self.datapath, dataname, self.info) + logger.info("DONE!") + + data = np.load(dataset) + + x_train = data["train_images"] + x_test = data["test_images"] + + y_train = data["train_labels"] + y_test = data["test_labels"] + logger.info('MedMNIST data was loaded!') + return (x_train, y_train), (x_test, y_test) diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/requirements.txt b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/requirements.txt new file mode 100644 index 0000000000..363c0d69f9 --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/requirements.txt @@ -0,0 +1,3 @@ +medmnist +setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability +wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/start_envoy.sh b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/start_envoy.sh new file mode 100755 index 0000000000..cdd84e7fb6 --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/envoy/start_envoy.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -e +ENVOY_NAME=$1 +ENVOY_CONF=$2 + +fx envoy start -n "$ENVOY_NAME" --disable-tls --envoy-config-path "$ENVOY_CONF" -dh localhost -dp 50051 diff --git a/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/workspace/Pytorch_FedProx_MedMNIST_2D.ipynb b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/workspace/Pytorch_FedProx_MedMNIST_2D.ipynb new file mode 100644 index 0000000000..3bae96ef48 --- /dev/null +++ b/openfl-tutorials/interactive_api/PyTorch_FedProx_MNIST/workspace/Pytorch_FedProx_MedMNIST_2D.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "26fdd9ed", + "metadata": {}, + "source": [ + "# Federated MedMNIST2D Using FedProx Aggregation Algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5504ab79", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install medmnist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0570122", + "metadata": {}, + "outputs": [], + "source": [ + "# Install dependencies if not already installed\n", + "import tqdm\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torchvision import transforms as T\n", + "import torch.nn.functional as F\n", + "\n", + "import medmnist\n", + "import openfl.utilities.optimizers.torch.fedprox as FP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22ba64da", + "metadata": {}, + "outputs": [], + "source": [ + "from medmnist import INFO, Evaluator\n", + "\n", + "## Change dataflag here to reflect the ones defined in the envoy_conifg_xxx.yaml\n", + "dataname = 'bloodmnist'\n" + ] + }, + { + "cell_type": "markdown", + "id": "246f9c98", + "metadata": {}, + "source": [ + "## Connect to the Federation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d657e463", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a federation\n", + "from openfl.interface.interactive_api.federation import Federation\n", + "\n", + "# please use the same identificator that was used in signed certificate\n", + "client_id = 'api'\n", + "director_node_fqdn = 'localhost'\n", + "director_port=50051\n", + "\n", + "# 2) Run with TLS disabled (trusted environment)\n", + "# Federation can also determine local fqdn automatically\n", + "federation = Federation(\n", + " client_id=client_id,\n", + " director_node_fqdn=director_node_fqdn,\n", + " director_port=director_port, \n", + " tls=False\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47dcfab3", + "metadata": {}, + "outputs": [], + "source": [ + "shard_registry = federation.get_shard_registry()\n", + "shard_registry" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2a6c237", + "metadata": {}, + "outputs": [], + "source": [ + "# First, request a dummy_shard_desc that holds information about the federated dataset \n", + "dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)\n", + "dummy_shard_dataset = dummy_shard_desc.get_dataset('train')\n", + "sample, target = dummy_shard_dataset[0]\n", + "f\"Sample shape: {sample.shape}, target shape: {target.shape}\"" + ] + }, + { + "cell_type": "markdown", + "id": "cc0dbdbd", + "metadata": {}, + "source": [ + "## Describing FL experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc88700a", + "metadata": {}, + "outputs": [], + "source": [ + "from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment" + ] + }, + { + "cell_type": "markdown", + "id": "9b3081a6", + "metadata": {}, + "source": [ + "## Load MedMNIST INFO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0377d3a", + "metadata": {}, + "outputs": [], + "source": [ + "num_epochs = 3\n", + "TRAIN_BS, VALID_BS = 64, 128\n", + "\n", + "lr = 0.001\n", + "gamma=0.1\n", + "milestones = [0.5 * num_epochs, 0.75 * num_epochs]\n", + "\n", + "info = INFO[dataname]\n", + "task = info['task']\n", + "n_channels = info['n_channels']\n", + "n_classes = len(info['label'])" + ] + }, + { + "cell_type": "markdown", + "id": "b0979470", + "metadata": {}, + "source": [ + "### Register dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0dc457e", + "metadata": {}, + "outputs": [], + "source": [ + "## Data transformations\n", + "data_transform = T.Compose([T.ToTensor(), \n", + " T.Normalize(mean=[.5], std=[.5])]\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09ba2f64", + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "\n", + "class TransformedDataset(Dataset):\n", + " \"\"\"Image Person ReID Dataset.\"\"\"\n", + "\n", + "\n", + " def __init__(self, dataset, transform=None, target_transform=None):\n", + " \"\"\"Initialize Dataset.\"\"\"\n", + " self.dataset = dataset\n", + " self.transform = transform\n", + " self.target_transform = target_transform\n", + "\n", + " def __len__(self):\n", + " \"\"\"Length of dataset.\"\"\"\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, index):\n", + " \n", + " img, label = self.dataset[index]\n", + " \n", + " if self.target_transform:\n", + " label = self.target_transform(label) \n", + " else:\n", + " label = label.astype(int)\n", + " \n", + " if self.transform:\n", + " img = Image.fromarray(img)\n", + " img = self.transform(img)\n", + " else:\n", + " base_transform = T.PILToTensor()\n", + " img = Image.fromarray(img)\n", + " img = base_transform(img) \n", + "\n", + " return img, label\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db2d563e", + "metadata": {}, + "outputs": [], + "source": [ + "class MedMnistFedDataset(DataInterface):\n", + " def __init__(self, **kwargs):\n", + " self.kwargs = kwargs\n", + " \n", + " @property\n", + " def shard_descriptor(self):\n", + " return self._shard_descriptor\n", + " \n", + " @shard_descriptor.setter\n", + " def shard_descriptor(self, shard_descriptor):\n", + " \"\"\"\n", + " Describe per-collaborator procedures or sharding.\n", + "\n", + " This method will be called during a collaborator initialization.\n", + " Local shard_descriptor will be set by Envoy.\n", + " \"\"\"\n", + " self._shard_descriptor = shard_descriptor\n", + "\n", + " self.train_set = TransformedDataset(\n", + " self._shard_descriptor.get_dataset('train'),\n", + " transform=data_transform\n", + " ) \n", + " \n", + " self.valid_set = TransformedDataset(\n", + " self._shard_descriptor.get_dataset('val'),\n", + " transform=data_transform\n", + " )\n", + " \n", + " def get_train_loader(self, **kwargs):\n", + " \"\"\"\n", + " Output of this method will be provided to tasks with optimizer in contract\n", + " \"\"\"\n", + " return DataLoader(\n", + " self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True)\n", + "\n", + " def get_valid_loader(self, **kwargs):\n", + " \"\"\"\n", + " Output of this method will be provided to tasks without optimizer in contract\n", + " \"\"\"\n", + " return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])\n", + "\n", + " def get_train_data_size(self):\n", + " \"\"\"\n", + " Information for aggregation\n", + " \"\"\"\n", + " return len(self.train_set)\n", + "\n", + " def get_valid_data_size(self):\n", + " \"\"\"\n", + " Information for aggregation\n", + " \"\"\"\n", + " return len(self.valid_set)\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "b0dfb459", + "metadata": {}, + "source": [ + "### Create Mnist federated dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4af5c4c2", + "metadata": {}, + "outputs": [], + "source": [ + "fed_dataset = MedMnistFedDataset(train_bs=TRAIN_BS, valid_bs=VALID_BS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f63908e", + "metadata": {}, + "outputs": [], + "source": [ + "fed_dataset.shard_descriptor = dummy_shard_desc\n", + "for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):\n", + " print(sample.shape, target.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "075d1d6c", + "metadata": {}, + "source": [ + "### Describe the model and optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8477a001", + "metadata": {}, + "outputs": [], + "source": [ + "# define a simple CNN model\n", + "class Net(nn.Module):\n", + " def __init__(self, in_channels, num_classes):\n", + " super(Net, self).__init__()\n", + "\n", + " self.layer1 = nn.Sequential(\n", + " nn.Conv2d(in_channels, 16, kernel_size=3),\n", + " nn.BatchNorm2d(16),\n", + " nn.ReLU())\n", + "\n", + " self.layer2 = nn.Sequential(\n", + " nn.Conv2d(16, 16, kernel_size=3),\n", + " nn.BatchNorm2d(16),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2))\n", + "\n", + " self.layer3 = nn.Sequential(\n", + " nn.Conv2d(16, 64, kernel_size=3),\n", + " nn.BatchNorm2d(64),\n", + " nn.ReLU())\n", + " \n", + " self.layer4 = nn.Sequential(\n", + " nn.Conv2d(64, 64, kernel_size=3),\n", + " nn.BatchNorm2d(64),\n", + " nn.ReLU())\n", + "\n", + " self.layer5 = nn.Sequential(\n", + " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", + " nn.BatchNorm2d(64),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2))\n", + "\n", + " self.fc = nn.Sequential(\n", + " nn.Linear(64 * 4 * 4, 128),\n", + " nn.ReLU(),\n", + " nn.Linear(128, 128),\n", + " nn.ReLU(),\n", + " nn.Linear(128, num_classes))\n", + "\n", + " def forward(self, x):\n", + " x = self.layer1(x)\n", + " x = self.layer2(x)\n", + " x = self.layer3(x)\n", + " x = self.layer4(x)\n", + " x = self.layer5(x)\n", + " x = x.view(x.size(0), -1)\n", + " x = self.fc(x)\n", + " return x\n", + "\n", + "model = Net(in_channels=n_channels, num_classes=n_classes)\n", + " \n", + "# define loss function and optimizer\n", + "if task == \"multi-label, binary-class\":\n", + " criterion = nn.BCEWithLogitsLoss()\n", + "else:\n", + " criterion = nn.CrossEntropyLoss()\n", + " \n", + "optimizer = FP.FedProxOptimizer(params = model.parameters(), lr=lr, momentum=0.9)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2154486", + "metadata": {}, + "outputs": [], + "source": [ + "print(model)" + ] + }, + { + "cell_type": "markdown", + "id": "8d1c78ee", + "metadata": {}, + "source": [ + "### Register model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59831bcd", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'\n", + "MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)\n", + "\n", + "# Save the initial model state\n", + "initial_model = deepcopy(model)" + ] + }, + { + "cell_type": "markdown", + "id": "849c165b", + "metadata": {}, + "source": [ + "## Define and register FL tasks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ff463bd", + "metadata": {}, + "outputs": [], + "source": [ + "TI = TaskInterface()\n", + "\n", + "train_custom_params={'criterion':criterion,'task':task}\n", + "\n", + "# Task interface currently supports only standalone functions.\n", + "@TI.add_kwargs(**train_custom_params)\n", + "@TI.register_fl_task(model='model', data_loader='train_loader',\n", + " device='device', optimizer='optimizer')\n", + "def train(model, train_loader, device, optimizer, criterion, task):\n", + " total_loss = []\n", + " \n", + " train_loader = tqdm.tqdm(train_loader, desc=\"train\")\n", + " model.train()\n", + " model.to(device)\n", + " \n", + " for inputs, targets in train_loader:\n", + " \n", + " optimizer.set_old_weights(list(model.parameters()))\n", + " optimizer.zero_grad()\n", + " \n", + " outputs = model(inputs.to(device))\n", + " \n", + " if task == 'multi-label, binary-class':\n", + " targets = targets.to(torch.float32).to(device)\n", + " loss = criterion(outputs, targets)\n", + " else:\n", + " targets = torch.squeeze(targets, 1).long().to(device)\n", + " loss = criterion(outputs, targets)\n", + " \n", + " total_loss.append(loss.item())\n", + " \n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " return {'train_loss': np.mean(total_loss),}\n", + "\n", + "\n", + "val_custom_params={'criterion':criterion, \n", + " 'task':task}\n", + "\n", + "@TI.add_kwargs(**val_custom_params)\n", + "@TI.register_fl_task(model='model', data_loader='val_loader', device='device')\n", + "def validate(model, val_loader, device, criterion, task):\n", + "\n", + " val_loader = tqdm.tqdm(val_loader, desc=\"validate\")\n", + " model.eval()\n", + " model.to(device)\n", + "\n", + " val_score = 0\n", + " total_samples = 0\n", + " total_loss = []\n", + " y_score = torch.tensor([]).to(device)\n", + "\n", + " with torch.no_grad():\n", + " for inputs, targets in val_loader:\n", + " outputs = model(inputs.to(device))\n", + " \n", + " if task == 'multi-label, binary-class':\n", + " targets = targets.to(torch.float32).to(device)\n", + " loss = criterion(outputs, targets)\n", + " m = nn.Sigmoid()\n", + " outputs = m(outputs).to(device)\n", + " else:\n", + " targets = torch.squeeze(targets, 1).long().to(device)\n", + " loss = criterion(outputs, targets)\n", + " m = nn.Softmax(dim=1)\n", + " outputs = m(outputs).to(device)\n", + " targets = targets.float().resize_(len(targets), 1)\n", + "\n", + " total_loss.append(loss.item())\n", + " \n", + " total_samples += targets.shape[0]\n", + " pred = outputs.argmax(dim=1)\n", + " val_score += pred.eq(targets).sum().cpu().numpy()\n", + " \n", + " acc = val_score / total_samples \n", + " test_loss = sum(total_loss) / len(total_loss)\n", + "\n", + " return {'acc': acc,\n", + " 'test_loss': test_loss,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "8f0ebf2d", + "metadata": {}, + "source": [ + "## Time to start a federated learning experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d41b7896", + "metadata": {}, + "outputs": [], + "source": [ + "# create an experimnet in federation\n", + "experiment_name = 'medmnist_exp'\n", + "fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41b44de9", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n", + "fl_experiment.start(model_provider=MI, \n", + " task_keeper=TI,\n", + " data_loader=fed_dataset,\n", + " rounds_to_train=3,\n", + " opt_treatment='RESET',\n", + " device_assignment_policy='CUDA_PREFERRED')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01fa7cea", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# If user want to stop IPython session, then reconnect and check how experiment is going\n", + "# fl_experiment.restore_experiment_state(model_interface)\n", + "\n", + "fl_experiment.stream_metrics(tensorboard_logs=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92940763", + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1690ea49", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10d7d5a2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}