Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do minor fixes to FedBN #2615

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion baselines/fedbn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ A more detailed explanation of the datasets is given in the following table.
| strategy_fraction_fit | 1.0 |
| strategy.fraction_evaluate | 0.0 |
| training samples per client| 743 |
| lr | 10E-2 |
| client.l_r | 10E-2 |
| local epochs | 1 |
| loss | cross entropy loss |
| optimizer | SGD |
Expand Down
36 changes: 23 additions & 13 deletions baselines/fedbn/fedbn/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,26 @@


class FlowerClient(fl.client.NumPyClient):
"""A standar FlowerClient. This base class.
"""A standar FlowerClient.
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

is what plain FedAvg clients do.
This base class is what plain FedAvg clients do.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
model: CNNModel,
trainloader: DataLoader,
testloader: DataLoader,
dataset_name: str,
l_r: float,
**kwargs, # pylint: disable=unused-argument
) -> None:
self.trainloader = trainloader
self.testloader = testloader
self.dataset_name = dataset_name
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.l_r = l_r

def get_parameters(self, config) -> NDArrays:
"""Return model parameters as a list of NumPy ndarrays w or w/o.
Expand All @@ -58,22 +60,23 @@ def fit(
"""Set model parameters, train model, return updated model parameters."""
self.set_parameters(parameters)

# evaluate the state of the global model on the train set; the loss returned
# Evaluate the state of the global model on the train set; the loss returned
# is what's reported in Fig3 in the FedBN paper (what this baseline focuses
# in reproducing)
pre_train_loss, pre_train_acc = test(
self.model, self.trainloader, device=self.device
)

# train model on local dataset
# Train model on local dataset
loss, acc = train(
self.model,
self.trainloader,
epochs=1,
l_r=self.l_r,
device=self.device,
)

# construct metrics to return to server
# Construct metrics to return to server
fl_round = config["round"]
metrics = {
"dataset_name": self.dataset_name,
Expand Down Expand Up @@ -107,9 +110,16 @@ def evaluate(
class FedBNFlowerClient(FlowerClient):
"""Similar to FlowerClient but this is used by FedBN clients."""

def __init__(self, bn_state_dir: Path, client_id: int, *args, **kwargs) -> None:
def __init__(self, save_path: Path, client_id: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.bn_state_dir = bn_state_dir
# For FedBN clients we need to persist the state of the BN
# layers across rounds. In Simulation clients are statess
# so everything not communicated to the server (as it is the
# case as with params in BN layers of FedBN clients) is lost
# once a client completes its training. An upcoming version of
# Flower suports stateful clients
bn_state_dir = save_path / "bn_states"
bn_state_dir.mkdir(exist_ok=True)
self.bn_state_pkl = bn_state_dir / f"client_{client_id}.pkl"

def _save_bn_statedict(self) -> None:
Expand All @@ -135,7 +145,7 @@ def get_parameters(self, config) -> NDArrays:

layers.
"""
# first update bn_state_dir
# First update bn_state_dir
self._save_bn_statedict()
# Excluding parameters of BN layers when using FedBN
return [
Expand All @@ -154,8 +164,8 @@ def set_parameters(self, parameters: NDArrays) -> None:
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=False)

# now also load from bn_state_dir
if self.bn_state_pkl.exists(): # it won't exist in the first round
# Now also load from bn_state_dir
if self.bn_state_pkl.exists(): # It won't exist in the first round
bn_state_dict = self._load_bn_statedict()
self.model.load_state_dict(bn_state_dict, strict=False)

Expand All @@ -164,7 +174,7 @@ def gen_client_fn(
client_data: List[Tuple[DataLoader, DataLoader, str]],
client_cfg: DictConfig,
model_cfg: DictConfig,
bn_state_dir: Path,
save_path: Path,
) -> Callable[[str], FlowerClient]:
"""Return a function that will be called to instantiate the cid-th client."""

Expand All @@ -182,7 +192,7 @@ def client_fn(cid: str) -> FlowerClient:
trainloader=trainloader,
testloader=valloader,
dataset_name=dataset_name,
bn_state_dir=bn_state_dir,
save_path=save_path,
client_id=int(cid),
)

Expand Down
3 changes: 2 additions & 1 deletion baselines/fedbn/fedbn/conf/client/fedavg.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
# standard FedAvg Flower Client
_target_: fedbn.client.FlowerClient
client_label: FedAvg
client_label: FedAvg
l_r: 0.01
3 changes: 2 additions & 1 deletion baselines/fedbn/fedbn/conf/client/fedbn.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
# standard FedBN Flower Client
_target_: fedbn.client.FedBNFlowerClient
client_label: FedBN
client_label: FedBN
l_r: 0.01
17 changes: 12 additions & 5 deletions baselines/fedbn/fedbn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__( # pylint: disable=too-many-arguments
transform=None,
):
if train and partitions is not None:
# construct dataset by loading one or more partitions
# Construct dataset by loading one or more partitions
self.images, self.labels = np.load(
os.path.join(
data_path,
Expand Down Expand Up @@ -212,11 +212,18 @@ def get_data(dataset_cfg: DictConfig) -> List[Tuple[DataLoader, DataLoader, str]
client_data = []
d_cfg = dataset_cfg

allowed_percent = (np.arange(1, 11) / 10).tolist()
total_partitions = (
10 # each dataset was pre-processed by the authors and split into 10 partitions
)
# First check that percent used is allowed
allowed_percent = (np.arange(1, total_partitions + 1) / total_partitions).tolist()
assert d_cfg.percent in allowed_percent, (
f"'dataset.percent' should be in {allowed_percent}."
"\nThis is because the trainset is pre-partitioned into 10 disjoint sets."
)

# Then check that with the percent selected, the desired number of clients (and
# therefore dataloaders) can be created.
max_expected_clients = len(d_cfg.to_include) * 1 / d_cfg.percent

num_clients_step = len(d_cfg.to_include)
Expand All @@ -234,10 +241,10 @@ def get_data(dataset_cfg: DictConfig) -> List[Tuple[DataLoader, DataLoader, str]
"'dataset.percent' you chose."
)

# All good, then create as many dataloaders as clients in the experiment.
# Each dataloader might containe one or more partitions (depends on 'percent')
# Each dataloader contains data of the same dataset.
num_clients_per_dataset = d_cfg.num_clients // num_clients_step
total_partitions = (
10 # each dataset was pre-processed by the authors and split into 10 partitions
)
num_parts = int(d_cfg.percent * total_partitions)

for dataset_name in dataset_cfg.to_include:
Expand Down
20 changes: 5 additions & 15 deletions baselines/fedbn/fedbn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
It includes processioning the dataset, instantiate strategy, specify how the global
model is going to be evaluated, etc. At the end, this script saves the results.
"""
# these are the basic packages you'll need here
# feel free to remove some if aren't needed

import pickle
from pathlib import Path

Expand Down Expand Up @@ -33,24 +32,15 @@ def main(cfg: DictConfig) -> None:

# Hydra automatically creates an output directory
# Let's retrieve it and save some results there
save_path = HydraConfig.get().runtime.output_dir

# For FedBN clients we need to persist the state of the BN
# layers across rounds. In Simulation clients are statess
# so everything not communicated to the server (as it is the
# case as with params in BN layers of FedBN clients) is lost
# once a client completes its training. An upcoming version of
# Flower suports stateful clients
bn_states = Path(save_path) / "bn_states"
bn_states.mkdir()
save_path = Path(HydraConfig.get().runtime.output_dir)

# 2. Prepare your dataset
# please ensure you followed the README.md and you downloaded the
# pre-processed dataset suplied by the authors of the FedBN paper
client_data_loaders = get_data(cfg.dataset)

# 3. Define your client generation function
client_fn = gen_client_fn(client_data_loaders, cfg.client, cfg.model, bn_states)
client_fn = gen_client_fn(client_data_loaders, cfg.client, cfg.model, save_path)

# 4. Define your strategy
strategy = instantiate(cfg.strategy)
Expand All @@ -71,14 +61,14 @@ def main(cfg: DictConfig) -> None:
print("................")
print(history)

# save results as a Python pickle using a file_path
# Save results as a Python pickle using a file_path
# the directory created by Hydra for each run
data = {"history": history}
history_path = f"{str(save_path)}/history.pkl"
with open(history_path, "wb") as handle:
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

# simple plot
# Simple plot
quick_plot(history_path)


Expand Down
4 changes: 2 additions & 2 deletions baselines/fedbn/fedbn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def forward(self, x):


# pylint: disable=too-many-locals
def train(model, traindata, epochs, device) -> Tuple[float, float]:
def train(model, traindata, epochs, l_r, device) -> Tuple[float, float]:
"""Train the network."""
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
optimizer = torch.optim.SGD(model.parameters(), lr=l_r)

# Train the network
model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

- niid-Bench [#2428](https://github.com/adap/flower/pull/2428)

- FedBN [#2608](https://github.com/adap/flower/pull/2608)
- FedBN ([#2608](https://github.com/adap/flower/pull/2608), [#2615](https://github.com/adap/flower/pull/2615))

- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526))

Expand Down