Skip to content

Commit

Permalink
Update Opacus example to use latest version (#2112)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 13, 2023
1 parent 344c242 commit d706abe
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 21 deletions.
2 changes: 1 addition & 1 deletion e2e/opacus/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ authors = ["The Flower Authors <hello@flower.dev>"]
[tool.poetry.dependencies]
python = "^3.8"
flwr = { path = "../../", develop = true, extras = ["simulation"] }
opacus = "1.4.0"
opacus = "^1.4.0"
torch = "^1.13.1"
torchvision = "^0.14.0"
2 changes: 1 addition & 1 deletion examples/opacus/dp_cifar_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def load_data():
model = Net()
trainloader, testloader, sample_rate = load_data()
fl.client.start_numpy_client(
"127.0.0.1:8080", client=DPCifarClient(model, trainloader, testloader, sample_rate)
server_address="127.0.0.1:8080", client=DPCifarClient(model, trainloader, testloader)
)
24 changes: 11 additions & 13 deletions examples/opacus/dp_cifar_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def train(net, trainloader, privacy_engine, epochs):
def train(net, trainloader, privacy_engine, optimizer, epochs):
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
privacy_engine.attach(optimizer)
for _ in range(epochs):
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
epsilon, _ = optimizer.privacy_engine.get_privacy_spent(
PRIVACY_PARAMS["target_delta"]
epsilon = privacy_engine.get_epsilon(
delta=PRIVACY_PARAMS["target_delta"]
)
return epsilon

Expand All @@ -79,16 +77,16 @@ def test(net, testloader):

# Define Flower client.
class DPCifarClient(fl.client.NumPyClient):
def __init__(self, model, trainloader, testloader, sample_rate) -> None:
def __init__(self, model, trainloader, testloader) -> None:
super().__init__()
self.model = model
self.trainloader = trainloader
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
self.testloader = testloader
# Create a privacy engine which will add DP and keep track of the privacy budget.
self.privacy_engine = PrivacyEngine(
self.model,
sample_rate=sample_rate,
target_delta=PRIVACY_PARAMS["target_delta"],
self.privacy_engine = PrivacyEngine()
self.model, self.optimizer, self.trainloader = self.privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=trainloader,
max_grad_norm=PRIVACY_PARAMS["max_grad_norm"],
noise_multiplier=PRIVACY_PARAMS["noise_multiplier"],
)
Expand All @@ -104,7 +102,7 @@ def set_parameters(self, parameters):
def fit(self, parameters, config):
self.set_parameters(parameters)
epsilon = train(
self.model, self.trainloader, self.privacy_engine, PARAMS["local_epochs"]
self.model, self.trainloader, self.privacy_engine, self.optimizer, PARAMS["local_epochs"]
)
print(f"epsilon = {epsilon:.2f}")
return (
Expand Down
4 changes: 1 addition & 3 deletions examples/opacus/dp_cifar_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def client_fn(cid: str) -> fl.client.Client:
client_trainloader = DataLoader(client_trainset, PARAMS["batch_size"])
client_testloader = DataLoader(client_testset, PARAMS["batch_size"])

sample_rate = PARAMS["batch_size"] / len(client_trainset)

return DPCifarClient(model, client_trainloader, client_testloader, sample_rate)
return DPCifarClient(model, client_trainloader, client_testloader)


# Define an evaluation function for centralized evaluation (using whole CIFAR10 testset).
Expand Down
6 changes: 3 additions & 3 deletions examples/opacus/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ authors = ["The Flower Authors <hello@flower.dev>"]

[tool.poetry.dependencies]
python = "^3.8"
flwr = "^0.17.0"
flwr = "^1.0.0"
# flwr = { path = "../../", develop = true } # Development
opacus = "~0.14.0"
opacus = "^1.4.0"
torch = "^1.13.1"
torchvision = "~0.8.0"
torchvision = "^0.14.0"

0 comments on commit d706abe

Please sign in to comment.