Skip to content

Commit

Permalink
Refactor + add correct CNN model for CIFAR
Browse files Browse the repository at this point in the history
  • Loading branch information
flydump committed Dec 8, 2024
1 parent 600f1e9 commit e5874c2
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 342 deletions.
77 changes: 42 additions & 35 deletions baselines/fedlc/fedlc/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,56 @@
import torch

from fedlc.dataset import load_data
from fedlc.model import get_parameters, initialize_model, set_parameters, train
from fedlc.model import get_parameters, CNNModel, set_parameters, train
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context


class LogitCorrectedLoss(torch.nn.CrossEntropyLoss):
def __init__(self, logits_correction: torch.Tensor):
def __init__(
self,
num_classes,
labels,
tau,
device,
):
super().__init__()
self.logits_correction = logits_correction
class_count = torch.zeros(num_classes).long()
labels, counts = labels.unique(
sorted=True, return_counts=True, return_inverse=False
)
class_count[labels] = counts
class_count = class_count.to(device)
self.correction = tau * class_count.pow(-0.25)

def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# Modify the logits before cross entropy loss
corrected_logits = logits - self.logits_correction
corrected_logits = logits - self.correction
return super().forward(corrected_logits, target)


def calc_logit_correction(net, labels, device) -> torch.Tensor:
num_classes = net.fc.out_features
class_count = torch.zeros(num_classes).long()
labels, counts = labels.unique(
sorted=True, return_counts=True, return_inverse=False
)
class_count[labels] = counts
class_count = class_count.to(device)
return class_count


class FlowerClient(NumPyClient):
def __init__(
self, net, trainloader, labels, local_epochs, use_lc, tau, learning_rate, device
self,
net,
trainloader,
labels,
local_epochs,
tau,
learning_rate,
device
):
self.net = net
self.trainloader = trainloader
self.local_epochs = local_epochs
self.learning_rate = learning_rate
self.use_lc = use_lc
self.device = device
if self.use_lc:
logits_correction = tau * calc_logit_correction(net, labels, device).pow(-0.25)
self.criterion = LogitCorrectedLoss(logits_correction)
use_lc = tau > 0.0
if use_lc:
self.criterion = LogitCorrectedLoss(
net.fc.out_features, # num_classes
labels,
tau,
device,
)
else:
self.criterion = torch.nn.CrossEntropyLoss()

Expand All @@ -69,26 +79,23 @@ def fit(self, parameters, config):

def client_fn(context: Context):
"""Construct a Client that will be run in a ClientApp."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = int(context.run_config["num-classes"])
num_channels = int(context.run_config["num-channels"])
model_name = str(context.run_config["model-name"])
net = initialize_model(model_name, num_channels, num_classes)

trainloader, labels = load_data(context)
trainloader, labels, num_classes = load_data(context)
net = CNNModel(num_classes)

local_epochs = int(context.run_config["local-epochs"])
learning_rate = float(context.run_config["learning-rate"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tau = float(context.run_config["tau"])
use_lc = bool(context.run_config["use-logit-correction"])

# Return Client instance
return FlowerClient(
net, trainloader, labels, local_epochs, use_lc, tau, learning_rate, device
net,
trainloader,
labels,
local_epochs,
tau,
learning_rate,
device
).to_client()


# Flower ClientApp
app = ClientApp(client_fn)
86 changes: 55 additions & 31 deletions baselines/fedlc/fedlc/dataset.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,46 @@
"""fedlc: A Flower Baseline."""

from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from flwr_datasets.partitioner import DirichletPartitioner, ShardPartitioner
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor, RandomCrop, RandomHorizontalFlip
from torchvision.transforms import (
Compose,
Normalize,
RandomCrop,
RandomHorizontalFlip,
ToTensor,
)

from flwr.common import Context

FDS = None # Cache FederatedDataset
from .utils import get_ds_info

FDS = None # Cache FederatedDataset

def get_data_transforms(dataset: str):
if dataset == "cifar10":
tfms = Compose(
[
RandomCrop(32, padding=4),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
]
)
else:
raise ValueError("Only cifar10 is supported!")
return tfms
CIFAR_MEAN_STD = {
"cifar10": ([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
"cifar100": ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
}

def get_data_transforms(dataset: str, split: str):
"""Return data transforms for dataset."""
tfms = []
if split == "train":
tfms = [
RandomCrop(32, padding=4),
RandomHorizontalFlip(),
]

tfms.extend(
[
ToTensor(),
Normalize(*CIFAR_MEAN_STD[dataset]),
]
)
return Compose(tfms)


def get_transforms_apply_fn(transforms, partition_by):
def _get_transforms_apply_fn(transforms, partition_by):
def apply_transforms(batch):
batch["img"] = [transforms(img) for img in batch["img"]]
batch["label"] = batch[partition_by]
Expand All @@ -34,9 +49,10 @@ def apply_transforms(batch):
return apply_transforms


def get_transformed_ds(ds, dataset_name, partition_by) -> Dataset:
tfms = get_data_transforms(dataset_name)
transform_fn = get_transforms_apply_fn(tfms, partition_by)
def get_transformed_ds(ds, dataset_name, partition_by, split) -> Dataset:
"""Return dataset with transformations applied."""
tfms = get_data_transforms(dataset_name, split)
transform_fn = _get_transforms_apply_fn(tfms, partition_by)
return ds.with_transform(transform_fn)


Expand All @@ -45,35 +61,43 @@ def load_data(context: Context):
Only used for client-side training.
"""
dirichlet_alpha = float(context.run_config["dirichlet-alpha"])
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
dataset = str(context.run_config["dataset"])
num_classes, partition_by = get_ds_info(dataset)
batch_size = int(context.run_config["batch-size"])
partition_by = str(context.run_config["dataset-partition-by"])
dirichlet_alpha = float(context.run_config["dirichlet-alpha"])
skew_type = str(context.run_config["skew-type"])
num_shards_per_partition = int(context.run_config["num-shards-per-partition"])

# Only initialize `FederatedDataset` once
global FDS # pylint: disable=global-statement

if FDS is None:
dirichlet_partitioner = DirichletPartitioner(
num_partitions=num_partitions,
alpha=dirichlet_alpha,
partition_by=partition_by,
min_partition_size=10,
)
if skew_type == "distribution":
partitioner = DirichletPartitioner(
num_partitions=num_partitions,
alpha=dirichlet_alpha,
partition_by=partition_by,
)
elif skew_type == "quantity":
partitioner = ShardPartitioner(
num_partitions=num_partitions,
partition_by=partition_by,
num_shards_per_partition=num_shards_per_partition,
)
FDS = FederatedDataset(
dataset=dataset,
partitioners={"train": dirichlet_partitioner},
partitioners={"train": partitioner},
)

train_partition = FDS.load_partition(partition_id)
train_partition.set_format("torch")

trainloader = DataLoader(
get_transformed_ds(train_partition, dataset, partition_by),
get_transformed_ds(train_partition, dataset, partition_by, split="train"),
batch_size=batch_size,
shuffle=True,
drop_last=False,
)
return trainloader, train_partition["label"]
return trainloader, train_partition["label"], num_classes
74 changes: 34 additions & 40 deletions baselines/fedlc/fedlc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,46 @@
from collections import OrderedDict

import torch
import torchvision


# Adapted from FedDebug baseline implementation
# https://github.com/adap/flower/blob/main/baselines/feddebug/feddebug/models.py
def initialize_model(name, num_channels, num_classes):
"""Initialize the model with the given name."""
model_functions = {
"resnet18": lambda: torchvision.models.resnet18(),
"resnet34": lambda: torchvision.models.resnet34(),
"resnet50": lambda: torchvision.models.resnet50(),
"resnet101": lambda: torchvision.models.resnet101(),
"resnet152": lambda: torchvision.models.resnet152(),
"vgg16": lambda: torchvision.models.vgg16(),
}
model = model_functions[name]()
# Modify model for grayscale input if necessary
if num_channels == 1:
if name.startswith("resnet"):
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=7, stride=2, padding=3, bias=False
)
elif name == "vgg16":
model.features[0] = torch.nn.Conv2d(
1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)

# Modify final layer to match the number of classes
if name.startswith("resnet"):
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, num_classes)
elif name == "vgg16":
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = torch.nn.Linear(num_ftrs, num_classes)
return model
import torch.nn as nn

class CNNModel(nn.Module):
"""CNN model as described in Appendix of FedLC paper"""

def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3,128,3)
self.bn1 = nn.BatchNorm2d(128)
self.conv2 = nn.Conv2d(128,128,3)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 128,3)
self.bn3 = nn.BatchNorm2d(128)

self.fc = nn.Linear(512, num_classes)

def forward(self, x):
"""Forward pass."""
x = nn.functional.relu(self.bn1(self.conv1(x)))
x = nn.functional.max_pool2d(x, 2, stride=2)

x = nn.functional.relu(self.bn2(self.conv2(x)))
x = nn.functional.max_pool2d(x, 2, stride=2)

x = nn.functional.relu(self.bn3(self.conv3(x)))
x = nn.functional.max_pool2d(x, 2, stride=2)

x = x.view(x.shape[0], -1)
x = self.fc(x)

return x


def train(net, trainloader, epochs, device, learning_rate, criterion):
"""Train the model on the training set."""
net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
net.train()
running_loss = 0.0

for _ in range(epochs):
for batch in trainloader:
images = batch["img"]
Expand Down Expand Up @@ -88,7 +85,4 @@ def set_parameters(net, parameters):
"""Apply parameters to an existing model."""
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)


# implementation from DASHA paper
net.load_state_dict(state_dict, strict=True)
Loading

0 comments on commit e5874c2

Please sign in to comment.