Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature(wxh): Add FedAMP algo and fix bugs. #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

XinghaoWu
Copy link
Collaborator

  1. Add FedAMP algo.
  2. Add the FedCAC paper link in README.md.
  3. Remove the default momentum in default_config.py to fix the bug when setting the optimizer to Adam.

2. Add FedCAC paper link in README.md.
3. Remove the default momentum in default_config.py to fix the bug when set optmizer to adam.
@XinghaoWu XinghaoWu requested a review from kxzxvbk October 25, 2023 15:43
@kxzxvbk kxzxvbk changed the title Add FedAMP algo and fix bugs. Feature(wxh): Add FedAMP algo and fix bugs. Oct 26, 2023
@kxzxvbk kxzxvbk mentioned this pull request Oct 26, 2023
20 tasks
@kxzxvbk kxzxvbk added the algorithm add new algorithm label Oct 26, 2023
class FedAMPClient(BaseClient):
"""
Overview:
This class is the base implementation of client in 'Bold but Cautious: Unlocking the Potential of Personalized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct this document, not FedCAC.


def __init__(self, args, client_id, train_dataset, test_dataset=None):
"""
Initializing train dataset, test dataset(for personalized settings).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct this document, the purpose is to get a copy of local model.

super(FedAMPClient, self).__init__(args, client_id, train_dataset, test_dataset)
self.client_u = copy.deepcopy(self.model)

def FedAMP_Loss_client(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the get_model_difference function (defined in fling/utils/torch_utils.py ) for simplification.

from fling.utils.utils import weight_flatten

@CLIENT_REGISTRY.register('fedamp_client')
class FedAMPClient(BaseClient):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether this client is identical to FedProxClient? What's the differences?

coef = torch.zeros(self.args.client.client_num)
for j, mw in enumerate(self.client_ws):
if i == j: continue
sub = weights[i] - weights[j]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite it using fling.utils.get_model_difference

@@ -14,6 +15,13 @@ def client_sampling(client_ids: Iterable, sample_rate: float) -> List:
)
return participated_clients

def weight_flatten(model) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose that this function can be removed.

@kxzxvbk
Copy link
Collaborator

kxzxvbk commented Oct 26, 2023

Reformat the code before final merge.

@kxzxvbk
Copy link
Collaborator

kxzxvbk commented Oct 26, 2023

Add example configs for cifar100, mnist and tiny-imagenet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algorithm add new algorithm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants