Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Dec 18, 2023
1 parent 18b692e commit 10899f2
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
1 change: 1 addition & 0 deletions examples/sklearn-logreg-mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ python3 client.py --node-id 1 # or any integer in {0-9}
```

Alternatively, you can run all of it in one shell as follows:

```bash
bash run.sh
```
Expand Down
6 changes: 2 additions & 4 deletions examples/sklearn-logreg-mnist/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
# Split the on edge data: 80% train, 20% test
X_train, X_test = X[:int(0.8 * len(X))], X[int(0.8 * len(X)):]
y_train, y_test = y[:int(0.8 * len(y))], y[int(0.8 * len(y)):]
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]

# Create LogisticRegression Model
model = LogisticRegression(
Expand All @@ -41,7 +41,6 @@
# Setting initial parameters, akin to model.compile for keras models
utils.set_initial_params(model)


# Define Flower client
class MnistClient(fl.client.NumPyClient):
def get_parameters(self, config): # type: ignore
Expand All @@ -62,6 +61,5 @@ def evaluate(self, parameters, config): # type: ignore
accuracy = model.score(X_test, y_test)
return loss, len(X_test), {"accuracy": accuracy}


# Start Flower client
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MnistClient())
4 changes: 1 addition & 3 deletions examples/sklearn-logreg-mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def get_model_parameters(model: LogisticRegression) -> NDArrays:
return params


def set_model_params(
model: LogisticRegression, params: NDArrays
) -> LogisticRegression:
def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression:
"""Sets the parameters of a sklean LogisticRegression model."""
model.coef_ = params[0]
if model.fit_intercept:
Expand Down

0 comments on commit 10899f2

Please sign in to comment.