Skip to content

Commit

Permalink
Do minor fixes to FedBN (#2615)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Nov 17, 2023
1 parent 77b715a commit b43515c
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 39 deletions.
2 changes: 1 addition & 1 deletion baselines/fedbn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ A more detailed explanation of the datasets is given in the following table.
| strategy_fraction_fit | 1.0 |
| strategy.fraction_evaluate | 0.0 |
| training samples per client| 743 |
| lr | 10E-2 |
| client.l_r | 10E-2 |
| local epochs | 1 |
| loss | cross entropy loss |
| optimizer | SGD |
Expand Down
36 changes: 23 additions & 13 deletions baselines/fedbn/fedbn/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,26 @@


class FlowerClient(fl.client.NumPyClient):
"""A standar FlowerClient. This base class.
"""A standard FlowerClient.
is what plain FedAvg clients do.
This base class is what plain FedAvg clients do.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
model: CNNModel,
trainloader: DataLoader,
testloader: DataLoader,
dataset_name: str,
l_r: float,
**kwargs, # pylint: disable=unused-argument
) -> None:
self.trainloader = trainloader
self.testloader = testloader
self.dataset_name = dataset_name
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.l_r = l_r

def get_parameters(self, config) -> NDArrays:
"""Return model parameters as a list of NumPy ndarrays w or w/o.
Expand All @@ -58,22 +60,23 @@ def fit(
"""Set model parameters, train model, return updated model parameters."""
self.set_parameters(parameters)

# evaluate the state of the global model on the train set; the loss returned
# Evaluate the state of the global model on the train set; the loss returned
# is what's reported in Fig3 in the FedBN paper (what this baseline focuses
# in reproducing)
pre_train_loss, pre_train_acc = test(
self.model, self.trainloader, device=self.device
)

# train model on local dataset
# Train model on local dataset
loss, acc = train(
self.model,
self.trainloader,
epochs=1,
l_r=self.l_r,
device=self.device,
)

# construct metrics to return to server
# Construct metrics to return to server
fl_round = config["round"]
metrics = {
"dataset_name": self.dataset_name,
Expand Down Expand Up @@ -107,9 +110,16 @@ def evaluate(
class FedBNFlowerClient(FlowerClient):
"""Similar to FlowerClient but this is used by FedBN clients."""

def __init__(self, bn_state_dir: Path, client_id: int, *args, **kwargs) -> None:
def __init__(self, save_path: Path, client_id: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.bn_state_dir = bn_state_dir
# For FedBN clients we need to persist the state of the BN
# layers across rounds. In Simulation clients are statess
# so everything not communicated to the server (as it is the
# case as with params in BN layers of FedBN clients) is lost
# once a client completes its training. An upcoming version of
# Flower suports stateful clients
bn_state_dir = save_path / "bn_states"
bn_state_dir.mkdir(exist_ok=True)
self.bn_state_pkl = bn_state_dir / f"client_{client_id}.pkl"

def _save_bn_statedict(self) -> None:
Expand All @@ -135,7 +145,7 @@ def get_parameters(self, config) -> NDArrays:
layers.
"""
# first update bn_state_dir
# First update bn_state_dir
self._save_bn_statedict()
# Excluding parameters of BN layers when using FedBN
return [
Expand All @@ -154,8 +164,8 @@ def set_parameters(self, parameters: NDArrays) -> None:
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=False)

# now also load from bn_state_dir
if self.bn_state_pkl.exists(): # it won't exist in the first round
# Now also load from bn_state_dir
if self.bn_state_pkl.exists(): # It won't exist in the first round
bn_state_dict = self._load_bn_statedict()
self.model.load_state_dict(bn_state_dict, strict=False)

Expand All @@ -164,7 +174,7 @@ def gen_client_fn(
client_data: List[Tuple[DataLoader, DataLoader, str]],
client_cfg: DictConfig,
model_cfg: DictConfig,
bn_state_dir: Path,
save_path: Path,
) -> Callable[[str], FlowerClient]:
"""Return a function that will be called to instantiate the cid-th client."""

Expand All @@ -182,7 +192,7 @@ def client_fn(cid: str) -> FlowerClient:
trainloader=trainloader,
testloader=valloader,
dataset_name=dataset_name,
bn_state_dir=bn_state_dir,
save_path=save_path,
client_id=int(cid),
)

Expand Down
3 changes: 2 additions & 1 deletion baselines/fedbn/fedbn/conf/client/fedavg.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
# standard FedAvg Flower Client
_target_: fedbn.client.FlowerClient
client_label: FedAvg
client_label: FedAvg
l_r: 0.01
3 changes: 2 additions & 1 deletion baselines/fedbn/fedbn/conf/client/fedbn.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
# standard FedBN Flower Client
_target_: fedbn.client.FedBNFlowerClient
client_label: FedBN
client_label: FedBN
l_r: 0.01
17 changes: 12 additions & 5 deletions baselines/fedbn/fedbn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__( # pylint: disable=too-many-arguments
transform=None,
):
if train and partitions is not None:
# construct dataset by loading one or more partitions
# Construct dataset by loading one or more partitions
self.images, self.labels = np.load(
os.path.join(
data_path,
Expand Down Expand Up @@ -212,11 +212,18 @@ def get_data(dataset_cfg: DictConfig) -> List[Tuple[DataLoader, DataLoader, str]
client_data = []
d_cfg = dataset_cfg

allowed_percent = (np.arange(1, 11) / 10).tolist()
total_partitions = (
10 # each dataset was pre-processed by the authors and split into 10 partitions
)
# First check that percent used is allowed
allowed_percent = (np.arange(1, total_partitions + 1) / total_partitions).tolist()
assert d_cfg.percent in allowed_percent, (
f"'dataset.percent' should be in {allowed_percent}."
"\nThis is because the trainset is pre-partitioned into 10 disjoint sets."
)

# Then check that with the percent selected, the desired number of clients (and
# therefore dataloaders) can be created.
max_expected_clients = len(d_cfg.to_include) * 1 / d_cfg.percent

num_clients_step = len(d_cfg.to_include)
Expand All @@ -234,10 +241,10 @@ def get_data(dataset_cfg: DictConfig) -> List[Tuple[DataLoader, DataLoader, str]
"'dataset.percent' you chose."
)

# All good, then create as many dataloaders as clients in the experiment.
# Each dataloader might containe one or more partitions (depends on 'percent')
# Each dataloader contains data of the same dataset.
num_clients_per_dataset = d_cfg.num_clients // num_clients_step
total_partitions = (
10 # each dataset was pre-processed by the authors and split into 10 partitions
)
num_parts = int(d_cfg.percent * total_partitions)

for dataset_name in dataset_cfg.to_include:
Expand Down
20 changes: 5 additions & 15 deletions baselines/fedbn/fedbn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
It includes processioning the dataset, instantiate strategy, specify how the global
model is going to be evaluated, etc. At the end, this script saves the results.
"""
# these are the basic packages you'll need here
# feel free to remove some if aren't needed

import pickle
from pathlib import Path

Expand Down Expand Up @@ -33,24 +32,15 @@ def main(cfg: DictConfig) -> None:

# Hydra automatically creates an output directory
# Let's retrieve it and save some results there
save_path = HydraConfig.get().runtime.output_dir

# For FedBN clients we need to persist the state of the BN
# layers across rounds. In Simulation clients are statess
# so everything not communicated to the server (as it is the
# case as with params in BN layers of FedBN clients) is lost
# once a client completes its training. An upcoming version of
# Flower suports stateful clients
bn_states = Path(save_path) / "bn_states"
bn_states.mkdir()
save_path = Path(HydraConfig.get().runtime.output_dir)

# 2. Prepare your dataset
# please ensure you followed the README.md and you downloaded the
# pre-processed dataset suplied by the authors of the FedBN paper
client_data_loaders = get_data(cfg.dataset)

# 3. Define your client generation function
client_fn = gen_client_fn(client_data_loaders, cfg.client, cfg.model, bn_states)
client_fn = gen_client_fn(client_data_loaders, cfg.client, cfg.model, save_path)

# 4. Define your strategy
strategy = instantiate(cfg.strategy)
Expand All @@ -71,14 +61,14 @@ def main(cfg: DictConfig) -> None:
print("................")
print(history)

# save results as a Python pickle using a file_path
# Save results as a Python pickle using a file_path
# the directory created by Hydra for each run
data = {"history": history}
history_path = f"{str(save_path)}/history.pkl"
with open(history_path, "wb") as handle:
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

# simple plot
# Simple plot
quick_plot(history_path)


Expand Down
4 changes: 2 additions & 2 deletions baselines/fedbn/fedbn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def forward(self, x):


# pylint: disable=too-many-locals
def train(model, traindata, epochs, device) -> Tuple[float, float]:
def train(model, traindata, epochs, l_r, device) -> Tuple[float, float]:
"""Train the network."""
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
optimizer = torch.optim.SGD(model.parameters(), lr=l_r)

# Train the network
model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

- niid-Bench [#2428](https://github.com/adap/flower/pull/2428)

- FedBN [#2608](https://github.com/adap/flower/pull/2608)
- FedBN ([#2608](https://github.com/adap/flower/pull/2608), [#2615](https://github.com/adap/flower/pull/2615))

- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526))

Expand Down

0 comments on commit b43515c

Please sign in to comment.