Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: burn-train in the browser #938

Closed
wants to merge 87 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
5a030b9
pnpm create vite train-web --template vanilla-ts
AlexErrant Nov 2, 2023
2f09310
pnpm i
AlexErrant Nov 2, 2023
408d740
pnpm up -rL
AlexErrant Nov 2, 2023
8eab3a6
pnpm i -D prettier
AlexErrant Nov 2, 2023
74b9b57
add .prettierrc
AlexErrant Nov 2, 2023
86c6a06
pnpm exec prettier . --write
AlexErrant Nov 2, 2023
bddd331
move everything to web folder
AlexErrant Nov 2, 2023
501b360
update workspace
AlexErrant Nov 2, 2023
9ff1615
cargo new train --lib
AlexErrant Nov 2, 2023
0a279de
pnpm add ../train/pkg
AlexErrant Nov 3, 2023
0af481a
add vite config
AlexErrant Nov 3, 2023
2740d4c
can call train's run from web
AlexErrant Nov 4, 2023
cc423e6
add train.rs
AlexErrant Nov 4, 2023
40f4140
make dev friendly
AlexErrant Nov 4, 2023
a77829d
implemented spawn using webworkers
AlexErrant Nov 9, 2023
2e074e7
update notices
AlexErrant Nov 9, 2023
1cf8dd1
Arc => Rc
AlexErrant Nov 9, 2023
f688a3d
clippy
AlexErrant Nov 9, 2023
13f26c9
fix CI
AlexErrant Nov 9, 2023
7fab3a0
Merge branch 'main' into train-browser
AlexErrant Nov 12, 2023
11cb46b
pnpm i sql.js
AlexErrant Nov 13, 2023
5fa4463
add postinstall to cp sql-wasm.wasm into assets
AlexErrant Nov 13, 2023
94d156d
can load mnist data in js and send to rust
AlexErrant Nov 13, 2023
66a7ffd
extract init
AlexErrant Nov 14, 2023
7735e47
autoload/run mnist on pageload
AlexErrant Nov 14, 2023
2d5f8c0
clean up resources
AlexErrant Nov 14, 2023
3d863e7
can create a DataLoader
AlexErrant Nov 15, 2023
0b0fd86
can load train and test
AlexErrant Nov 15, 2023
6331ede
copied over model
AlexErrant Nov 15, 2023
469e270
copy over guide's `training.rs`, more or less
AlexErrant Nov 15, 2023
939f3d0
docs
AlexErrant Nov 15, 2023
437cf00
assert is valid png
AlexErrant Nov 15, 2023
e3c84a6
add pool from https://github.com/rustwasm/wasm-bindgen/blob/main/exam…
AlexErrant Nov 22, 2023
9a355f2
pool works
AlexErrant Nov 24, 2023
ea1cfe1
Merge branch 'main' into train-browser
AlexErrant Nov 29, 2023
70e1360
fix conflict
AlexErrant Nov 29, 2023
a1f5828
it builds
AlexErrant Nov 30, 2023
9f2285e
add rayon and wasm-bindgen-rayon
AlexErrant Dec 1, 2023
6d926e8
replace with rayon
AlexErrant Dec 1, 2023
d2672b4
Merge branch 'main' into train-browser
AlexErrant Dec 1, 2023
8035870
fix conflicts
AlexErrant Dec 1, 2023
7b27871
nix license notice
AlexErrant Dec 1, 2023
6a76157
Merge branch 'main' into train-browser
AlexErrant Dec 1, 2023
da36016
fix conflicts
AlexErrant Dec 1, 2023
9aa0439
fix
AlexErrant Dec 1, 2023
051b41f
fix?
AlexErrant Dec 1, 2023
34d6b3a
Merge branch 'main' into train-browser
AlexErrant Dec 1, 2023
7a8372c
fix??
AlexErrant Dec 1, 2023
2c9dc06
runchecks runs each package separately
AlexErrant Dec 2, 2023
1ee5cef
skip train-web
AlexErrant Dec 2, 2023
3f52a78
fix deps
AlexErrant Dec 3, 2023
f188b7a
converted train to a worker (as to not block main thread)
AlexErrant Nov 16, 2023
37a748e
add autotrain checkbox
AlexErrant Nov 16, 2023
fa63f6d
MetricsRenderer logs
AlexErrant Nov 16, 2023
f8ab1d4
add to_bytes
AlexErrant Dec 3, 2023
9e9dfb6
comment out off by one check
AlexErrant Nov 16, 2023
0a9a467
don't hardcode dirs
AlexErrant Dec 3, 2023
cfffb93
oops
AlexErrant Dec 3, 2023
d4106bb
fmt
AlexErrant Dec 3, 2023
c84d171
exclude `examples/train-web/train` less
AlexErrant Dec 4, 2023
3541503
I regret everything (in this file)
AlexErrant Dec 4, 2023
65724c8
Merge branch 'main' into train-browser
AlexErrant Dec 4, 2023
15af6fd
fix ci
AlexErrant Dec 4, 2023
49e2c38
rustup component add rust-src --toolchain nightly-2023-07-01-x86_64-u…
AlexErrant Dec 4, 2023
519b752
0
AlexErrant Dec 4, 2023
6d8c721
Revert "0"
AlexErrant Dec 4, 2023
c025d35
Revert "rustup component add rust-src --toolchain nightly-2023-07-01-…
AlexErrant Dec 4, 2023
b0fbeef
can I just get a ✔ please
AlexErrant Dec 5, 2023
57472fe
Revert "Revert "rustup component add rust-src --toolchain nightly-202…
AlexErrant Dec 4, 2023
8fa5a03
not always linux
AlexErrant Dec 5, 2023
1b6d619
1
AlexErrant Dec 5, 2023
c8575e3
2
AlexErrant Dec 5, 2023
de55ae8
Revert "can I just get a ✔ please"
AlexErrant Dec 5, 2023
bebab5e
3
AlexErrant Dec 5, 2023
64a11b5
4
AlexErrant Dec 5, 2023
944e53a
5
AlexErrant Dec 5, 2023
1bddd63
6
AlexErrant Dec 5, 2023
e56d489
Merge branch 'main' into train-browser
AlexErrant Dec 23, 2023
218e17d
fix breaking changes
AlexErrant Dec 24, 2023
b96df97
make windows work?
AlexErrant Dec 24, 2023
a92392f
Merge branch 'main' into train-browser
AlexErrant Jan 28, 2024
d3cabca
reject CI changes
AlexErrant Jan 28, 2024
2d0a488
less CI changes
AlexErrant Jan 28, 2024
f053b4d
bump wasm-bindgen
AlexErrant Jan 29, 2024
116d3f5
add generic
AlexErrant Jan 29, 2024
3b53c32
it builds
AlexErrant Jan 29, 2024
a527fae
add AsyncTask(Boxed)
AlexErrant Jan 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ members = [
"burn-train",
"xtask",
"examples/*",
"examples/train-web/train",
"backend-comparison",
]

exclude = ["examples/notebook"]
exclude = ["examples/notebook", "examples/train-web"]

[workspace.dependencies]
async-trait = "0.1.73"
Expand Down Expand Up @@ -66,7 +67,7 @@ thiserror = "1.0.49"
tracing-appender = "0.2.2"
tracing-core = "0.1.31"
tracing-subscriber = "0.3.17"
wasm-bindgen = "0.2.87"
wasm-bindgen = "=0.2.88"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2.0"

Expand Down
24 changes: 24 additions & 0 deletions NOTICES.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

## Rust threads in the browser

**Source**: https://github.com/tweag/rust-wasm-threads/tree/main/shared-memory

MIT License

Copyright (c) 2023 Matthew Toohey and Modus Create LLC

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
30 changes: 16 additions & 14 deletions burn-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,14 @@ version = "0.11.0"

[features]
default = ["metrics", "tui"]
metrics = [
"nvml-wrapper",
"sysinfo",
"systemstat"
]
tui = [
"ratatui",
"crossterm"
]
metrics = ["nvml-wrapper", "sysinfo", "systemstat"]
tui = ["ratatui", "crossterm"]
browser = ["js-sys", "web-sys", "wasm-bindgen"]

[dependencies]
burn-core = {path = "../burn-core", version = "0.11.0" }
burn-core = { path = "../burn-core", version = "0.11.0" }

log = {workspace = true}
log = { workspace = true }
tracing-subscriber.workspace = true
tracing-appender.workspace = true
tracing-core.workspace = true
Expand All @@ -40,8 +34,16 @@ ratatui = { version = "0.23", optional = true, features = ["all-widgets"] }
crossterm = { version = "0.27", optional = true }

# Utilities
derive-new = {workspace = true}
serde = {workspace = true, features = ["std", "derive"]}
derive-new = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }

js-sys = { version = "0.3.64", optional = true }
web-sys = { version = "0.3.65", optional = true, features = [
"Worker",
"WorkerOptions",
"WorkerType",
] }
wasm-bindgen = { workspace = true, optional = true }

[dev-dependencies]
burn-ndarray = {path = "../burn-ndarray", version = "0.11.0" }
burn-ndarray = { path = "../burn-ndarray", version = "0.11.0" }
4 changes: 2 additions & 2 deletions burn-train/src/checkpoint/strategy/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ mod tests {
},
TestBackend,
};
use std::sync::Arc;
use std::rc::Rc;

use super::*;

Expand All @@ -93,7 +93,7 @@ mod tests {
store.register_logger_train(InMemoryMetricLogger::default());
// Register the loss metric.
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Arc::new(EventStoreClient::new(store));
let store = Rc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());

// Two points for the first epoch. Mean 0.75
Expand Down
3 changes: 2 additions & 1 deletion burn-train/src/learner/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::Module;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::Backend;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand All @@ -24,7 +25,7 @@ pub struct Learner<LC: LearnerComponents> {
pub(crate) interrupter: TrainingInterrupter,
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
pub(crate) event_processor: LC::EventProcessor,
pub(crate) event_store: Arc<EventStoreClient>,
pub(crate) event_store: Rc<EventStoreClient>,
}

#[derive(new)]
Expand Down
4 changes: 2 additions & 2 deletions burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::rc::Rc;

use super::log::install_file_logger;
use super::Learner;
Expand Down Expand Up @@ -312,7 +312,7 @@
));
}

let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_store = Rc::new(EventStoreClient::new(self.event_store));

Check warning on line 315 in burn-train/src/learner/builder.rs

View check run for this annotation

Codecov / codecov/patch

burn-train/src/learner/builder.rs#L315

Added line #L315 was not covered by tests
let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone());

let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
Expand Down
4 changes: 2 additions & 2 deletions burn-train/src/learner/early_stopping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl MetricEarlyStoppingStrategy {

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::{rc::Rc, sync::Arc};

use crate::{
logger::InMemoryMetricLogger,
Expand Down Expand Up @@ -188,7 +188,7 @@ mod tests {
store.register_logger_train(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());

let store = Arc::new(EventStoreClient::new(store));
let store = Rc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());

let mut epoch = 1;
Expand Down
2 changes: 2 additions & 0 deletions burn-train/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

//! A library for training neural networks using the burn crate.

pub mod util;

#[macro_use]
extern crate derive_new;

Expand Down
6 changes: 3 additions & 3 deletions burn-train/src/metric/processor/full.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use super::{Event, EventProcessor, Metrics};
use crate::metric::store::EventStoreClient;
use crate::renderer::{MetricState, MetricsRenderer};
use std::sync::Arc;
use std::rc::Rc;

/// An [event processor](EventProcessor) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
/// - Render metrics using a [metrics renderer](MetricsRenderer).
pub struct FullEventProcessor<T, V> {
metrics: Metrics<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
store: Rc<EventStoreClient>,
}

impl<T, V> FullEventProcessor<T, V> {
pub(crate) fn new(
metrics: Metrics<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
store: Rc<EventStoreClient>,

Check warning on line 19 in burn-train/src/metric/processor/full.rs

View check run for this annotation

Codecov / codecov/patch

burn-train/src/metric/processor/full.rs#L19

Added line #L19 was not covered by tests
) -> Self {
Self {
metrics,
Expand Down
4 changes: 2 additions & 2 deletions burn-train/src/metric/processor/minimal.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use super::{Event, EventProcessor, Metrics};
use crate::metric::store::EventStoreClient;
use std::sync::Arc;
use std::rc::Rc;

/// An [event processor](EventProcessor) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
#[derive(new)]
pub(crate) struct MinimalEventProcessor<T, V> {
metrics: Metrics<T, V>,
store: Arc<EventStoreClient>,
store: Rc<EventStoreClient>,
}

impl<T, V> EventProcessor for MinimalEventProcessor<T, V> {
Expand Down
9 changes: 5 additions & 4 deletions burn-train/src/metric/store/client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use super::EventStore;
use super::{Aggregate, Direction, Event, Split};
use std::{sync::mpsc, thread::JoinHandle};
use crate::util;
use std::sync::mpsc;

/// Type that allows to communicate with an [event store](EventStore).
pub struct EventStoreClient {
sender: mpsc::Sender<Message>,
handler: Option<JoinHandle<()>>,
handler: Option<Box<dyn FnOnce() -> Result<(), ()>>>,
}

impl EventStoreClient {
Expand All @@ -17,7 +18,7 @@ impl EventStoreClient {
let (sender, receiver) = mpsc::channel();
let thread = WorkerThread::new(store, receiver);

let handler = std::thread::spawn(move || thread.run());
let handler = util::spawn(move || thread.run());
let handler = Some(handler);

Self { sender, handler }
Expand Down Expand Up @@ -153,7 +154,7 @@ impl Drop for EventStoreClient {
let handler = self.handler.take();

if let Some(handler) = handler {
handler.join().expect("The event store thread should stop.");
handler().expect("The event store thread should stop.");
}
}
}
64 changes: 64 additions & 0 deletions burn-train/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#![allow(missing_docs)]

#[cfg(feature = "browser")]
use wasm_bindgen::prelude::*;

#[cfg(not(feature = "browser"))]
pub fn spawn<F>(f: F) -> Box<dyn FnOnce() -> Result<(), ()>>
where
F: FnOnce(),
F: Send + 'static,
{
let handle = std::thread::spawn(f);
Box::new(move || handle.join().map_err(|_| ()))
}

// High level description at https://www.tweag.io/blog/2022-11-24-wasm-threads-and-messages/
// Mostly copied from https://github.com/tweag/rust-wasm-threads/blob/main/shared-memory/src/lib.rs
#[cfg(feature = "browser")]
pub fn spawn<F>(f: F) -> Box<dyn FnOnce() -> Result<(), ()>>
where
F: FnOnce(),
F: Send + 'static,
{
let mut worker_options = web_sys::WorkerOptions::new();
worker_options.type_(web_sys::WorkerType::Module);
// Double-boxing because `dyn FnOnce` is unsized and so `Box<dyn FnOnce()>` has
let w = web_sys::Worker::new_with_options(
WORKER_URL
.get()
.expect("You must first call `init` with the worker's url."),
&worker_options,
)
.unwrap_or_else(|_| panic!("Error initializing worker at {:?}", WORKER_URL));
// an undefined layout (although I think in practice its a pointer and a length?).
let ptr = Box::into_raw(Box::new(Box::new(f) as Box<dyn FnOnce()>));

// See `worker.js` for the format of this message.
let msg: js_sys::Array = [
&wasm_bindgen::module(),
&wasm_bindgen::memory(),
&JsValue::from(ptr as u32),
]
.into_iter()
.collect();
if let Err(e) = w.post_message(&msg) {

Check warning on line 45 in burn-train/src/util.rs

View check run for this annotation

Codecov / codecov/patch

burn-train/src/util.rs#L19-L45

Added lines #L19 - L45 were not covered by tests
// We expect the worker to deallocate the box, but if there was an error then
// we'll do it ourselves.
let _ = unsafe { Box::from_raw(ptr) };
panic!("Error initializing worker during post_message: {:?}", e)

Check warning on line 49 in burn-train/src/util.rs

View check run for this annotation

Codecov / codecov/patch

burn-train/src/util.rs#L48-L49

Added lines #L48 - L49 were not covered by tests
} else {
Box::new(move || {
w.terminate();
Ok(())
})
}
}

Check warning on line 56 in burn-train/src/util.rs

View check run for this annotation

Codecov / codecov/patch

burn-train/src/util.rs#L51-L56

Added lines #L51 - L56 were not covered by tests

#[cfg(feature = "browser")]
static WORKER_URL: std::sync::OnceLock<String> = std::sync::OnceLock::new();

#[cfg(feature = "browser")]
pub fn init(worker_url: String) -> Result<(), String> {
WORKER_URL.set(worker_url)
}

Check warning on line 64 in burn-train/src/util.rs

View check run for this annotation

Codecov / codecov/patch

burn-train/src/util.rs#L62-L64

Added lines #L62 - L64 were not covered by tests
12 changes: 11 additions & 1 deletion burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ train = ["burn-train/default", "autodiff", "dataset"]
# Useful when targeting WASM and not using WGPU.
wasm-sync = ["burn-core/wasm-sync"]

browser = ["burn-train/browser"]

## Include nothing
train-minimal = ["burn-train"]

Expand Down Expand Up @@ -62,4 +64,12 @@ burn-core = { path = "../burn-core", version = "0.11.0", default-features = fals
burn-train = { path = "../burn-train", version = "0.11.0", optional = true, default-features = false }

[package.metadata.docs.rs]
features = ["dataset", "default", "std", "train", "train-tui", "train-metrics", "dataset-sqlite"]
features = [
"dataset",
"default",
"std",
"train",
"train-tui",
"train-metrics",
"dataset-sqlite",
]
14 changes: 14 additions & 0 deletions examples/train-web/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## Getting Started

Your `wasm-bindgen-cli` version must *exactly* match the `wasm-bindgen` version in [Cargo.toml](../../Cargo.toml) since `wasm-bindgen-cli` is implicitly used by `wasm-pack`.

For example, run `cargo install --version 0.2.88 wasm-bindgen-cli --force`. The version in this example command is not guaranteed to be up to date!

Install [PNPM](https://pnpm.io/).

Then in separate terminals:

1. `cd train && dev.sh`
2. `cd web && pnpm i && pnpm dev`

Any changes to `/train` or `burn` should trigger a recompilation. When a new binary is generated, `web` will automatically refresh the page.
6 changes: 6 additions & 0 deletions examples/train-web/train/.cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[unstable]
build-std = ['std', 'panic_abort']

[build]
target = "wasm32-unknown-unknown"
rustflags = '-Ctarget-feature=+atomics,+bulk-memory,+mutable-globals'
1 change: 1 addition & 0 deletions examples/train-web/train/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pkg
21 changes: 21 additions & 0 deletions examples/train-web/train/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "train"
version = "0.1.0"
edition = "2021"

[lib]
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = { workspace = true }
log = { workspace = true }
console_error_panic_hook = "0.1.7"
console_log = { version = "1", features = ["color"] }
burn = { path = "../../../burn", default-features = false, features = [
"autodiff",
"ndarray-no-std",
"train-minimal",
"wasm-sync",
"browser",
] }
serde = { workspace = true }
Loading