diff --git a/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl index eb43acfce97..51a21dd1741 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl @@ -10,15 +10,27 @@ from torch.utils.data import DataLoader from transformers import AutoTokenizer, DataCollatorWithPadding from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner + warnings.filterwarnings("ignore", category=UserWarning) DEVICE = torch.device("cpu") CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint +fds = None # Cache FederatedDataset + + def load_data(partition_id: int, num_partitions: int): """Load IMDB data (training and eval)""" - fds = FederatedDataset(dataset="imdb", partitioners={"train": num_partitions}) + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="stanfordnlp/imdb", + partitioners={"train": partitioner}, + ) partition = fds.load_partition(partition_id) # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) diff --git a/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl index 88053b0cd59..1759fe8c0b4 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl @@ -5,10 +5,12 @@ import mlx.nn as nn import numpy as np from datasets.utils.logging import disable_progress_bar from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner disable_progress_bar() + class MLP(nn.Module): """A simple MLP.""" @@ -43,8 +45,19 @@ def batch_iterate(batch_size, X, y): yield X[ids], y[ids] +fds = None # Cache FederatedDataset + + def load_data(partition_id: int, num_partitions: int): - fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions}) + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="ylecun/mnist", + partitioners={"train": partitioner}, + trust_remote_code=True, + ) partition = fds.load_partition(partition_id) partition_splits = partition.train_test_split(test_size=0.2, seed=42) diff --git a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl index d5971ffb6ce..bd2fad5be58 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl @@ -6,9 +6,10 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader -from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner + DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -34,9 +35,19 @@ class Net(nn.Module): return self.fc3(x) +fds = None # Cache FederatedDataset + + def load_data(partition_id: int, num_partitions: int): """Load partition CIFAR10 data.""" - fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}) + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="uoft-cs/cifar10", + partitioners={"train": partitioner}, + ) partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) diff --git a/src/py/flwr/cli/new/templates/app/code/task.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.tensorflow.py.tpl index fa07f93713e..c495774ffeb 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.tensorflow.py.tpl @@ -4,11 +4,13 @@ import os import tensorflow as tf from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner # Make TensorFlow log less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + def load_model(): # Load model and data (MobileNetV2, CIFAR-10) model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) @@ -16,9 +18,19 @@ def load_model(): return model +fds = None # Cache FederatedDataset + + def load_data(partition_id, num_partitions): # Download and partition dataset - fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}) + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="uoft-cs/cifar10", + partitioners={"train": partitioner}, + ) partition = fds.load_partition(partition_id, "train") partition.set_format("numpy")