diff --git a/examples/train-web/readme.md b/examples/train-web/readme.md index 736b912544..9244941dff 100644 --- a/examples/train-web/readme.md +++ b/examples/train-web/readme.md @@ -6,7 +6,9 @@ For example, run `cargo install --version 0.2.88 wasm-bindgen-cli --force`. The Install [PNPM](https://pnpm.io/). -The [`postinstall.sh`](./web/postinstall.sh) script expects the mnist database to be at `~/.cache/burn-dataset/mnist.db`. Running `guide` will generate this file. Alternatively, you can download it from [Hugging Face](https://huggingface.co/datasets/mnist). +Install [cargo-watch](https://crates.io/crates/cargo-watch). + +The [`postinstall.sh`](./web/postinstall.sh) script expects the mnist database to be at `~/.cache/burn-dataset/mnist.db`. Running `burn/examples/guide` will generate this file. Alternatively, you can download it from [Hugging Face](https://huggingface.co/datasets/mnist). Then in separate terminals: diff --git a/examples/train-web/train/src/mnist.rs b/examples/train-web/train/src/mnist.rs index 41fb40b0e8..b6f74ebf93 100644 --- a/examples/train-web/train/src/mnist.rs +++ b/examples/train-web/train/src/mnist.rs @@ -71,6 +71,8 @@ impl Dataset for MNISTDataset { impl MNISTDataset { /// Creates a new dataset. pub fn new(labels: &[u8], images: &[u8], lengths: &[u16]) -> Self { + // Decoding is here. + // Encoding is done at `examples/train-web/web/src/train.ts`. debug_assert!(labels.len() == lengths.len()); let mut start = 0 as usize; let raws = labels diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index 53e3109698..aec82c9d5c 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -34,6 +34,12 @@ export function setupTrain(element: HTMLInputElement) { } async function loadSqliteAndRun(ab: ArrayBuffer) { + // Images are an array of arrays. + // We can't send an array of arrays to Wasm. + // So instead we merge images into a single large array and + // use another array, `lengths`, to keep track of the image size. + // Encoding is done here. + // Decoding is done at `burn/examples/train-web/train/src/mnist.rs`. const trainImages: Uint8Array[] = [] const trainLabels: number[] = [] const trainLengths: number[] = []