Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FedAvg] Balance classes in dataset #1670

Merged
merged 3 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def evaluate(
def gen_client_fn(
device: torch.device,
iid: bool,
balance: bool,
num_clients: int,
num_epochs: int,
batch_size: int,
Expand All @@ -85,6 +86,9 @@ def gen_client_fn(
should be independent and identically distributed between the clients
or if the data should first be sorted by labels and distributed by chunks
to each client (used to test the convergence in a worst case scenario)
balance : bool
Whether the dataset should contain an equal number of samples in each class,
by default True
num_clients : int
The number of clients present in the setup
num_epochs : int
Expand All @@ -102,7 +106,7 @@ def gen_client_fn(
the DataLoader that will be used for testing
"""
trainloaders, valloaders, testloader = load_datasets(
iid=iid, num_clients=num_clients, batch_size=batch_size
iid=iid, balance=balance, num_clients=num_clients, batch_size=batch_size
)

def client_fn(cid: str) -> FlowerClient:
Expand Down
58 changes: 54 additions & 4 deletions baselines/flwr_baselines/publications/fedavg_mnist/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from torchvision.datasets import MNIST


def load_datasets(
def load_datasets( # pylint: disable=too-many-arguments
num_clients: int = 10,
iid: Optional[bool] = True,
balance: Optional[bool] = True,
val_ratio: float = 0.1,
batch_size: Optional[int] = 32,
seed: Optional[int] = 42,
Expand All @@ -27,6 +28,9 @@ def load_datasets(
Whether the data should be independent and identically distributed between the
clients or if the data should first be sorted by labels and distributed by chunks
to each client (used to test the convergence in a worst case scenario), by default True
balance : bool, optional
Whether the dataset should contain an equal number of samples in each class,
by default True
val_ratio : float, optional
The ratio of training data that will be used for validation (between 0 and 1),
by default 0.1
Expand All @@ -40,14 +44,13 @@ def load_datasets(
Tuple[DataLoader, DataLoader, DataLoader]
The DataLoader for training, the DataLoader for validation, the DataLoader for testing.
"""
datasets, testset = _partition_data(num_clients, iid, seed)
datasets, testset = _partition_data(num_clients, iid, balance, seed)
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for dataset in datasets:
len_val = int(len(dataset) / (1 / val_ratio))
len_train = len(dataset) - len_val
lengths = [len_train, len_val]
lengths = [len(dataset) - len_val, len_val]
ds_train, ds_val = random_split(
dataset, lengths, torch.Generator().manual_seed(seed)
)
Expand Down Expand Up @@ -75,6 +78,7 @@ def _download_data() -> Tuple[Dataset, Dataset]:
def _partition_data(
num_clients: int = 10,
iid: Optional[bool] = True,
balance: Optional[bool] = True,
seed: Optional[int] = 42,
) -> Tuple[List[Dataset], Dataset]:
"""Split training set into iid or non iid partitions to simulate the
Expand All @@ -88,6 +92,9 @@ def _partition_data(
Whether the data should be independent and identically distributed between
the clients or if the data should first be sorted by labels and distributed by chunks
to each client (used to test the convergence in a worst case scenario), by default True
balance : bool, optional
Whether the dataset should contain an equal number of samples in each class,
by default True
seed : int, optional
Used to set a fix seed to replicate experiments, by default 42

Expand All @@ -102,6 +109,9 @@ def _partition_data(
if iid:
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(seed))
else:
if balance:
trainset = _balance_classes(trainset, seed)
partition_size = int(len(trainset) / num_clients)
shard_size = int(partition_size / 2)
idxs = trainset.targets.argsort()
sorted_data = Subset(trainset, idxs)
Expand All @@ -119,3 +129,43 @@ def _partition_data(
]

return datasets, testset


def _balance_classes(
trainset: Dataset,
seed: Optional[int] = 42,
) -> Dataset:
"""Balance the classes of the trainset.

Trims the dataset so each class contains as many elements as the
class that contained the least elements.

Parameters
----------
trainset : Dataset
The training dataset that needs to be balanced.
seed : int, optional
Used to set a fix seed to replicate experiments, by default 42.

Returns
-------
Dataset
The balanced training dataset.
"""
class_counts = np.bincount(trainset.targets)
smallest = np.min(class_counts)
idxs = trainset.targets.argsort()
tmp = [Subset(trainset, idxs[: int(smallest)])]
tmp_targets = [trainset.targets[idxs[: int(smallest)]]]
for count in class_counts:
tmp.append(Subset(trainset, idxs[int(count) : int(count + smallest)]))
tmp_targets.append(trainset.targets[idxs[int(count) : int(count + smallest)]])
unshuffled = ConcatDataset(tmp)
unshuffled_targets = torch.cat(tmp_targets)
shuffled_idxs = torch.randperm(
len(unshuffled), generator=torch.Generator().manual_seed(seed)
)
shuffled = Subset(unshuffled, shuffled_idxs)
shuffled.targets = unshuffled_targets[shuffled_idxs]

return shuffled
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ num_rounds: 10
num_epochs: 5
batch_size: 10
iid: False
balance: True
client_fraction: 1.0
expected_maximum: 0.9924
learning_rate: 0.1
Expand Down
30 changes: 14 additions & 16 deletions baselines/flwr_baselines/publications/fedavg_mnist/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Runs CNN federated learning for MNST dataset."""
"""Runs CNN federated learning for MNIST dataset."""

from pathlib import Path

Expand Down Expand Up @@ -28,6 +28,7 @@ def main(cfg: DictConfig) -> None:
device=DEVICE,
num_clients=cfg.num_clients,
iid=cfg.iid,
balance=cfg.balance,
learning_rate=cfg.learning_rate,
)

Expand All @@ -50,29 +51,26 @@ def main(cfg: DictConfig) -> None:
config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
strategy=strategy,
)

file_suffix: str = (
f"{'_iid' if cfg.iid else ''}"
f"{'_balanced' if cfg.balance else ''}"
f"_C={cfg.num_clients}"
f"_B={cfg.batch_size}"
f"_E={cfg.num_epochs}"
f"_R={cfg.num_rounds}"
)

np.save(
Path(cfg.save_path)
/ Path(
f"hist_C={cfg.num_clients}"
f"_B={cfg.batch_size}"
f"_E={cfg.num_epochs}"
f"_R={cfg.num_rounds}"
f"_stag={1 - cfg.client_fraction}"
),
Path(cfg.save_path) / Path(f"hist{file_suffix}"),
history, # type: ignore
)

utils.plot_metric_from_history(
history,
cfg.save_path,
cfg.expected_maximum,
(
f"_C={cfg.num_clients}"
f"_B={cfg.batch_size}"
f"_E={cfg.num_epochs}"
f"_R={cfg.num_rounds}"
f"_stag={1 - cfg.client_fraction}"
),
file_suffix,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
"""Used to test the model and the data partitionning."""


from flwr_baselines.publications.fedavg_mnist import dataset, model


def test_cnn_size_mnist() -> None:
"""Test number of parameters with MNIST-sized inputs."""
# Prepare
net = model.Net()
expected = 1_663_370

# Execute
actual = sum([p.numel() for p in net.parameters()])

# Assert
assert actual == expected
from flwr_baselines.publications.fedavg_mnist import dataset


def test_non_iid_partitionning(num_clients: int = 100) -> None:
Expand All @@ -26,7 +13,7 @@ def test_non_iid_partitionning(num_clients: int = 100) -> None:
The number of clients to distribute the data to, by default 100
"""
trainloaders, _, _ = dataset.load_datasets(
batch_size=1, num_clients=num_clients, iid=False
batch_size=1, num_clients=num_clients, iid=False, balance=True
)
for trainloader in trainloaders:
labels = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Used to test the model and the data partitionning."""


from flwr_baselines.publications.fedavg_mnist import model


def test_cnn_size_mnist() -> None:
"""Test number of parameters with MNIST-sized inputs."""
# Prepare
net = model.Net()
expected = 1_663_370

# Execute
actual = sum([p.numel() for p in net.parameters()])

# Assert
assert actual == expected