Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
yaox12 committed May 27, 2020
0 parents commit 4b16e10
Show file tree
Hide file tree
Showing 17 changed files with 1,350 additions and 0 deletions.
113 changes: 113 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# project related
dataset/
models*/
runs*/
figs/
.vscode/

*.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2020 THU Media

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Continual Local Training for Better Initialization of Federated Models

The implementation of "Continual Local Training for Better Initialization of Federated Models" (ICIP 2020).
[[Conference Version]](#)[[arXiv Version]](https://arxiv.org/abs/2005.12657)

## Introduction

Federated learning (FL) refers to the learning paradigm that trains machine learning models directly in the decentralized systems consisting of smart edge devices without transmitting the raw data, which avoids the heavy communication costs and privacy concerns.
Given the typical heterogeneous data distributions in such situations, the popular FL algorithm *Federated Averaging* (FedAvg) suffers from weight divergence and thus cannot achieve a competitive performance for the global model (denoted as the *initial performance* in FL) compared to centralized methods.

In this paper, we propose the local continual training strategy to address this problem.
Importance weights are evaluated on a small proxy dataset on the central server and then used to constrain the local training.
With this additional term, we alleviate the weight divergence and continually integrate the knowledge on different local clients into the global model, which ensures a better generalization ability.
Experiments on various FL settings demonstrate that our method significantly improves the initial performance of federated models with few extra communication costs.

<div align="center">
<img src="./overview.png" width = "70%" height = "70%" alt="overview" />
</div>

## Dependency

```
python==3.7
pytorch==1.4
prefetch_generator
tensorboardx
```

## How To Run

1. Download the `dataset.tar.gz` in the [release page](https://github.com/thu-media/FedCL/releases/tag/v1.0), and unzip it to the root of the repository.

2. Then you can start with
```shell
python cifar_main.py
```
or
```shell
python mnist_main.py
```
The hyperparameters are defined in standalone file `config.py`.

## Code Structure

```
-- mnist_main.py # the main train file for experiments on split MNIST
-- cifar_main.py # the main train file for experiments on split CIFAR10
-- config.py # the global config file
-- model/
|_ cifar_model.py # the model file for CIFAR10
|_ mnist_model.py # the model file for MNIST
-- data/
|_ cifar_data.py # define data loader and allocator for CIFAR10
|_ mnist_data.py # define data loader and allocator for MNIST
-- core/
|_ agent.py # core functions for FL clients, e.g., train/test/estimate importance weights
|_ trainer.py # core functions for FL server, e.g., model aggregation/initialize clients
|_ utils.py # define some utils
```
## Cite
If you find this work useful to you, please cite [the conference version](#):
```
To be published
```
or [the arXiv version](https://arxiv.org/abs/2005.12657):
```
@article{yao2020continual,
title={Continual Local Training for Better Initialization of Federated Models},
author={Yao, Xin and Sun, Lifeng},
journal={arXiv preprint arXiv:2005.12657},
year={2020}
}
```
107 changes: 107 additions & 0 deletions cifar_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
from torch import nn
import torch.multiprocessing as mp

import config
from core import Agent, Trainer, train_local_mp
from model import CifarModel
from data import CifarData


class CIFARAgent(Agent):
"""CIFARAgent for CIFAR10 and CIFAR100."""
def __init__(self, global_args, subset=tuple(range(10)), fine='CIFAR10'):
super().__init__(global_args, subset, fine)
self.distr_type = global_args.distr_type
if self.distr_type == 'uniform':
self.distribution = np.array([0.1] * 10)
elif self.distr_type == 'dirichlet':
self.distribution = np.random.dirichlet([global_args.alpha] * 10)
else:
raise ValueError(f'Invalid distribution type: {self.distr_type}.')

def load_data(self, data_alloc, center=False):
print("=> loading data")
if center:
self.train_loader, self.test_loader, self.num_train = \
data_alloc.create_dataset_for_center(self.batch_size, self.num_workers)
else:
self.train_loader, self.test_loader, self.num_train = \
data_alloc.create_dataset_for_client(self.distribution, self.batch_size,
self.num_workers, self.subset)

def build_model(self):
print("=> building model")
if self.fine == 'CIFAR10':
num_class = 10
elif self.fine == 'CIFAR100':
num_class = 100
else:
raise ValueError('Invalid dataset choice.')
self.model = CifarModel(num_class).to(self.device)
self.criterion = nn.CrossEntropyLoss().to(self.device)
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr,
momentum=0.9, weight_decay=1e-4)


class CIFARTrainer(Trainer):
"""CIFAR Trainer."""
def __init__(self, global_args):
super().__init__(global_args)
self.data_alloc = CifarData(self.num_locals, self.sample_rate, fine=self.fine)

# init the global model
self.global_agent = CIFARAgent(global_args, fine=self.fine)
self.global_agent.load_data(self.data_alloc, center=True)
self.global_agent.build_model()
self.global_agent.resume_model(self.resume)

def build_local_models(self, global_args):
self.nets_pool = list()
for _ in range(self.num_locals):
self.nets_pool.append(CIFARAgent(global_args, fine=self.fine))
self.init_local_models()

def train(self):
for rnd in range(self.rounds):
np.random.shuffle(self.nets_pool)
pool = mp.Pool(self.num_per_rnd)
self.q = mp.Manager().Queue()
dict_new = self.global_agent.model.state_dict()
if self.estimate_weights_in_center and rnd % self.interval == 0:
w_d = self.global_agent.estimate_weights(self.policy)
else:
w_d = None
for net in self.nets_pool[:self.num_per_rnd]:
net.model.load_state_dict(dict_new)
net.set_lr(self.global_agent.lr)
pool.apply_async(train_local_mp, (net, self.local_epochs, rnd, self.q, self.policy, w_d))
pool.close()
pool.join()
self.update_global(rnd)

def main():
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
mp.set_start_method('forkserver')

cifar_trainer = CIFARTrainer(args)

# test
if args.mode == 'test':
cifar_trainer.test()
return

cifar_trainer.build_local_models(args)
cifar_trainer.train()

if __name__ == '__main__':
args = config.get_args()
args.fine = 'CIFAR10'
main()
40 changes: 40 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#pylint: disable=C0301,C0326
import argparse

def get_args():
parser = argparse.ArgumentParser()

parser.add_argument('--model_file', type=str, default='model.pth.tar', help='File to save model.')
parser.add_argument('--model_dir', type=str, default='models', help='Directory for storing checkpoint file.')
parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to resume checkpoint (default: none)')
parser.add_argument('--mode', type=str, default='train', choices=('train', 'test'), help='train or test.')
parser.add_argument('--log_dir', type=str, default='runs_attn', help='Directory for logging.')
parser.add_argument('--gpu', type=str, default='0', help='Number of gpu to use')
parser.add_argument('--seed', type=int, default=1234, help='Random seed')

# hyper parameter for local data and training
parser.add_argument('--distr_type', type=str, default='uniform', choices=('uniform', 'dirichlet'), help='Distribution to construct local data.')
parser.add_argument('--alpha', type=float, default=1., help='alpha for dirichlet distribution. Must > 0 if dirichlet distribution is chosen.')
parser.add_argument('--lr', type=float, default=5e-3, help='learning rate.')
parser.add_argument('--min_lr', type=float, default=1e-4, help='minimum learning rate.')
parser.add_argument('--decay_rate', type=float, default=0.99, help='lr decay rate.')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size. (B)')
parser.add_argument('--local_epochs', type=int, default=2, help='Number of epoch in local. (E)')
parser.add_argument('--num_workers', type=int, default=0, help='number of workers to preprocess data, must be 0 for mp agents.')

# hyper parameters for central server
parser.add_argument('--num_locals', type=int, default=10, help='number of local agents.')
parser.add_argument('--num_per_rnd', type=int, default=2, help='number of local agents to train per round.')
parser.add_argument('--rounds', type=int, default=500, help='number of communication rounds.')
parser.add_argument('--sample_rate', type=float, default=-1., help='sample rate of central data.')
parser.add_argument('--policy', type=str, default='avg', choices=('avg', 'ewc', 'mas'), help='Policy for estimating parameter importance.')
parser.add_argument('--estimate_weights_in_center', action='store_true', help='Estimate parameter importance in central server.')

# hyper parameters for ewc train
parser.add_argument('--coe', type=float, default=0.5, help='The coefficient for local additional constraint.')
parser.add_argument('--interval', type=float, default=1, help='The interval for weight estimation.')

args = parser.parse_args()
return args
2 changes: 2 additions & 0 deletions core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .agent import Agent
from .trainer import Trainer, train_local_mp, test_local_mp
Loading

0 comments on commit 4b16e10

Please sign in to comment.