Skip to content

Commit

Permalink
Merge pull request #27 from FL33TW00D/fix/unit-fails
Browse files Browse the repository at this point in the history
Fix/unit fails
  • Loading branch information
FL33TW00D authored Jan 22, 2024
2 parents 50f7b5b + 15c3893 commit 5022af4
Show file tree
Hide file tree
Showing 22 changed files with 390 additions and 266 deletions.
17 changes: 11 additions & 6 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ env:
CARGO_TERM_COLOR: always
WGPU_DX12_COMPILER: dxc
RUSTFLAGS: --cfg=web_sys_unstable_apis
RUST_BACKTRACE: 1

jobs:
build:
Expand All @@ -33,12 +34,16 @@ jobs:
sudo apt-get update
sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev vulkan-sdk mesa-vulkan-drivers pkg-config libasound2-dev
- name: Setup
run: |
cargo install wasm-pack
- name: Build
run: cargo build
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: '3.10.6'
cache: 'pip'
- run: pip install -r requirements.txt
- name: Run tests
run: cargo test
run: cargo test tensor -- --test-threads=1 --nocapture
- name: Install wasm-pack
run: |
curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Run integration tests
run: (cd crates/ratchet-integration-tests;sh run-tests.sh)
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ Cargo.lock

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
.python-version
13 changes: 6 additions & 7 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@ Ratchet is designed for 1 thing only: **Inference on WebGPU**.

This leads us to a few design decisions:
1. Ratchet is **lazy**, no computation is done until the entire computation graph is built and executed. This aligns closely with CUDAGraphs & Command buffers.
2. Ratchet supports **BOTH** static & dynamic graphs, this is key.
- The graph is implicitly defined through tensor operations. If any of the tensors are defined with a *symbolic dimension* (i.e a dimension not known until runtime, e.g sequence_len), the graph is dynamic. When the graph is dynamic, the graph is recompiled on inference pass (because runtime information is required).
- If no tensors contain a symbolic dimension, the graph is static. This means the graph is compiled into a single command buffer, and is repeatedly called with different input data (brrr).

By exposing symbolic dimensions to the user, they can code their models with the CG in mind.
2. Ratchet supports **BOTH** static & dynamic graphs, see [Unified Graph Execution by Jittor](http://scis.scichina.com/en/2020/222103.pdf) for more details.
3. Memory planning is crucial. Creation and first bind of a buffer is *expensive* in WebGPU. Therefore, Ratchet uses a greedy algorithm to pool buffers for intermediate results of the CFG.

Why do this?

Take for example Whisper from OpenAI. This is an encoder-decoder model, where the encoder is completely static (i.e everything is known at compile time), and the decoder is very dynamic (KV caching, seq_len increments every step). By allowing both paradigms, we can maximise performance.

## Memory Management

Ratchets top level `Tensor` is just an `Arc` around the `Inner`. Tensors should be cheaply cloneable.
`Inner` contains a struct `Storage`, this is an enum around our 2 managed structures for CPU & GPU: `CpuStorage` & `GpuStorage`.
`CpuStorage` is an `Arc<RwLock<RawCPUBuffer>>`, and `GpuStorage` is an `Arc<RwLock<Buffer>>`.


## Quantization
Expand Down
7 changes: 2 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ strip = true
#debug = 2

[workspace.dependencies]
wgpu = { version = "0.19.0", features = ["fragile-send-sync-non-atomic-wasm"] }
wgpu = { version = "0.18.0", features = ["fragile-send-sync-non-atomic-wasm", "expose-ids"] }
anyhow = "1.0.40"
bytemuck = "1.14.0"
bytemuck = { version = "1.14.0", features=["wasm_simd", "aarch64_simd", "extern_crate_alloc"] }
num-traits = "0.2.17"
half = { version = "2.3.1", features = ["num-traits", "bytemuck"] }
derive-new = "0.6.0"
log = "0.4.20"
thiserror = "1.0.56"
byteorder = "1.5.0"

[workspace.dev-dependencies]
hf-hub = "0.3.0"
14 changes: 11 additions & 3 deletions crates/ratchet-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ version = "0.1.0"
edition = "2021"

[features]
default = ["rand"]
default = ["rand", "pyo3"]
pyo3 = ["dep:pyo3", "dep:numpy", "dep:ndarray"]
gpu_profiling = []
rand = ["dep:rand", "dep:rand_distr"]

Expand All @@ -27,14 +28,21 @@ slotmap = "1.0.7"
parking_lot = "0.12.1"
smallvec = "1.11.2"
encase = "0.7.0"
glam = "0.25.0"
pollster = "0.3.0"
futures-intrusive = "0.5.0"
anyhow = "1.0.79"
num = "0.4.1"
rand_distr = { version = "0.4.3", optional = true }
rand = { version = "0.8.4", optional = true }
lazy_static = "1.4.0"

# Python bindings
pyo3 = { version = "0.20.2", features=["auto-initialize"], optional = true }
numpy = { version = "0.20.0", optional = true }
ndarray = { version = "0.15.6", optional = true }

[dev-dependencies]
rand = "0.8.4"
pyo3 = { version = "0.20.2", features=["auto-initialize"] }
numpy = { version = "0.20.0" }
ndarray = { version = "0.15.6" }

5 changes: 3 additions & 2 deletions crates/ratchet-core/src/compiled_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ impl CompiledOp {
let mut bind_group_entries = drvec![];

for tensor in srcs.iter().chain(std::iter::once(&dst)) {
let buf = tensor.storage().try_read().unwrap();
let gpu_buf = &buf.try_gpu().unwrap().inner;
let storage_guard = tensor.storage();
let storage = storage_guard.as_ref().unwrap();
let gpu_buf = &storage.try_gpu().unwrap().inner;
bind_group_entries.push(BindGroupEntry {
handle: gpu_buf.handle,
offset: 0,
Expand Down
6 changes: 5 additions & 1 deletion crates/ratchet-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub enum DeviceError {
BufferAllocationFailed(#[from] AllocatorError),
#[error("Invalid GPU Buffer Usage, current: {0:?}, required: {1:?}")]
InvalidBufferUsage(wgpu::BufferUsages, wgpu::BufferUsages),
#[error("Failed to transfer buffer with error: {0:?}")]
BufferTransferFailed(#[from] wgpu::BufferAsyncError),
}

pub enum DeviceRequest {
Expand Down Expand Up @@ -49,7 +51,9 @@ impl Device {
pub fn request_device(request: DeviceRequest) -> Result<Self, DeviceError> {
match request {
DeviceRequest::CPU => Ok(Device::CPU),
DeviceRequest::GPU => Ok(Device::GPU(pollster::block_on(WgpuDevice::new())?)),
DeviceRequest::GPU => Ok(Device::GPU(pollster::block_on(async {
WgpuDevice::new().await
})?)),
}
}

Expand Down
27 changes: 13 additions & 14 deletions crates/ratchet-core/src/gpu/buffer_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rustc_hash::FxHashMap;
use wgpu::BufferUsages;

use crate::{
gpu::{BufferDescriptor, BufferPool, GPUBuffer, GpuBufferHandle},
gpu::{BufferDescriptor, BufferPool, GpuBufferHandle, PooledGPUBuffer},
DeviceError, Tensor, TensorId,
};
use std::cell::{Ref, RefCell, RefMut};
Expand Down Expand Up @@ -31,7 +31,7 @@ impl BufferAllocator {
self.pool.borrow_mut().begin_pass(pass_index);
}

pub fn get(&self, handle: GpuBufferHandle) -> GPUBuffer {
pub fn get(&self, handle: GpuBufferHandle) -> PooledGPUBuffer {
self.pool.borrow().get(handle).unwrap()
}

Expand All @@ -43,7 +43,7 @@ impl BufferAllocator {
self.pool.borrow_mut()
}

pub fn create_buffer(&self, desc: &BufferDescriptor, device: &WgpuDevice) -> GPUBuffer {
pub fn create_buffer(&self, desc: &BufferDescriptor, device: &WgpuDevice) -> PooledGPUBuffer {
self.pool.borrow_mut().get_or_create(desc, device)
}

Expand All @@ -52,13 +52,13 @@ impl BufferAllocator {
desc: &BufferDescriptor,
contents: &[u8],
device: &WgpuDevice,
) -> GPUBuffer {
) -> PooledGPUBuffer {
let buf = self.pool.borrow_mut().get_or_create(desc, device);
device.queue().write_buffer(&buf.inner, 0, contents);
buf
}

pub fn create_uniform_init(&self, uniform: CpuUniform, device: &WgpuDevice) -> GPUBuffer {
pub fn create_uniform_init(&self, uniform: CpuUniform, device: &WgpuDevice) -> PooledGPUBuffer {
let mut uniform = uniform.into_inner();
uniform.resize(
uniform.len() + UNIFORM_ALIGN - uniform.len() % UNIFORM_ALIGN,
Expand All @@ -85,9 +85,9 @@ impl BufferAllocator {
fn graph_allocate(
&self,
descriptor: BufferDescriptor,
free: &mut Vec<GPUBuffer>,
free: &mut Vec<PooledGPUBuffer>,
device: &WgpuDevice,
) -> GPUBuffer {
) -> PooledGPUBuffer {
let required_size = descriptor.size as _;
let mut closest_index = None;
let mut closest_size_diff: Option<usize> = None;
Expand Down Expand Up @@ -121,17 +121,16 @@ impl BufferAllocator {
&self,
execution_order: &[Tensor],
device: &WgpuDevice,
) -> Result<FxHashMap<TensorId, GPUBuffer>, DeviceError> {
) -> Result<FxHashMap<TensorId, PooledGPUBuffer>, DeviceError> {
let mut free = Vec::new(); //TODO: switch to BTreeMap
let mut assignments = FxHashMap::default();

for t in execution_order {
if t.resolved() {
let storage_resource = t
.storage()
.try_read()
.ok_or(AllocatorError::BufferNotFound)?;
assignments.insert(t.id(), storage_resource.try_gpu()?.inner.clone());
assignments.insert(
t.id(),
t.storage().as_ref().unwrap().try_gpu()?.inner.clone(),
);
continue;
}

Expand Down Expand Up @@ -159,7 +158,7 @@ impl BufferAllocator {
let output = execution_order.last().unwrap();
assignments.insert(
output.id(),
device.allocate_buffer(&BufferDescriptor {
device.get_or_create_buffer(&BufferDescriptor {
size: output.num_bytes() as _,
usage: BufferUsages::standard(),
mapped_at_creation: false,
Expand Down
25 changes: 14 additions & 11 deletions crates/ratchet-core/src/gpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use wgpu::{Adapter, DeviceType, Limits};

use crate::DeviceError;

use super::{BufferDescriptor, GPUBuffer, PoolError};
use super::{BufferDescriptor, PoolError, PooledGPUBuffer};

pub const MAX_BUFFER_SIZE: u64 = (2 << 29) - 1;

Expand Down Expand Up @@ -56,16 +56,16 @@ impl WgpuDevice {
let adapter = Self::select_adapter()?;

#[allow(unused_mut)]
let mut required_features = wgpu::Features::default();
let mut features = wgpu::Features::default();
#[cfg(feature = "gpu-profiling")]
{
features |= wgpu::Features::TIMESTAMP_QUERY;
}

let mut device_descriptor = wgpu::DeviceDescriptor {
label: Some("ratchet"),
required_features,
required_limits: Limits {
features,
limits: Limits {
max_buffer_size: MAX_BUFFER_SIZE,
max_storage_buffer_binding_size: MAX_BUFFER_SIZE as u32,
..Default::default()
Expand All @@ -77,7 +77,7 @@ impl WgpuDevice {
"Failed to acq. device, trying again with reduced limits: {:?}",
e
);
device_descriptor.required_limits = adapter.limits();
device_descriptor.limits = adapter.limits();
adapter.request_device(&device_descriptor, None).await
} else {
device_request
Expand Down Expand Up @@ -147,25 +147,28 @@ impl WgpuDevice {
}

impl WgpuDevice {
pub fn create_buffer_init(
pub fn get_or_create_buffer_init(
&self,
desc: &BufferDescriptor,
contents: &[u8],
) -> Result<GPUBuffer, DeviceError> {
) -> Result<PooledGPUBuffer, DeviceError> {
Ok(self
.buffer_allocator
.create_buffer_init(desc, contents, self))
}

pub fn create_uniform_init(&self, cpu_uniform: CpuUniform) -> GPUBuffer {
pub fn create_uniform_init(&self, cpu_uniform: CpuUniform) -> PooledGPUBuffer {
self.buffer_allocator.create_uniform_init(cpu_uniform, self)
}

pub fn allocate_buffer(&self, desc: &BufferDescriptor) -> Result<GPUBuffer, DeviceError> {
pub fn get_or_create_buffer(
&self,
desc: &BufferDescriptor,
) -> Result<PooledGPUBuffer, DeviceError> {
Ok(self.buffer_allocator.create_buffer(desc, self))
}

pub fn get_buffer(&self, handle: GpuBufferHandle) -> Result<GPUBuffer, DeviceError> {
pub fn get_buffer(&self, handle: GpuBufferHandle) -> Result<PooledGPUBuffer, DeviceError> {
Ok(self.buffer_allocator.get(handle))
}

Expand Down Expand Up @@ -221,7 +224,7 @@ impl WgpuDevice {
&self,
execution_order: &[Tensor],
device: &WgpuDevice,
) -> Result<FxHashMap<TensorId, GPUBuffer>, DeviceError> {
) -> Result<FxHashMap<TensorId, PooledGPUBuffer>, DeviceError> {
self.buffer_allocator.allocate_cfg(execution_order, device)
}
}
4 changes: 2 additions & 2 deletions crates/ratchet-core/src/gpu/pools/bind_group_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ slotmap::new_key_type! { pub struct GpuBindGroupHandle; }
#[derive(Clone)]
pub struct GpuBindGroup {
resource: Arc<DynamicResource<GpuBindGroupHandle, BindGroupDescriptor, wgpu::BindGroup>>,
_owned_buffers: RVec<GPUBuffer>,
_owned_buffers: RVec<PooledGPUBuffer>,
}

impl std::fmt::Debug for GpuBindGroup {
Expand Down Expand Up @@ -98,7 +98,7 @@ impl BindGroupPool {
pub fn get_or_create(&self, desc: &BindGroupDescriptor, device: &WgpuDevice) -> GpuBindGroup {
// Retrieve strong handles to buffers and textures.
// This way, an owner of a bind group handle keeps buffers & textures alive!.
let owned_buffers: RVec<GPUBuffer> = {
let owned_buffers: RVec<PooledGPUBuffer> = {
desc.entries
.iter()
.map(|e| device.get_buffer(e.handle).unwrap())
Expand Down
12 changes: 6 additions & 6 deletions crates/ratchet-core/src/gpu/pools/buffer_pool.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Adapted from https://github.com/rerun-io/rerun MIT licensed
use super::{DynamicResource, DynamicResourcePool, DynamicResourcesDesc, PoolError};
use crate::gpu::WgpuDevice;
use crate::{gpu::WgpuDevice, RawGPUBuffer};

#[derive(Clone, Hash, PartialEq, Eq, Debug, derive_new::new)]
pub struct BufferDescriptor {
Expand All @@ -19,8 +19,8 @@ slotmap::new_key_type! { pub struct GpuBufferHandle; }

/// A reference-counter baked buffer.
/// Once all instances are dropped, the buffer will be marked for reclamation in the following pass.
pub type GPUBuffer =
std::sync::Arc<DynamicResource<GpuBufferHandle, BufferDescriptor, wgpu::Buffer>>;
pub type PooledGPUBuffer =
std::sync::Arc<DynamicResource<GpuBufferHandle, BufferDescriptor, RawGPUBuffer>>;

impl DynamicResourcesDesc for BufferDescriptor {
fn resource_size_in_bytes(&self) -> u64 {
Expand All @@ -37,7 +37,7 @@ impl DynamicResourcesDesc for BufferDescriptor {
}

pub struct BufferPool {
inner: DynamicResourcePool<GpuBufferHandle, BufferDescriptor, wgpu::Buffer>,
inner: DynamicResourcePool<GpuBufferHandle, BufferDescriptor, RawGPUBuffer>,
}

impl BufferPool {
Expand All @@ -47,7 +47,7 @@ impl BufferPool {
}
}

pub fn get_or_create(&self, desc: &BufferDescriptor, device: &WgpuDevice) -> GPUBuffer {
pub fn get_or_create(&self, desc: &BufferDescriptor, device: &WgpuDevice) -> PooledGPUBuffer {
self.inner.get_or_create(desc, |desc| {
let (size, usage, mapped_at_creation) = desc.fields();
device.create_buffer(&wgpu::BufferDescriptor {
Expand All @@ -64,7 +64,7 @@ impl BufferPool {
}

/// Method to retrieve a resource from a weak handle (used by [`super::GpuBindGroupPool`])
pub fn get(&self, handle: GpuBufferHandle) -> Result<GPUBuffer, PoolError> {
pub fn get(&self, handle: GpuBufferHandle) -> Result<PooledGPUBuffer, PoolError> {
self.inner.get_from_handle(handle)
}

Expand Down
4 changes: 2 additions & 2 deletions crates/ratchet-core/src/gpu/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
rvec,
};

use super::{BindGroupDescriptor, GPUBuffer, GpuBindGroup, WgpuDevice};
use super::{BindGroupDescriptor, GpuBindGroup, PooledGPUBuffer, WgpuDevice};
use encase::DynamicUniformBuffer;

///We use a single uniform buffer for all operations to hold their parameters.
Expand Down Expand Up @@ -56,7 +56,7 @@ impl CpuUniform {
}

pub struct GpuUniform {
buf: GPUBuffer,
buf: PooledGPUBuffer,
bind_group: GpuBindGroup,
}

Expand Down
Loading

0 comments on commit 5022af4

Please sign in to comment.