Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worker API #101

Merged
merged 5 commits into from
Mar 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .babelrc
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
{
"presets": ["@babel/preset-env"],
"plugins": ["@babel/plugin-proposal-class-properties"]
"plugins": [
"@babel/plugin-proposal-class-properties"
],
"env": {
"test": {
"plugins": [
"@babel/plugin-transform-runtime",
"@babel/plugin-proposal-class-properties"
]
}
}
}
3 changes: 2 additions & 1 deletion .eslintignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
node_modules
dist
dist
examples
Binary file added examples/with-grid/data/model_params.pb
Binary file not shown.
Binary file added examples/with-grid/data/tp_ops.pb
Binary file not shown.
19 changes: 17 additions & 2 deletions examples/with-grid/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
<!-- NOTE: TFJS version must match with one in package-lock.json -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.5/dist/tf.min.js"></script>
<script src="https://webrtc.github.io/adapter/adapter-latest.js"></script>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
</head>
<body>
<img
Expand Down Expand Up @@ -72,8 +73,22 @@ <h1>syft.js/grid.js testing</h1>
>.
</p>
<input type="text" id="grid-server" value="ws://localhost:3000" />
<input type="text" id="protocol" value="10000000013" />
<button id="connect">Connect to grid.js server</button>
<!-- <input type="text" id="protocol" value="10000000013" />-->
<!-- <button id="connect">Connect to grid.js server</button>-->
<button id="start">Start FL Worker</button>

<div id="fl-training" style="display: none">
<div style="display: table-row">
<div style="display: table-cell">
<div id="loss_graph"></div>
</div>

<div style="display: table-cell">
<div id="acc_graph"></div>
</div>
</div>
</div>

<div id="app">
<button id="disconnect">Disconnect</button>
<p id="identity"></p>
Expand Down
148 changes: 147 additions & 1 deletion examples/with-grid/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,21 @@ import {
import * as tf from '@tensorflow/tfjs-core';

// In the real world: import syft from 'syft.js';
import Syft from '../../src';
import { Syft } from '../../src';
import { MnistData } from './mnist';

const gridServer = document.getElementById('grid-server');
const protocol = document.getElementById('protocol');
const connectButton = document.getElementById('connect');
const startButton = document.getElementById('start');
const disconnectButton = document.getElementById('disconnect');
const appContainer = document.getElementById('app');
const textarea = document.getElementById('message');
const submitButton = document.getElementById('message-send');

appContainer.style.display = 'none';

/*
connectButton.onclick = () => {
appContainer.style.display = 'block';
gridServer.style.display = 'none';
Expand All @@ -46,6 +49,149 @@ connectButton.onclick = () => {

startSyft(gridServer.value, protocol.value);
};
*/

startButton.onclick = () => {
setFLUI();
startFL(gridServer.value, 'model-id');
};

const startFL = async (url, modelName) => {
const worker = new Syft({ url, verbose: true });
const job = await worker.newJob({ modelName });

job.start();

job.on('accepted', async ({ model, clientConfig }) => {
// Load data
console.log('Loading data...');
const mnist = new MnistData();
await mnist.load();
const trainDataset = mnist.getTrainData();
const data = trainDataset.xs;
const targets = trainDataset.labels;
console.log('Data loaded');

// Prepare train parameters.
const batchSize = clientConfig.batch_size;
const lr = clientConfig.lr;
const numBatches = Math.ceil(data.shape[0] / batchSize);

// Calculate total number of model updates
// in case none of these options specified, we fallback to one loop
// though all batches.
const maxEpochs = clientConfig.max_epochs || 1;
const maxUpdates = clientConfig.max_updates || maxEpochs * numBatches;
const numUpdates = Math.min(maxUpdates, maxEpochs * numBatches);

// Make copies of model params.
let modelParams = [];
for (let param of model.params) {
modelParams.push(param.clone());
}

// Main training loop.
for (let update = 0, batch = 0, epoch = 0; update < numUpdates; update++) {
// Slice a batch.
const chunkSize = Math.min(batchSize, data.shape[0] - batch * batchSize);
const dataBatch = data.slice(batch * batchSize, chunkSize);
const targetBatch = targets.slice(batch * batchSize, chunkSize);

// Execute the plan and get updated model params back.
let [loss, acc, ...updatedModelParams] = await job.plans[
'training_plan'
].execute(
job.worker,
dataBatch,
targetBatch,
chunkSize,
lr,
...modelParams
);

// Use updated model params in the next cycle.
for (let i = 0; i < modelParams.length; i++) {
modelParams[i].dispose();
modelParams[i] = updatedModelParams[i];
}

await updateUIAfterBatch({
epoch,
batch,
accuracy: (await acc.data())[0],
loss: (await loss.data())[0]
});

batch++;

// Check if we're out of batches (end of epoch).
if (batch === numBatches) {
batch = 0;
epoch++;
}

// Free GPU memory.
acc.dispose();
loss.dispose();
dataBatch.dispose();
targetBatch.dispose();
}

// TODO protocol execution
// job.protocols['secure_aggregation'].execute();

// Calc model diff.
const modelDiff = [];
for (let i = 0; i < modelParams.length; i++) {
modelDiff.push(model.params[i].sub(modelParams[i]));
}

// Report diff.
await job.report(modelDiff);
console.log('Done!');
});

job.on('rejected', ({ timeout }) => {
// Handle the job rejection
console.log('We have been rejected by PyGrid to participate in the job.');
const msUntilRetry = timeout * 1000;
// Try to join the job again in "msUntilRetry" milliseconds
setTimeout(job.start.bind(job), msUntilRetry);
});

job.on('error', err => {
console.log('Error', err);
});
};

const setFLUI = () => {
Plotly.newPlot(
'loss_graph',
[{ y: [], mode: 'lines', line: { color: '#80CAF6' } }],
{ title: 'Train Loss', showlegend: false },
{ staticPlot: true }
);

Plotly.newPlot(
'acc_graph',
[{ y: [], mode: 'lines', line: { color: '#80CAF6' } }],
{ title: 'Train Accuracy', showlegend: false },
{ staticPlot: true }
);

document.getElementById('fl-training').style.display = 'table';
};

const updateUIAfterBatch = async ({ epoch, batch, accuracy, loss }) => {
console.log(
`Epoch: ${epoch}, Batch: ${batch}, Accuracy: ${accuracy}, Loss: ${loss}`
);
Plotly.extendTraces('loss_graph', { y: [[loss]] }, [0]);
Plotly.extendTraces('acc_graph', { y: [[accuracy]] }, [0]);
await tf.nextFrame();
};

// ---------------------------- OLD -------------------------------

const startSyft = (url, protocolId) => {
const workerId = getQueryVariable('worker_id');
Expand Down
163 changes: 163 additions & 0 deletions examples/with-grid/mnist.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-core';

export const IMAGE_H = 28;
export const IMAGE_W = 28;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const NUM_TRAIN_ELEMENTS = 55000;
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH =
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

/**
* A class that fetches the sprited MNIST dataset and provide data as
* tf.Tensors.
*/
export class MnistData {
cereallarceny marked this conversation as resolved.
Show resolved Hide resolved
constructor() {}

async load() {
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;

const datasetBytesBuffer = new ArrayBuffer(
NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4
);

const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;

for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer,
i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize
);
ctx.drawImage(
img,
0,
i * chunkSize,
img.width,
chunkSize,
0,
0,
img.width,
chunkSize
);

const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);

resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});

const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] = await Promise.all([
imgRequest,
labelsRequest
]);

this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

// Slice the the images and labels into train and test sets.
this.trainImages = this.datasetImages.slice(
0,
IMAGE_SIZE * NUM_TRAIN_ELEMENTS
);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels = this.datasetLabels.slice(
0,
NUM_CLASSES * NUM_TRAIN_ELEMENTS
);
this.testLabels = this.datasetLabels.slice(
NUM_CLASSES * NUM_TRAIN_ELEMENTS
);
}

/**
* Get all training data as a data tensor and a labels tensor.
*
* @returns
* xs: The data tensor, of shape `[numTrainExamples, 784]`.
* labels: The one-hot encoded labels tensor, of shape
* `[numTrainExamples, 10]`.
*/
getTrainData() {
const xs = tf.tensor2d(this.trainImages, [
this.trainImages.length / IMAGE_SIZE,
IMAGE_H * IMAGE_W
]);
const labels = tf.tensor2d(this.trainLabels, [
this.trainLabels.length / NUM_CLASSES,
NUM_CLASSES
]);
return { xs, labels };
}

/**
* Get all test data as a data tensor a a labels tensor.
*
* @param {number} numExamples Optional number of examples to get. If not
* provided,
* all test examples will be returned.
* @returns
* xs: The data tensor, of shape `[numTestExamples, 784]`.
* labels: The one-hot encoded labels tensor, of shape
* `[numTestExamples, 10]`.
*/
getTestData(numExamples = NUM_TEST_ELEMENTS) {
let xs = tf.tensor2d(this.testImages, [
this.testImages.length / IMAGE_SIZE,
IMAGE_H * IMAGE_W
]);
let labels = tf.tensor2d(this.testLabels, [
this.testLabels.length / NUM_CLASSES,
NUM_CLASSES
]);

if (numExamples != null) {
xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H * IMAGE_W]);
labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
}
return { xs, labels };
}
}
Loading