-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4b16e10
Showing
17 changed files
with
1,350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# project related | ||
dataset/ | ||
models*/ | ||
runs*/ | ||
figs/ | ||
.vscode/ | ||
|
||
*.DS_Store | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2020 THU Media | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Continual Local Training for Better Initialization of Federated Models | ||
|
||
The implementation of "Continual Local Training for Better Initialization of Federated Models" (ICIP 2020). | ||
[[Conference Version]](#)[[arXiv Version]](https://arxiv.org/abs/2005.12657) | ||
|
||
## Introduction | ||
|
||
Federated learning (FL) refers to the learning paradigm that trains machine learning models directly in the decentralized systems consisting of smart edge devices without transmitting the raw data, which avoids the heavy communication costs and privacy concerns. | ||
Given the typical heterogeneous data distributions in such situations, the popular FL algorithm *Federated Averaging* (FedAvg) suffers from weight divergence and thus cannot achieve a competitive performance for the global model (denoted as the *initial performance* in FL) compared to centralized methods. | ||
|
||
In this paper, we propose the local continual training strategy to address this problem. | ||
Importance weights are evaluated on a small proxy dataset on the central server and then used to constrain the local training. | ||
With this additional term, we alleviate the weight divergence and continually integrate the knowledge on different local clients into the global model, which ensures a better generalization ability. | ||
Experiments on various FL settings demonstrate that our method significantly improves the initial performance of federated models with few extra communication costs. | ||
|
||
<div align="center"> | ||
<img src="./overview.png" width = "70%" height = "70%" alt="overview" /> | ||
</div> | ||
|
||
## Dependency | ||
|
||
``` | ||
python==3.7 | ||
pytorch==1.4 | ||
prefetch_generator | ||
tensorboardx | ||
``` | ||
|
||
## How To Run | ||
|
||
1. Download the `dataset.tar.gz` in the [release page](https://github.com/thu-media/FedCL/releases/tag/v1.0), and unzip it to the root of the repository. | ||
|
||
2. Then you can start with | ||
```shell | ||
python cifar_main.py | ||
``` | ||
or | ||
```shell | ||
python mnist_main.py | ||
``` | ||
The hyperparameters are defined in standalone file `config.py`. | ||
|
||
## Code Structure | ||
|
||
``` | ||
-- mnist_main.py # the main train file for experiments on split MNIST | ||
-- cifar_main.py # the main train file for experiments on split CIFAR10 | ||
-- config.py # the global config file | ||
-- model/ | ||
|_ cifar_model.py # the model file for CIFAR10 | ||
|_ mnist_model.py # the model file for MNIST | ||
-- data/ | ||
|_ cifar_data.py # define data loader and allocator for CIFAR10 | ||
|_ mnist_data.py # define data loader and allocator for MNIST | ||
-- core/ | ||
|_ agent.py # core functions for FL clients, e.g., train/test/estimate importance weights | ||
|_ trainer.py # core functions for FL server, e.g., model aggregation/initialize clients | ||
|_ utils.py # define some utils | ||
``` | ||
## Cite | ||
If you find this work useful to you, please cite [the conference version](#): | ||
``` | ||
To be published | ||
``` | ||
or [the arXiv version](https://arxiv.org/abs/2005.12657): | ||
``` | ||
@article{yao2020continual, | ||
title={Continual Local Training for Better Initialization of Federated Models}, | ||
author={Yao, Xin and Sun, Lifeng}, | ||
journal={arXiv preprint arXiv:2005.12657}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
import os | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
import torch.multiprocessing as mp | ||
|
||
import config | ||
from core import Agent, Trainer, train_local_mp | ||
from model import CifarModel | ||
from data import CifarData | ||
|
||
|
||
class CIFARAgent(Agent): | ||
"""CIFARAgent for CIFAR10 and CIFAR100.""" | ||
def __init__(self, global_args, subset=tuple(range(10)), fine='CIFAR10'): | ||
super().__init__(global_args, subset, fine) | ||
self.distr_type = global_args.distr_type | ||
if self.distr_type == 'uniform': | ||
self.distribution = np.array([0.1] * 10) | ||
elif self.distr_type == 'dirichlet': | ||
self.distribution = np.random.dirichlet([global_args.alpha] * 10) | ||
else: | ||
raise ValueError(f'Invalid distribution type: {self.distr_type}.') | ||
|
||
def load_data(self, data_alloc, center=False): | ||
print("=> loading data") | ||
if center: | ||
self.train_loader, self.test_loader, self.num_train = \ | ||
data_alloc.create_dataset_for_center(self.batch_size, self.num_workers) | ||
else: | ||
self.train_loader, self.test_loader, self.num_train = \ | ||
data_alloc.create_dataset_for_client(self.distribution, self.batch_size, | ||
self.num_workers, self.subset) | ||
|
||
def build_model(self): | ||
print("=> building model") | ||
if self.fine == 'CIFAR10': | ||
num_class = 10 | ||
elif self.fine == 'CIFAR100': | ||
num_class = 100 | ||
else: | ||
raise ValueError('Invalid dataset choice.') | ||
self.model = CifarModel(num_class).to(self.device) | ||
self.criterion = nn.CrossEntropyLoss().to(self.device) | ||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, | ||
momentum=0.9, weight_decay=1e-4) | ||
|
||
|
||
class CIFARTrainer(Trainer): | ||
"""CIFAR Trainer.""" | ||
def __init__(self, global_args): | ||
super().__init__(global_args) | ||
self.data_alloc = CifarData(self.num_locals, self.sample_rate, fine=self.fine) | ||
|
||
# init the global model | ||
self.global_agent = CIFARAgent(global_args, fine=self.fine) | ||
self.global_agent.load_data(self.data_alloc, center=True) | ||
self.global_agent.build_model() | ||
self.global_agent.resume_model(self.resume) | ||
|
||
def build_local_models(self, global_args): | ||
self.nets_pool = list() | ||
for _ in range(self.num_locals): | ||
self.nets_pool.append(CIFARAgent(global_args, fine=self.fine)) | ||
self.init_local_models() | ||
|
||
def train(self): | ||
for rnd in range(self.rounds): | ||
np.random.shuffle(self.nets_pool) | ||
pool = mp.Pool(self.num_per_rnd) | ||
self.q = mp.Manager().Queue() | ||
dict_new = self.global_agent.model.state_dict() | ||
if self.estimate_weights_in_center and rnd % self.interval == 0: | ||
w_d = self.global_agent.estimate_weights(self.policy) | ||
else: | ||
w_d = None | ||
for net in self.nets_pool[:self.num_per_rnd]: | ||
net.model.load_state_dict(dict_new) | ||
net.set_lr(self.global_agent.lr) | ||
pool.apply_async(train_local_mp, (net, self.local_epochs, rnd, self.q, self.policy, w_d)) | ||
pool.close() | ||
pool.join() | ||
self.update_global(rnd) | ||
|
||
def main(): | ||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | ||
torch.manual_seed(args.seed) | ||
torch.cuda.manual_seed(args.seed) | ||
np.random.seed(args.seed) | ||
mp.set_start_method('forkserver') | ||
|
||
cifar_trainer = CIFARTrainer(args) | ||
|
||
# test | ||
if args.mode == 'test': | ||
cifar_trainer.test() | ||
return | ||
|
||
cifar_trainer.build_local_models(args) | ||
cifar_trainer.train() | ||
|
||
if __name__ == '__main__': | ||
args = config.get_args() | ||
args.fine = 'CIFAR10' | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
#pylint: disable=C0301,C0326 | ||
import argparse | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--model_file', type=str, default='model.pth.tar', help='File to save model.') | ||
parser.add_argument('--model_dir', type=str, default='models', help='Directory for storing checkpoint file.') | ||
parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to resume checkpoint (default: none)') | ||
parser.add_argument('--mode', type=str, default='train', choices=('train', 'test'), help='train or test.') | ||
parser.add_argument('--log_dir', type=str, default='runs_attn', help='Directory for logging.') | ||
parser.add_argument('--gpu', type=str, default='0', help='Number of gpu to use') | ||
parser.add_argument('--seed', type=int, default=1234, help='Random seed') | ||
|
||
# hyper parameter for local data and training | ||
parser.add_argument('--distr_type', type=str, default='uniform', choices=('uniform', 'dirichlet'), help='Distribution to construct local data.') | ||
parser.add_argument('--alpha', type=float, default=1., help='alpha for dirichlet distribution. Must > 0 if dirichlet distribution is chosen.') | ||
parser.add_argument('--lr', type=float, default=5e-3, help='learning rate.') | ||
parser.add_argument('--min_lr', type=float, default=1e-4, help='minimum learning rate.') | ||
parser.add_argument('--decay_rate', type=float, default=0.99, help='lr decay rate.') | ||
parser.add_argument('--batch_size', type=int, default=64, help='Batch size. (B)') | ||
parser.add_argument('--local_epochs', type=int, default=2, help='Number of epoch in local. (E)') | ||
parser.add_argument('--num_workers', type=int, default=0, help='number of workers to preprocess data, must be 0 for mp agents.') | ||
|
||
# hyper parameters for central server | ||
parser.add_argument('--num_locals', type=int, default=10, help='number of local agents.') | ||
parser.add_argument('--num_per_rnd', type=int, default=2, help='number of local agents to train per round.') | ||
parser.add_argument('--rounds', type=int, default=500, help='number of communication rounds.') | ||
parser.add_argument('--sample_rate', type=float, default=-1., help='sample rate of central data.') | ||
parser.add_argument('--policy', type=str, default='avg', choices=('avg', 'ewc', 'mas'), help='Policy for estimating parameter importance.') | ||
parser.add_argument('--estimate_weights_in_center', action='store_true', help='Estimate parameter importance in central server.') | ||
|
||
# hyper parameters for ewc train | ||
parser.add_argument('--coe', type=float, default=0.5, help='The coefficient for local additional constraint.') | ||
parser.add_argument('--interval', type=float, default=1, help='The interval for weight estimation.') | ||
|
||
args = parser.parse_args() | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .agent import Agent | ||
from .trainer import Trainer, train_local_mp, test_local_mp |
Oops, something went wrong.