From 653bab5d048f4378a947e90ce2540f862588fe6d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 16 Nov 2023 22:22:43 +0000 Subject: [PATCH 1/5] tweaks --- baselines/fedbn/fedbn/client.py | 34 +++++++++++++++++++++----------- baselines/fedbn/fedbn/dataset.py | 17 +++++++++++----- baselines/fedbn/fedbn/main.py | 20 +++++-------------- baselines/fedbn/fedbn/models.py | 4 ++-- 4 files changed, 41 insertions(+), 34 deletions(-) diff --git a/baselines/fedbn/fedbn/client.py b/baselines/fedbn/fedbn/client.py index be3b4b8f54a..479ffb99b4f 100644 --- a/baselines/fedbn/fedbn/client.py +++ b/baselines/fedbn/fedbn/client.py @@ -16,9 +16,9 @@ class FlowerClient(fl.client.NumPyClient): - """A standar FlowerClient. This base class. + """A standar FlowerClient. - is what plain FedAvg clients do. + This base class is what plain FedAvg clients do. """ def __init__( @@ -27,6 +27,7 @@ def __init__( trainloader: DataLoader, testloader: DataLoader, dataset_name: str, + lr: float, **kwargs, # pylint: disable=unused-argument ) -> None: self.trainloader = trainloader @@ -34,6 +35,7 @@ def __init__( self.dataset_name = dataset_name self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model = model.to(self.device) + self.lr = lr def get_parameters(self, config) -> NDArrays: """Return model parameters as a list of NumPy ndarrays w or w/o. @@ -58,22 +60,23 @@ def fit( """Set model parameters, train model, return updated model parameters.""" self.set_parameters(parameters) - # evaluate the state of the global model on the train set; the loss returned + # Evaluate the state of the global model on the train set; the loss returned # is what's reported in Fig3 in the FedBN paper (what this baseline focuses # in reproducing) pre_train_loss, pre_train_acc = test( self.model, self.trainloader, device=self.device ) - # train model on local dataset + # Train model on local dataset loss, acc = train( self.model, self.trainloader, epochs=1, + lr=self.lr, device=self.device, ) - # construct metrics to return to server + # Construct metrics to return to server fl_round = config["round"] metrics = { "dataset_name": self.dataset_name, @@ -107,9 +110,16 @@ def evaluate( class FedBNFlowerClient(FlowerClient): """Similar to FlowerClient but this is used by FedBN clients.""" - def __init__(self, bn_state_dir: Path, client_id: int, *args, **kwargs) -> None: + def __init__(self, save_path: Path, client_id: int, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.bn_state_dir = bn_state_dir + # For FedBN clients we need to persist the state of the BN + # layers across rounds. In Simulation clients are statess + # so everything not communicated to the server (as it is the + # case as with params in BN layers of FedBN clients) is lost + # once a client completes its training. An upcoming version of + # Flower suports stateful clients + bn_state_dir = save_path / "bn_states" + bn_state_dir.mkdir(exist_ok=True) self.bn_state_pkl = bn_state_dir / f"client_{client_id}.pkl" def _save_bn_statedict(self) -> None: @@ -135,7 +145,7 @@ def get_parameters(self, config) -> NDArrays: layers. """ - # first update bn_state_dir + # First update bn_state_dir self._save_bn_statedict() # Excluding parameters of BN layers when using FedBN return [ @@ -154,8 +164,8 @@ def set_parameters(self, parameters: NDArrays) -> None: state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) self.model.load_state_dict(state_dict, strict=False) - # now also load from bn_state_dir - if self.bn_state_pkl.exists(): # it won't exist in the first round + # Now also load from bn_state_dir + if self.bn_state_pkl.exists(): # It won't exist in the first round bn_state_dict = self._load_bn_statedict() self.model.load_state_dict(bn_state_dict, strict=False) @@ -164,7 +174,7 @@ def gen_client_fn( client_data: List[Tuple[DataLoader, DataLoader, str]], client_cfg: DictConfig, model_cfg: DictConfig, - bn_state_dir: Path, + save_path: Path, ) -> Callable[[str], FlowerClient]: """Return a function that will be called to instantiate the cid-th client.""" @@ -182,7 +192,7 @@ def client_fn(cid: str) -> FlowerClient: trainloader=trainloader, testloader=valloader, dataset_name=dataset_name, - bn_state_dir=bn_state_dir, + save_path=save_path, client_id=int(cid), ) diff --git a/baselines/fedbn/fedbn/dataset.py b/baselines/fedbn/fedbn/dataset.py index 49fdad16643..b6fb7b223f8 100644 --- a/baselines/fedbn/fedbn/dataset.py +++ b/baselines/fedbn/fedbn/dataset.py @@ -24,7 +24,7 @@ def __init__( # pylint: disable=too-many-arguments transform=None, ): if train and partitions is not None: - # construct dataset by loading one or more partitions + # Construct dataset by loading one or more partitions self.images, self.labels = np.load( os.path.join( data_path, @@ -212,11 +212,18 @@ def get_data(dataset_cfg: DictConfig) -> List[Tuple[DataLoader, DataLoader, str] client_data = [] d_cfg = dataset_cfg - allowed_percent = (np.arange(1, 11) / 10).tolist() + total_partitions = ( + 10 # each dataset was pre-processed by the authors and split into 10 partitions + ) + # First check that percent used is allowed + allowed_percent = (np.arange(1, total_partitions+1) / total_partitions).tolist() assert d_cfg.percent in allowed_percent, ( f"'dataset.percent' should be in {allowed_percent}." "\nThis is because the trainset is pre-partitioned into 10 disjoint sets." ) + + # Then check that with the percent selected, the desired number of clients (and there + # fore dataloaders) can be created. max_expected_clients = len(d_cfg.to_include) * 1 / d_cfg.percent num_clients_step = len(d_cfg.to_include) @@ -234,10 +241,10 @@ def get_data(dataset_cfg: DictConfig) -> List[Tuple[DataLoader, DataLoader, str] "'dataset.percent' you chose." ) + # All good, then create as many dataloaders as clients in the experiment. + # Each dataloader might containe one or more partitions (depends on 'percent') + # Each dataloader contains data of the same dataset. num_clients_per_dataset = d_cfg.num_clients // num_clients_step - total_partitions = ( - 10 # each dataset was pre-processed by the authors and split into 10 partitions - ) num_parts = int(d_cfg.percent * total_partitions) for dataset_name in dataset_cfg.to_include: diff --git a/baselines/fedbn/fedbn/main.py b/baselines/fedbn/fedbn/main.py index 8942a175052..15fd4a436bc 100644 --- a/baselines/fedbn/fedbn/main.py +++ b/baselines/fedbn/fedbn/main.py @@ -3,8 +3,7 @@ It includes processioning the dataset, instantiate strategy, specify how the global model is going to be evaluated, etc. At the end, this script saves the results. """ -# these are the basic packages you'll need here -# feel free to remove some if aren't needed + import pickle from pathlib import Path @@ -33,16 +32,7 @@ def main(cfg: DictConfig) -> None: # Hydra automatically creates an output directory # Let's retrieve it and save some results there - save_path = HydraConfig.get().runtime.output_dir - - # For FedBN clients we need to persist the state of the BN - # layers across rounds. In Simulation clients are statess - # so everything not communicated to the server (as it is the - # case as with params in BN layers of FedBN clients) is lost - # once a client completes its training. An upcoming version of - # Flower suports stateful clients - bn_states = Path(save_path) / "bn_states" - bn_states.mkdir() + save_path = Path(HydraConfig.get().runtime.output_dir) # 2. Prepare your dataset # please ensure you followed the README.md and you downloaded the @@ -50,7 +40,7 @@ def main(cfg: DictConfig) -> None: client_data_loaders = get_data(cfg.dataset) # 3. Define your client generation function - client_fn = gen_client_fn(client_data_loaders, cfg.client, cfg.model, bn_states) + client_fn = gen_client_fn(client_data_loaders, cfg.client, cfg.model, save_path) # 4. Define your strategy strategy = instantiate(cfg.strategy) @@ -71,14 +61,14 @@ def main(cfg: DictConfig) -> None: print("................") print(history) - # save results as a Python pickle using a file_path + # Save results as a Python pickle using a file_path # the directory created by Hydra for each run data = {"history": history} history_path = f"{str(save_path)}/history.pkl" with open(history_path, "wb") as handle: pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) - # simple plot + # Simple plot quick_plot(history_path) diff --git a/baselines/fedbn/fedbn/models.py b/baselines/fedbn/fedbn/models.py index efe98290512..0ec9fba6c4a 100644 --- a/baselines/fedbn/fedbn/models.py +++ b/baselines/fedbn/fedbn/models.py @@ -56,11 +56,11 @@ def forward(self, x): # pylint: disable=too-many-locals -def train(model, traindata, epochs, device) -> Tuple[float, float]: +def train(model, traindata, epochs, lr, device) -> Tuple[float, float]: """Train the network.""" # Define loss and optimizer criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) + optimizer = torch.optim.SGD(model.parameters(), lr=lr) # Train the network model.to(device) From 151c73489a1361ce6a5bc60aca7e3591cd81d479 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 16 Nov 2023 22:26:15 +0000 Subject: [PATCH 2/5] w/ previous --- baselines/fedbn/fedbn/conf/client/fedavg.yaml | 3 ++- baselines/fedbn/fedbn/conf/client/fedbn.yaml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/baselines/fedbn/fedbn/conf/client/fedavg.yaml b/baselines/fedbn/fedbn/conf/client/fedavg.yaml index a08fccc6d24..a1766c659b9 100644 --- a/baselines/fedbn/fedbn/conf/client/fedavg.yaml +++ b/baselines/fedbn/fedbn/conf/client/fedavg.yaml @@ -1,4 +1,5 @@ --- # standard FedAvg Flower Client _target_: fedbn.client.FlowerClient -client_label: FedAvg \ No newline at end of file +client_label: FedAvg +lr: 0.01 \ No newline at end of file diff --git a/baselines/fedbn/fedbn/conf/client/fedbn.yaml b/baselines/fedbn/fedbn/conf/client/fedbn.yaml index 0f0e2a05cbf..11377d4ca3e 100644 --- a/baselines/fedbn/fedbn/conf/client/fedbn.yaml +++ b/baselines/fedbn/fedbn/conf/client/fedbn.yaml @@ -1,4 +1,5 @@ --- # standard FedBN Flower Client _target_: fedbn.client.FedBNFlowerClient -client_label: FedBN \ No newline at end of file +client_label: FedBN +lr: 0.01 \ No newline at end of file From ef4c7b0ae74a193076bb11087ad91bb3edef62fb Mon Sep 17 00:00:00 2001 From: javier Date: Thu, 16 Nov 2023 22:43:53 +0000 Subject: [PATCH 3/5] updates --- baselines/fedbn/README.md | 2 +- baselines/fedbn/fedbn/client.py | 8 ++++---- baselines/fedbn/fedbn/conf/client/fedavg.yaml | 2 +- baselines/fedbn/fedbn/conf/client/fedbn.yaml | 2 +- baselines/fedbn/fedbn/dataset.py | 6 +++--- baselines/fedbn/fedbn/models.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/baselines/fedbn/README.md b/baselines/fedbn/README.md index cc4f68b90e9..4b271bd4985 100644 --- a/baselines/fedbn/README.md +++ b/baselines/fedbn/README.md @@ -59,7 +59,7 @@ A more detailed explanation of the datasets is given in the following table. | strategy_fraction_fit | 1.0 | | strategy.fraction_evaluate | 0.0 | | training samples per client| 743 | -| lr | 10E-2 | +| client.l_r | 10E-2 | | local epochs | 1 | | loss | cross entropy loss | | optimizer | SGD | diff --git a/baselines/fedbn/fedbn/client.py b/baselines/fedbn/fedbn/client.py index 479ffb99b4f..9e8b79d24cc 100644 --- a/baselines/fedbn/fedbn/client.py +++ b/baselines/fedbn/fedbn/client.py @@ -21,13 +21,13 @@ class FlowerClient(fl.client.NumPyClient): This base class is what plain FedAvg clients do. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, model: CNNModel, trainloader: DataLoader, testloader: DataLoader, dataset_name: str, - lr: float, + l_r: float, **kwargs, # pylint: disable=unused-argument ) -> None: self.trainloader = trainloader @@ -35,7 +35,7 @@ def __init__( self.dataset_name = dataset_name self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model = model.to(self.device) - self.lr = lr + self.l_r = l_r def get_parameters(self, config) -> NDArrays: """Return model parameters as a list of NumPy ndarrays w or w/o. @@ -72,7 +72,7 @@ def fit( self.model, self.trainloader, epochs=1, - lr=self.lr, + l_r=self.l_r, device=self.device, ) diff --git a/baselines/fedbn/fedbn/conf/client/fedavg.yaml b/baselines/fedbn/fedbn/conf/client/fedavg.yaml index a1766c659b9..993fcb7b4d3 100644 --- a/baselines/fedbn/fedbn/conf/client/fedavg.yaml +++ b/baselines/fedbn/fedbn/conf/client/fedavg.yaml @@ -2,4 +2,4 @@ # standard FedAvg Flower Client _target_: fedbn.client.FlowerClient client_label: FedAvg -lr: 0.01 \ No newline at end of file +l_r: 0.01 \ No newline at end of file diff --git a/baselines/fedbn/fedbn/conf/client/fedbn.yaml b/baselines/fedbn/fedbn/conf/client/fedbn.yaml index 11377d4ca3e..26b6833062a 100644 --- a/baselines/fedbn/fedbn/conf/client/fedbn.yaml +++ b/baselines/fedbn/fedbn/conf/client/fedbn.yaml @@ -2,4 +2,4 @@ # standard FedBN Flower Client _target_: fedbn.client.FedBNFlowerClient client_label: FedBN -lr: 0.01 \ No newline at end of file +l_r: 0.01 \ No newline at end of file diff --git a/baselines/fedbn/fedbn/dataset.py b/baselines/fedbn/fedbn/dataset.py index b6fb7b223f8..8ce3eb520f8 100644 --- a/baselines/fedbn/fedbn/dataset.py +++ b/baselines/fedbn/fedbn/dataset.py @@ -216,14 +216,14 @@ def get_data(dataset_cfg: DictConfig) -> List[Tuple[DataLoader, DataLoader, str] 10 # each dataset was pre-processed by the authors and split into 10 partitions ) # First check that percent used is allowed - allowed_percent = (np.arange(1, total_partitions+1) / total_partitions).tolist() + allowed_percent = (np.arange(1, total_partitions + 1) / total_partitions).tolist() assert d_cfg.percent in allowed_percent, ( f"'dataset.percent' should be in {allowed_percent}." "\nThis is because the trainset is pre-partitioned into 10 disjoint sets." ) - # Then check that with the percent selected, the desired number of clients (and there - # fore dataloaders) can be created. + # Then check that with the percent selected, the desired number of clients (and + # therefore dataloaders) can be created. max_expected_clients = len(d_cfg.to_include) * 1 / d_cfg.percent num_clients_step = len(d_cfg.to_include) diff --git a/baselines/fedbn/fedbn/models.py b/baselines/fedbn/fedbn/models.py index 0ec9fba6c4a..c15a8e1815e 100644 --- a/baselines/fedbn/fedbn/models.py +++ b/baselines/fedbn/fedbn/models.py @@ -56,11 +56,11 @@ def forward(self, x): # pylint: disable=too-many-locals -def train(model, traindata, epochs, lr, device) -> Tuple[float, float]: +def train(model, traindata, epochs, l_r, device) -> Tuple[float, float]: """Train the network.""" # Define loss and optimizer criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(model.parameters(), lr=lr) + optimizer = torch.optim.SGD(model.parameters(), lr=l_r) # Train the network model.to(device) From 981c9d17169c0118854b6338d1a9181898886a9b Mon Sep 17 00:00:00 2001 From: javier Date: Thu, 16 Nov 2023 22:45:56 +0000 Subject: [PATCH 4/5] in changelog --- doc/source/ref-changelog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 10112088b88..dd8a6e22540 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -46,7 +46,7 @@ - niid-Bench [#2428](https://github.com/adap/flower/pull/2428) - - FedBN [#2608](https://github.com/adap/flower/pull/2608) + - FedBN ([#2608](https://github.com/adap/flower/pull/2608), [#2615](https://github.com/adap/flower/pull/2615)) - **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526)) From 9ce55ff8741aaf2143e5ae7229d1827bf67d1b32 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 17 Nov 2023 10:39:20 +0100 Subject: [PATCH 5/5] Update baselines/fedbn/fedbn/client.py --- baselines/fedbn/fedbn/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baselines/fedbn/fedbn/client.py b/baselines/fedbn/fedbn/client.py index 9e8b79d24cc..ce498b8cebc 100644 --- a/baselines/fedbn/fedbn/client.py +++ b/baselines/fedbn/fedbn/client.py @@ -16,7 +16,7 @@ class FlowerClient(fl.client.NumPyClient): - """A standar FlowerClient. + """A standard FlowerClient. This base class is what plain FedAvg clients do. """