-
Notifications
You must be signed in to change notification settings - Fork 9
/
unit_test.py
93 lines (83 loc) · 3.5 KB
/
unit_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from client.trainer.fedproxTrainer import fedproxTrainer
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch
from collections import defaultdict
if __name__ =="__main__":
#-----------Unit test for serverSimulator.py----------------
#This test contains
#1. The functional of the interact with the server(upload and download model)
#2. The Factory of aggregator for user to choose the fl algorithm (TODO)
#3. Args passed into the server.
#4.
# from server.serverSimulator import serverSimulator
# from server.base.baseAggregator import ServerAggregator
# import torch.nn as nn
# class LinearModel(nn.Module):
# def __init__(self, h_dims):
# super(LinearModel,self).__init__()
# models = []
# for i in range(len(h_dims) - 1):
# models.append(nn.Linear(h_dims[i], h_dims[i + 1]))
# if i != len(h_dims) - 2:
# models.append(nn.ReLU())
# self.models = nn.Sequential(*models)
# def forward(self, X):
# return self.models(X)
# test_sample_pool = [LinearModel([10,10]) for i in range(10)]
# a = ServerAggregator()
# server = serverSimulator(a)
# ab = server.download_model()
# print(ab)
# for i in test_sample_pool:
# upload_param = {'state_dict': i.state_dict()}
# server.upload_model(upload_param)
#-----------Unit test for client.py----------------
# config = {
# 'client_id': 1,
# 'device' : 'cuda',
# 'lr': 0.001,
# 'weight_decay' : 0,
# 'batch_size': 64
# }
# train_dataset = FashionMNIST('./data',train=True,download=True,transform=ToTensor())
# test_dataset = FashionMNIST('./data',train=False,download=True,transform=ToTensor())
# train_dataloader = DataLoader(train_dataset,batch_size=config['batch_size'])
# class GNN_model(nn.Module):
# def __init__(self):
# super(GNN_model,self).__init__()
# self.flatten = nn.Flatten()
# self.linear_relu_stack = nn.Sequential(
# nn.Conv2d(in_channels=1,out_channels=4,kernel_size=4,padding=2),
# nn.ReLU(),
# nn.MaxPool2d(kernel_size=2,stride=2),
# nn.Conv2d(in_channels=4,out_channels=8,kernel_size=4,padding=2),
# nn.ReLU(),
# nn.MaxPool2d(kernel_size=2,stride=2),
# nn.Flatten(),
# nn.Linear(8*7*7,10),
# )
# def forward(self,x):
# logits = self.linear_relu_stack(x)
# return logits
# model = GNN_model().to(config['device'])
# print(model)
# loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD
# trainer = fedproxTrainer(model=model,
# dataloader=train_dataloader,
# criterion=loss_fn,
# optimizer=optimizer,
# config=config)
#-----------Unit test DataSpliter.py----------------
from dataset.DatasetFactory import DatasetFactory
from dataset.DatasetSpliter import DatasetSpliter
dataset = DatasetFactory().get_dataset("FashionMNIST")
client_list = defaultdict(int)
client_list["1"] = 1
client_list["2"] = 2
client_list["3"] = 3
client_list["4"] = 4
dataloaders = DatasetSpliter().dirichlet_split(dataset, client_list)