Skip to content

Commit

Permalink
feat(framework) Use dataset caching in flwr new templates (#3877)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jul 23, 2024
1 parent 7ce9855 commit 54b2682
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 5 deletions.
14 changes: 13 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,27 @@ from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding

from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cpu")
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint


fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
"""Load IMDB data (training and eval)"""
fds = FederatedDataset(dataset="imdb", partitioners={"train": num_partitions})
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="stanfordnlp/imdb",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
Expand Down
15 changes: 14 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import mlx.nn as nn
import numpy as np
from datasets.utils.logging import disable_progress_bar
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


disable_progress_bar()


class MLP(nn.Module):
"""A simple MLP."""

Expand Down Expand Up @@ -43,8 +45,19 @@ def batch_iterate(batch_size, X, y):
yield X[ids], y[ids]


fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="ylecun/mnist",
partitioners={"train": partitioner},
trust_remote_code=True,
)
partition = fds.load_partition(partition_id)
partition_splits = partition.train_test_split(test_size=0.2, seed=42)

Expand Down
15 changes: 13 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Expand All @@ -34,9 +35,19 @@ class Net(nn.Module):
return self.fc3(x)


fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="uoft-cs/cifar10",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
Expand Down
14 changes: 13 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/task.tensorflow.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,33 @@ import os

import tensorflow as tf
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def load_model():
# Load model and data (MobileNetV2, CIFAR-10)
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
return model


fds = None # Cache FederatedDataset


def load_data(partition_id, num_partitions):
# Download and partition dataset
fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="uoft-cs/cifar10",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id, "train")
partition.set_format("numpy")

Expand Down

0 comments on commit 54b2682

Please sign in to comment.