Skip to content

Commit

Permalink
draft: i64 support (tested on clip model)
Browse files Browse the repository at this point in the history
  • Loading branch information
EdupugantiAkhil committed Sep 1, 2024
1 parent 991a1b0 commit 71fdc57
Show file tree
Hide file tree
Showing 8 changed files with 1,820 additions and 1,140 deletions.
741 changes: 423 additions & 318 deletions candle-core/src/wgpu_backend/cache.rs

Large diffs are not rendered by default.

855 changes: 590 additions & 265 deletions candle-core/src/wgpu_backend/device.rs

Large diffs are not rendered by default.

355 changes: 264 additions & 91 deletions candle-core/src/wgpu_backend/storage.rs

Large diffs are not rendered by default.

58 changes: 44 additions & 14 deletions candle-core/src/wgpu_backend/wgpu_functions/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@ pub fn queue_convert_u32_to_f32(
let mut meta = get_meta(&dev);
meta.add_layout1(&input_layout);


let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U32, Functions::ConvertToF32));
let bind_group = create_bind_group_input1( buffer_dest, buffer_input);
let bind_group = create_bind_group_input1(buffer_dest, buffer_input);
enqueue(
meta,
pipeline,
bind_group,
input_layout.shape().elem_count() as u32,
input_layout.shape().elem_count()
input_layout.shape().elem_count(),
);
return Ok(());
}


pub fn queue_convert_u8_to_f32(
dev: &WgpuDevice,
buffer_dest: BufferReferenceId,
Expand All @@ -35,13 +33,13 @@ pub fn queue_convert_u8_to_f32(
meta.add_layout1(&input_layout);

let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U8, Functions::ConvertU8ToF32));
let bind_group = create_bind_group_input1( buffer_dest, buffer_input);
let bind_group = create_bind_group_input1(buffer_dest, buffer_input);
enqueue(
meta,
pipeline,
bind_group,
input_layout.shape().elem_count() as u32,
input_layout.shape().elem_count()
input_layout.shape().elem_count(),
);
return Ok(());
}
Expand All @@ -57,7 +55,7 @@ pub fn queue_convert_f32_to_u32(

let pipeline = meta.get_pipeline(Pipelines::Convert(DType::F32, Functions::ConvertToU32));

let bind_group = create_bind_group_input1( buffer_dest, buffer_input);
let bind_group = create_bind_group_input1(buffer_dest, buffer_input);
enqueue(
meta,
pipeline,
Expand All @@ -68,21 +66,20 @@ pub fn queue_convert_f32_to_u32(
return Ok(());
}


pub fn queue_convert_u32_to_u8(
dev: &WgpuDevice,
buffer_dest: BufferReferenceId,
buffer_input: BufferReferenceId,
start_offset: u32,
size : u32
size: u32,
) -> crate::Result<()> {
let mut meta = get_meta(&dev);
meta.add(start_offset);
meta.add(size);

let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U32, Functions::ConvertU32ToU8));
let bind_group = create_bind_group_input1( buffer_dest, buffer_input);

let bind_group = create_bind_group_input1(buffer_dest, buffer_input);
enqueue(
meta,
pipeline,
Expand All @@ -93,20 +90,53 @@ pub fn queue_convert_u32_to_u8(
return Ok(());
}

pub fn queue_convert_u32_to_i64(
dev: &WgpuDevice,
buffer_dest: BufferReferenceId,
buffer_input: BufferReferenceId,
start_offset: u32,
size: u32,
) -> crate::Result<()> {
let mut meta = get_meta(&dev);
meta.add(start_offset);
meta.add(size);

// Get the appropriate pipeline for converting u32 to i64
let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U32, Functions::ConvertU32ToI64));

// Create a bind group for the destination and input buffers
let bind_group = create_bind_group_input1_8(buffer_dest, buffer_input);

// Calculate the number of workgroups needed.
// Assuming that each workgroup processes 4 u32 values, and each u32 is converted to an i64
let num_workgroups = ((size + 3) / 4) as u32;

// Enqueue the compute operation
enqueue(
meta,
pipeline,
bind_group,
num_workgroups, // number of workgroups
size as usize, // total size (in elements)
);

Ok(())
}

pub fn queue_convert_f32_to_u8(
dev: &WgpuDevice,
buffer_dest: BufferReferenceId,
buffer_input: BufferReferenceId,
start_offset: u32,
size : u32
size: u32,
) -> crate::Result<()> {
let mut meta = get_meta(&dev);
meta.add(start_offset);
meta.add(size);

let pipeline = meta.get_pipeline(Pipelines::Convert(DType::F32, Functions::ConvertF32ToU8));

let bind_group = create_bind_group_input1( buffer_dest, buffer_input);
let bind_group = create_bind_group_input1(buffer_dest, buffer_input);
enqueue(
meta,
pipeline,
Expand All @@ -115,4 +145,4 @@ pub fn queue_convert_f32_to_u8(
size as usize,
);
return Ok(());
}
}
46 changes: 32 additions & 14 deletions candle-core/src/wgpu_backend/wgpu_functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub use cmp::queue_cmp_buffer_from_buffer;
pub use conv2d::{queue_conv1d, queue_conv1d_transpose, queue_conv2d, queue_conv2d_transpose};
pub use convert::{
queue_convert_f32_to_u32, queue_convert_f32_to_u8, queue_convert_u32_to_f32,
queue_convert_u32_to_u8, queue_convert_u8_to_f32,
queue_convert_u32_to_u8, queue_convert_u8_to_f32, queue_convert_u32_to_i64,
};
pub use copy::{queue_copy, queue_copy2d, queue_copy3d,queue_copy3d_padded, queue_copy_strided};
pub use gather::{queue_gather, queue_index_add_inplace, queue_scatter_add_inplace};
Expand Down Expand Up @@ -306,7 +306,7 @@ fn get_command_buffer(
let vd = buffers.get_dest();
match buffers.get_input(){
BindgroupInputBase::Bindgroup0 => {},
BindgroupInputBase::Bindgroup1(v1, _) => {
BindgroupInputBase::Bindgroup1(v1, _, _) => {
if v1 == vd{
panic!("B1: output and input are equal");
}
Expand Down Expand Up @@ -445,7 +445,7 @@ fn prepare(dev: &WgpuDevice, queue_buffer: &mut QueueBuffer, cache : &mut ModelC

match input {
BindgroupInputBase::Bindgroup0 => {},
BindgroupInputBase::Bindgroup1(v1, _) => {
BindgroupInputBase::Bindgroup1(v1, _, _) => {
check_buffer(v1);
},
BindgroupInputBase::Bindgroup2(v1,v2, _) => {
Expand Down Expand Up @@ -548,7 +548,7 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut
{
if let Pipelines::Unary(dtype, candle_wgpu_kernels::unary::Functions::UnaryFromBufferContiguous) = &q.pipeline.0{
if q.pipeline.2.input1_inplaceable{
if let BindgroupReferenceInput::Bindgroup1(v1_id, _) = bindgroup_reference.get_input()
if let BindgroupReferenceInput::Bindgroup1(v1_id, _, _) = bindgroup_reference.get_input()
{
if optmize_inplace(bindgroup_reference.get_dest(), v1_id)
{
Expand Down Expand Up @@ -578,7 +578,7 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut
q.pipeline.0 = Pipelines::Binary(dtype.clone(), candle_wgpu_kernels::binary::Functions::BinaryBufferInplace1ContiguousBoth);
q.bindgroup =
DispatchedBindgroup::BindgroupReference(
BindGroupReference::new(v1_id.clone(), BindgroupInputBase::Bindgroup1(v2_id.clone(), false)));
BindGroupReference::new(v1_id.clone(), BindgroupInputBase::Bindgroup1(v2_id.clone(), false,false)));
}
}
else if q.pipeline.2.input2_inplaceable{
Expand All @@ -587,14 +587,14 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut
q.pipeline.0 = Pipelines::Binary(dtype.clone(), candle_wgpu_kernels::binary::Functions::BinaryBufferInplace2ContiguousBoth);
q.bindgroup =
DispatchedBindgroup::BindgroupReference(
BindGroupReference::new(v2_id.clone(), BindgroupInputBase::Bindgroup1(v1_id.clone(), false)));
BindGroupReference::new(v2_id.clone(), BindgroupInputBase::Bindgroup1(v1_id.clone(), false,false)));
}
}
}
}
else if let Pipelines::Copy(_, candle_wgpu_kernels::copy::Functions::Copy) = &q.pipeline.0{
if q.pipeline.2.input1_inplaceable{
if let BindgroupReferenceInput::Bindgroup1(v1_id, _) = bindgroup_reference.get_input()
if let BindgroupReferenceInput::Bindgroup1(v1_id, _,_) = bindgroup_reference.get_input()
{
let v1 = cache.buffer_reference.get(v1_id);
if let Some(v1) = v1{
Expand Down Expand Up @@ -659,7 +659,7 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut
chec_buffer(dest, cache, command_index);
match input {
BindgroupInputBase::Bindgroup0 => {},
BindgroupInputBase::Bindgroup1(v1, _) => {chec_buffer(v1, cache, command_index);},
BindgroupInputBase::Bindgroup1(v1, _, _) => {chec_buffer(v1, cache, command_index);},
BindgroupInputBase::Bindgroup2(v1, v2, _) => {chec_buffer(v1, cache, command_index);chec_buffer(v2, cache, command_index);},
BindgroupInputBase::Bindgroup3(v1, v2, v3) => {chec_buffer(v1, cache, command_index);chec_buffer(v2, cache, command_index);chec_buffer(v3, cache, command_index);},
}
Expand All @@ -675,10 +675,13 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut
BindgroupReferenceInput::Bindgroup0 => {
&dev.bindgroup_layouts.pipeline_layout0
}
BindgroupReferenceInput::Bindgroup1( _,false) => {
BindgroupReferenceInput::Bindgroup1(_, _, true) => {
&dev.bindgroup_layouts.pipeline_layout1_8
}
BindgroupReferenceInput::Bindgroup1( _,false,_) => {
&dev.bindgroup_layouts.pipeline_layout1
}
BindgroupReferenceInput::Bindgroup1( _, true) => {
BindgroupReferenceInput::Bindgroup1( _, true, _) => {
&dev.bindgroup_layouts.pipeline_layout1_16
}
BindgroupReferenceInput::Bindgroup2( _, _, false) => {
Expand Down Expand Up @@ -1032,8 +1035,9 @@ pub fn create_bindgroup(dev: &WgpuDevice, bindgroup: CachedBindgroupFull, cache

let bind_group_layout = match bindgroup.get_input() {
CachedBindgroupInput::Bindgroup0 => &dev.bindgroup_layouts.bind_group_layout0,
CachedBindgroupInput::Bindgroup1(_, false) => &dev.bindgroup_layouts.bind_group_layout1,
CachedBindgroupInput::Bindgroup1(_, true) => &dev.bindgroup_layouts.bind_group_layout1_16,
CachedBindgroupInput::Bindgroup1(_, _, true) => &dev.bindgroup_layouts.bind_group_layout1_8,
CachedBindgroupInput::Bindgroup1(_, false, false) => &dev.bindgroup_layouts.bind_group_layout1,
CachedBindgroupInput::Bindgroup1(_, true, false) => &dev.bindgroup_layouts.bind_group_layout1_16,
CachedBindgroupInput::Bindgroup2(_, _, false) => &dev.bindgroup_layouts.bind_group_layout2,
CachedBindgroupInput::Bindgroup2(_, _, true) => &dev.bindgroup_layouts.bind_group_layout2_16,
CachedBindgroupInput::Bindgroup3(_, _, _) => &dev.bindgroup_layouts.bind_group_layout3,
Expand All @@ -1058,7 +1062,7 @@ pub fn create_bindgroup(dev: &WgpuDevice, bindgroup: CachedBindgroupFull, cache
entries: entries,
})
}
CachedBindgroupInput::Bindgroup1(buffer_input1, _) => {
CachedBindgroupInput::Bindgroup1(buffer_input1, _,_) => {

if cache.buffers.get_buffer(buffer_input1).is_none(){
panic!("buffer_input_1 : {:?} could not be found(in {:?})", buffer_input1, bindgroup);
Expand Down Expand Up @@ -1145,8 +1149,22 @@ fn create_bind_group_input1(
buffer_dest: BufferReferenceId,
buffer_input1: BufferReferenceId,
) -> BindGroupReference {
BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, false))
BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, false,false))
}

fn create_bind_group_input1_8(
buffer_dest: BufferReferenceId,
buffer_input1: BufferReferenceId,
) -> BindGroupReference {
BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, true, true))
}
fn create_bind_group_input1_16(
buffer_dest: BufferReferenceId,
buffer_input1: BufferReferenceId,
) -> BindGroupReference {
BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, true, false))
}


fn create_bind_group_input2(
buffer_dest: BufferReferenceId,
Expand Down
Loading

23 comments on commit 71fdc57

@KimHenrikOtte
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the code and found a few things:

  • You explicitly added a Vulcan instead of Primary (which I think would automatically use Vulcan, DirectX, or Metal). was this on purpose?
  • I think you were using wgpu::PowerPreference::LowPower by default. I was expecting for nn-stuff to use the best gpu we can find (e.g. a dedicated gpu which I think may not be selected when using low power).
  • I added a bool for is_16 for the alignment of the bindgroup. If we add different alignments we may want to use an enum instead of bools. (In addition, for normal commands (e.g. binary addition with i64) you need a bindgroup with all inputs and outputs 8byte aligned, while for convert you need different output and input alignments.

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback
I defiantly agree that we should switch to enums it would also solve the problem of selecting the input/output size of pipelines and reduce repeated code, the bools were added just to check if I am even capable of fixing it as I am new to everything(including rust).

all the changes in device.rs were only for debugging and getting the project running as I initially was not able to get my GPU recognized.

I also noticed that the clip demo example with 20 images loaded took around 23x the time as my cpu (4s cpu -> 94s gpu) with the default(your) device.rs with wgpu::Features::SHADER_INT64; added with my cpu pinned at 100% and very little GPU utilization but that could also be caused by me using an Intel(R) HD Graphics 4400 Haswell GPU that doesn't have complete Vulkan support (feels less likely)

what do you think could be the cause of this and could parallelism or wgpu_core::device::resource: Device::maintain: waiting for submission index x be related?

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Sep 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested your fork on my pc:
cpu : 550-600ms
gpu: about the same(I got manys 550-600ms as well but also a few 700ms)

I dumped time on the gpu with: (one also needs to enable the timestamp feature queries and must use the wgpu_debug features: --features="wgpu_Debug")

 match &device {
       candle::Device::WebGpu(gpu) => {
           gpu.print_bindgroup_reuseinfo2();
           #[cfg(feature = "wgpu_debug")]{
               let info = pollster::block_on(gpu.get_debug_info()).unwrap();
               let map2 = candle::wgpu::debug_info::calulate_measurment(&info);
               candle::wgpu::debug_info::save_list(&map2,& format!("wgpu_clip_test_1_b.json")).unwrap();
           
           
               let info: Vec<candle::wgpu::debug_info::ShaderInfo> = gpu.get_pipeline_info().unwrap();
               candle::wgpu::debug_info::save_list(&info,& format!("wgpu_clip_test_1_c.json")).unwrap();
           }
       },
       _ => {},
   };

and analysed the resulting json with "/auswertung2.ipynb"

I get the following result:

 Total sum for Duration: 0.04121702399999999
Operations sorted by total duration for Duration:
Operation: Pipeline: Matmul64x648x8(F32, Matmul), Batched: 2*(50x768 * 768x768), Duration: 0.008736767999999999, Count: 48 Perc:21.20%
Operation: Pipeline: Matmul64x648x8(F32, Matmul), Batched: 2*(50x3072 * 3072x768), Duration: 0.007944192, Count: 12 Perc:19.27%
Operation: Pipeline: Binary(F32, BinaryBufferFromBuffer), OP: Add, Duration: 0.002603008, Count: 209 Perc:6.32%
Operation: Pipeline: Matmul64x648x8(F32, Matmul), Batched: 2*(50x768 * 768x3072), Duration: 0.0025763839999999997, Count: 12 Perc:6.25%
Operation: Pipeline: Matmul(F32, Matmul7), Batched: 3*(7x512 * 512x512), Duration: 0.002469888, Count: 48 Perc:5.99%
Operation: Pipeline: Matmul(F32, Matmul7), Batched: 3*(7x2048 * 2048x512), Duration: 0.0020602879999999995, Count: 12 Perc:5.00%
Operation: Pipeline: Unary(F32, UnaryInplaceContiguous), OP: Affine, Duration: 0.00188416, Count: 177 Perc:4.57%
Operation: Pipeline: Reduce(F32, Reduce), , Duration: 0.001780736, Count: 154 Perc:4.32%
Operation: Pipeline: Matmul(F32, Matmul7), Batched: 3*(7x512 * 512x2048), Duration: 0.000992256, Count: 12 Perc:2.41%
...

So i my case the gpu is only working for 41ms, the rest seems to be initializing and communication overhead.
It is quite hard to guess why it is so slow for you.
I guess there might be a shader that is bad compiled or too "big" for your gpu(I have seen matmul 64x64tiled with 8x8 work per thread to be much slower(~60 times) in the browser than native while other shaders (e.g. matmul 64x64, 4x4 work per thread have the same performance).

@KimHenrikOtte
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just analyzed why the performance was 10 times slower than the actual gpu computation time on my machine, half of the time was spent loading the model into the gpu ram.
If I load the clip model first, and then run the same forward pass 2 times in a row, the second call (when everything was fully initialized) is in the order of the gpu computation time:
CPU:

Results for image: examples/stable-diffusion/assets/stable-diffusion-xl.jpg
...
Duration: 278.2339ms
Results for image: examples/stable-diffusion/assets/stable-diffusion-xl.jpg
...
Duration: 273.4635ms

WGPU:

...
Results for image: examples/stable-diffusion/assets/stable-diffusion-xl.jpg
...
Duration: 289.557ms

Results for image: examples/stable-diffusion/assets/stable-diffusion-xl.jpg
...

Duration: 61.3901ms

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @KimHenrikOtte, I ran the clip model again and observed that the GPU version is running a lot faster -> 220 images in 10-11s with compared to 44s on CPU
for the GPU version I noticed that both the (model.forward or model.get_image_features took) around 10 seconds when run serially irrespective of input size with 100% CPU usage and near 0% GPU core running at 200mhz. ie for every model.forward it takes 10 more seconds
could it be because I am running a slow CPU or it loading the model to memory every time I run mode.xyz?

I will take a look at the execution time with the procedure you mentioned above and come back to you.

also I am not sure what happened yesterday that made the default clip/main.rs take 94s but maybe I think I might have put model.forward in a loop.

I would also like to confirm whether running model.forward once would make the model stay in GPU memory and is there a better/recommended way to run models to batch process data?

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also do you have any development roadmap in mind that we can follow so that we would be on the same page while implementing them. I will be able to work in this on Saturdays and Sundays.

and do you want to first concentrate on performance optimization or getting all the features implemented and then start optimizing for performance?

I am thinking if we build out all the features we will have a bigger picture of the problems that we could face that could force us to refactor but I am a bit too early into this project to know if that could even be a problem or not.

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KimHenrikOtte I did little bit of profiling and the duration % looked pretty much identical to yours +/- 5%. the CPU profiling lead me to find out that libvulkan_lvp.so (Vulkan on cpu) was the cause of the CPU usage and no GPU usage even though it was faster than the raw CPU 🧙. It looks like something I need to investigate(i am speculating duplicate images being used could be the cause of it)

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few things that I think are not possible or too complicated/slow to implement.
E.g. quantized matrices, f16 and proper u8 support.

While reading u8 input values can be achieved by reading a u32 and using bit shifting, writing a single u8 to a buffer is not trivial as there would be race conditions between threads.
For best performance, all shaders with u8 would need a custom impl that processes 4 values at once.
There are proposals to add f16 support for wgsl, but I think these are only for calculations inside the shader, loading and writing to a memory still needs 4 byte alignment.

Things currently missing from the wgpu fork:
-I64 and F64
- Implement all nn-implementations (and argsort).
- Add ability for custom shader
- add documentation
- fix wasm-examples (the index.html are only copies from other examples, the settings panel isnt working)

@KimHenrikOtte
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just committed a first version with i64 and f64 support (KimHenrikOtte@2d8518f) with some of your changes merged in.
But the commit still has the following problem:
The current implementation expects the same alignment for all input and output bytes, but some shaders require different types on different inputs:
index_select may use 4-byte aligned index buffers and 8-byte aligned input and output buffers,
where_cond can use a different alignment for condition buffers and input/output buffers.
matmul loads inputs 16-byte aligned, but currently writes dest_buffer with 4-byte alignment.

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's awesome @KimHenrikOtte, I will take a look at it now,

about the missing stuff, I could pick up documentation and fix wasm-examples as the simple first step till I get the hang of the code base.

Getting the documentation done through me could require a little more effort on your part, but it could help clear any misunderstandings I have in the code base.

do you have anything in mind that has to be documented?
do you also want to list out which of the default example models work through webgpu?

can you explain what are your intentions for Add ability for custom shader. Is this supported by Candle or any other accelerated runtime's supporting this?

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the ability for anyone using Candle to create and use their own shaders.
This was suggested in the webGpu candle issues:

For any new backend, it is very important to create a way for USERS to create their own kernel/op. (huggingface#344)

But I do not know how easy this could be integrated. You would need a way to specify your own shaders (e.g. by a string).

In general, the code uses a lot of custom struct and enums for various operations that could be better documented what they do.

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @KimHenrikOtte, I will think about this as well.

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EdupugantiAkhil Yesterday I committed a change to optimise the cache system for growing buffers (the metavoice example showed a huge performance drop because many buffers and bindgroups were created and deleted before).

I got the following results: (CPU: AMD Ryzen 7 5800X 8-Core Processor, GPU : NVIDEA GeForce GTX 1080 Ti)

llama2-c:
cpu: 331.49 token/s
gpu: 291.99 token/s,

clip:
cpu elapsed: 250.5971ms
gpu elapsed: 293.1473ms

stable-diffusion:
cpu: 369.8279709s
gpu 33.4540036s

t5:
cpu: 32 tokens generated (67.43 token/s)
gpu: 32 tokens generated (116.59 token/s)
metavoice:
cpu: 566.7995057s
gpu: 37.9660216s

wuerstchen:
gpu: 148.8636398s
cpu: 986.7038538s

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is super nice @KimHenrikOtte
last week I spent time to get my GPU situation fixed and was able to get this running on Vega7(AMD 4650g) and Adreno 725(Snapdragon 7+ Gen 2) on Vulcan runtime

I also have a modified version of matmul5 that saved 10 sec(from 70s -> 58s) off my Vulkan CPU runtime for Whisper but have not tested it on a real GPU or double checked the results (will do it next week) along with the other stuff

@compute @workgroup_size(RTRS, TS1, 1)
fn matmul5(@builtin(workgroup_id) group_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>) {
    let lx = local_id.x;
    let ly = local_id.y;

    let gx = TS1 * group_id.x + lx;
    let gy = TS1 * group_id.y + ly;
    let batch = group_id.z;

    let output_size_of_one_batch = select(op_matmul_use_batch, op_matmul_m * op_matmul_n, 0u);

    let input1_offset = op_matmul_input1_offset;
    let input2_offset = op_matmul_input2_offset;

    let input1_stride_b = select(op_matmul_use_batch, op_matmul_input1_stride_b, 0u);
    let input2_stride_b = select(op_matmul_use_batch, op_matmul_input2_stride_b, 0u);

    let m_input1_offset = input1_offset + op_matmul_input1_stride_m * gy + batch * input1_stride_b;
    let m_input2_offset = input2_offset + op_matmul_input2_stride_n * gx + batch * input2_stride_b;

    let max_k = op_matmul_k;

    // Shared memory for the tile (rotated)
    var Asub : array<array<DTYPE, TS1>, TS1>;
    var Bsub : array<array<DTYPE, TS1>, TS1>;

    var acc = array<DTYPE, WPT1>();

    for (var t = 0u; t < max_k; t += TS1) {
        // Rotate and load A and B into shared memory with coalesced accesses
        Asub[ly][lx] = v_input1[m_input1_offset + ly * op_matmul_input1_stride_k + t + lx];
        Bsub[lx][ly] = v_input2[m_input2_offset + lx * op_matmul_input2_stride_k + t + ly];

        workgroupBarrier();

        // Perform the matrix multiplication on the shared memory tiles
        for (var k = 0u; k < TS1; k++) {
            let a = Asub[ly][k];
            for (var w = 0u; w < WPT1; w++) {
                acc[w] += a * Bsub[k][lx + w];
            }
        }
        workgroupBarrier();
    }

Total sum for Duration: 57.10677721699996
Operations sorted by total duration for Duration:
Operation: Pipeline: Matmul32x32(F32, Matmul), (1500x384 * 384x384), Duration: 9.545818492, Count: 24 Perc:16.72%
Operation: Pipeline: Matmul32x32(F32, Matmul), Batched: 6*(1500x64 * 64x1500), Duration: 7.764473577, Count: 4 Perc:13.60%
Operation: Pipeline: Matmul32x32(F32, Matmul), (1500x384 * 384x1536), Duration: 6.312070591, Count: 4 Perc:11.05%
Operation: Pipeline: Matmul32x32(F32, Matmul), (1500x1536 * 1536x384), Duration: 6.057844635, Count: 4 Perc:10.61%
Operation: Pipeline: Matmul32x32(F32, Matmul), Batched: 6*(1500x1500 * 1500x64), Duration: 6.005425084, Count: 4 Perc:10.52%
Operation: Pipeline: Conv1d(F32, Conv1d), , Duration: 4.62934659, Count: 2 Perc:8.11%
Operation: Pipeline: Unary(F32, UnaryFromBuffer), OP: Affine, Duration: 2.879218753, Count: 232 Perc:5.04%
Operation: Pipeline: Matmul(F32, Matmul116), (1x384 * 384x51864), Duration: 1.1686291449999997, Count: 14 Perc:2.05%
Operation: Pipeline: Copy(F32, Copy3d), , Duration: 1.026695981, Count: 524 Perc:1.80%
Operation: Pipeline: Binary(F32, BinaryBufferFromBuffer), OP: Add, Duration: 0.669787452, Count: 666 Perc:1.17%
Operation: Pipeline: Unary(F32, UnaryInplaceContiguous), OP: Gelu, Duration: 0.660622716, Count: 62 Perc:1.16%
Operation: Pipeline: Softmax(F32, Softmax), 1500x9000(0), Duration: 0.647026526, Count: 4 Perc:1.13%
Operation: Pipeline: Copy(F32, Copy3dPadded), , Duration: 0.643478292, Count: 72 Perc:1.13%
Operation: Pipeline: Unary(F32, UnaryFromBufferContiguous), OP: Square, Duration: 0.360263681, Count: 191 Perc:0.63%
Operation: Pipeline: Matmul(F32, Matmul7), (3x384 * 384x51864), Duration: 0.273951141, Count: 1 Perc:0.48%
Operation: Pipeline: Copy(F32, Copy3dPaddedNobatch), , Duration: 0.248316817, Count: 256 Perc:0.43%
Operation: Pipeline: Matmul32x32(F32, Matmul), (4x384 * 384x384), Duration: 0.211906774, Count: 24 Perc:0.37%
Operation: Pipeline: Matmul32x32(F32, Matmul), (12x384 * 384x384), Duration: 0.207058092, Count: 24 Perc:0.36%
Operation: Pipeline: Matmul32x32(F32, Matmul), (16x384 * 384x384), Duration: 0.20681412300000002, Count: 24 Perc:0.36%
Operation: Pipeline: Matmul32x32(F32, Matmul), (10x384 * 384x384), Duration: 0.206364767, Count: 24 Perc:0.36%
Operation: Pipeline: Matmul32x32(F32, Matmul), (6x384 * 384x384), Duration: 0.20555873, Count: 24 Perc:0.36%
Operation: Pipeline: Matmul32x32(F32, Matmul), (14x384 * 384x384), Duration: 0.20517918500000001, Count: 24 Perc:0.36%
Total sum for Duration: 70.45364361499996
Operations sorted by total duration for Duration:
Operation: Pipeline: Matmul32x32(F32, Matmul), (1500x384 * 384x384), Duration: 9.507843548, Count: 24 Perc:13.50%
Operation: Pipeline: Matmul32x32(F32, Matmul), Batched: 6*(1500x64 * 64x1500), Duration: 7.818339483, Count: 4 Perc:11.10%
Operation: Pipeline: Matmul32x32(F32, Matmul), (1500x384 * 384x1536), Duration: 6.366360102, Count: 4 Perc:9.04%
Operation: Pipeline: Matmul32x32(F32, Matmul), (1500x1536 * 1536x384), Duration: 6.118251775, Count: 4 Perc:8.68%
Operation: Pipeline: Matmul32x32(F32, Matmul), Batched: 6*(1500x1500 * 1500x64), Duration: 6.014527578, Count: 4 Perc:8.54%
Operation: Pipeline: Unary(F32, UnaryFromBuffer), OP: Affine, Duration: 4.981581751999999, Count: 408 Perc:7.07%
Operation: Pipeline: Conv1d(F32, Conv1d), , Duration: 4.538942967, Count: 2 Perc:6.44%
Operation: Pipeline: Matmul(F32, Matmul116), (1x384 * 384x51864), Duration: 2.09869979, Count: 25 Perc:2.98%
Operation: Pipeline: Copy(F32, Copy3d), , Duration: 1.376092632, Count: 1064 Perc:1.95%
Operation: Pipeline: Copy(F32, Copy3dPadded), , Duration: 0.8097130600000001, Count: 160 Perc:1.15%
Operation: Pipeline: Binary(F32, BinaryBufferFromBuffer), OP: Add, Duration: 0.80579728, Count: 1161 Perc:1.14%
Operation: Pipeline: Unary(F32, UnaryInplaceContiguous), OP: Gelu, Duration: 0.753349769, Count: 106 Perc:1.07%
Operation: Pipeline: Softmax(F32, Softmax), 1500x9000(0), Duration: 0.73619247, Count: 4 Perc:1.04%
Operation: Pipeline: Unary(F32, UnaryFromBufferContiguous), OP: Square, Duration: 0.45023639000000004, Count: 334 Perc:0.64%
Operation: Pipeline: Copy(F32, Copy3dPaddedNobatch), , Duration: 0.322404952, Count: 576 Perc:0.46%
Operation: Pipeline: Matmul(F32, Matmul7), (3x384 * 384x51864), Duration: 0.27491778, Count: 1 Perc:0.39%
Operation: Pipeline: Matmul32x32(F32, Matmul), (25x384 * 384x384), Duration: 0.252382766, Count: 24 Perc:0.36%
Operation: Pipeline: Matmul32x32(F32, Matmul), (12x384 * 384x384), Duration: 0.22063252100000003, Count: 24 Perc:0.31%
Operation: Pipeline: Matmul32x32(F32, Matmul), (20x384 * 384x384), Duration: 0.21812464899999998, Count: 24 Perc:0.31%
Operation: Pipeline: Matmul32x32(F32, Matmul), (26x384 * 384x384), Duration: 0.21541983000000003, Count: 24 Perc:0.31%

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, just for documentation, so we are on the same page:

in matmul.pwgsl are mostly early versions of matmul alg:

  • matmul1: naive impl (may be the fastest for small matrix multiplications)

  • matmul1_16: naive impl with 4 simultaneous loads (only works when k_stride is 1)

  • matmul1_end : artifact of previous test, idea was to use a tiled shader for the output first, this version can then calculate the edge of the matrix multiplication, but other shaders (e.g. matmul7 seems to be faster in this case)

  • matmul5: (based on https://cnugteren.github.io/tutorial/pages/page7.html) with a WorkPerThread of 8 and 16x16 tiles. (but still no rectangular tiles and wider data loads)

  • matmul7: (is a modified version of matmul5 that supports edge cases (e.g. input matrices need not be divisible by 16). Performance measurements show that in some cases it is faster to use matmul7 instead of padding input matrices and using a matmul5 based shader.

Under "sgemm/matmul..." there are more sophisticated shader models:
They all use the same MatmulHelper.pwgsl base functions with different sets of defines.
They use wider data loads (width=4), rectangular tiles, WorkPerThreads in both dimensions (e.g. each thread could compute 4x4 values, theoretically prefatching, but this did not show any performance gain).

  • matmul1x64b.pwgsl uses a special shader without using shared memory, where the k-sum is divided into several threads and the partial results are summed afterwards. (This may be faster if the input k-stride is one, so that aligned threads can load coalesced memory.

Also, your matmul5 impl seems to be wrong:

     // shared memory for tile (rotated)
     var Asub : array<array<DTYPE, TS1>, TS1>;

Asub here is a local variable and not shared memory in the workgroup.
Also, I think whether memory can be coalesced depends on the input stride.

In addition, I think you are not loading enough values:

     Asub[ly][lx] = v_input1[m_input1_offset + ly * op_matmul_input1_stride_k + t + lx];

For example, if you have 16x16 tiles, and Work per thread x=8, Wpty=1, then you have 2x16 threads responsible for the 16x16 tile.
To load the 16x16 values of the tile, each thread must also load 8 values here.

Perhaps it makes more sense to use a sgemm/matmul16x16 version instead of matmul5 if the input is not divisible by 32 but by 16. (if that is faster)

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback @KimHenrikOtte, I hope I didn't scare you too much with my code 😅
Also, I would like to ask you what resources I should look at to get my fundamentals up as I lack the knowledge even to give you a proper reply
maybe like going through https://sotrh.github.io/learn-wgpu/ , https://cnugteren.github.io/tutorial/, or any other stuff?

Can you give me recommendations from the perspective of a beginner or If I were to learn it again how would I have done it

I also want to let you know that I really appreciate your effort Thanks again.

@KimHenrikOtte
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem,

I think the wgpu-learn tutorial is mostly aimed at rendering and not computing shaders.
Traditionally for rendering you would create vertex buffers, then call vertex and fragment shaders to transform vertices and compute pixel data.
Apart from this "rendering" pipeline, there is also a "compute" pipeline to handle arbitrary data (the WGPU fork only uses this architecture).

The "Tutorial: OpenCL SGEMM tuning for Kepler" is a great tutorial for understanding the principles of optimizing Matmul for GPUs (although it uses OpenCL, the principles are 100% the same for WGSL).
For visualizations and understanding how compute shaders divide work into multiple threads, I liked this post: "https://webgpufundamentals.org/webgpu/lessons/webgpu-compute-shaders.html"

For a tutorial on how the WGSL syntax works, "https://google.github.io/tour-of-wgsl/" has short descriptions and examples of most of the operations.

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EdupugantiAkhil I had also looked at the article: “https://siboehm.com/articles/22/CUDA-MMM ” (optimizes matmul in cuda and has tons of nice visualization)

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @KimHenrikOtte I will go through them, also feel free to share any recourses you find interesting.

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EdupugantiAkhil I just committed some improvements for the matrix multiplication:
I optimized the shader parameters like WPTN, WPTM etc. for better performance (I just tested a few different parameters with the matmul benchmark and looked for the best ones).

The best selection I found was the following with a throughput of 3161.4 GiB/s (which is a little bit better than the best 64x648x8 I found so far which only got about 2500 GiB/s)

2048x2048 * 2048x2048:
#define TSM 32u // The tile size in dimension M
#define TSN 64u // The tile size in dimension N
#define TSK 4u // The tile size in dimension K
#define WPTM 32u // The work per thread in M units
#define WPTN 2u // The work per thread in dimension N
#define WIDTHA 4u
#define WIDTHB 2u

Unfortunately, when matrix B is transposed (matrix B k-stride = 1), the performance is much worse (only 336.65 GiB/s.).
I dont really understand why, but I found that when TSK = 8, the shader is a little slower overall, but when B is transposed, I measured again around 3000 GiB/s.
So the current impl will use 32x64b(for a transposed b-matrix) and 32x64 otherwise.

llama2-c:
gpu: 292.49 token/s

clip:
gpu elapsed: 387.5123ms

stable-diffusion:
gpu 23.9633768s

t5:
gpu: 32 tokens generated (124.99 token/s)

metavoice:
gpu: 30.6853641s

wuerstchen:
gpu:92.0862131ss

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @KimHenrikOtte, I just skimmed through the diffs, and I will check how the code impacts performance on the hardware I have.
It might take some time as I have some personal work that is going to keep me busy till the end of this month. I hope it didn't/doesn't cause too many problems for you. Sorry about the short notice.

https://youtu.be/QQceTDjA4f4?si=3oXrPYOKU2h9wSdK might be able to help (they talk about memory starting at the 7 min mark but it's worth a watch if you already haven't or you might already know about this)

I was also able to go through the resources you gave me. I still feel a little foggy about the webgpu stuff, but it will get clearer more I use it.

@EdupugantiAkhil
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @KimHenrikOtte,

I apologize for the abrupt pause in communication. I've now wrapped up my personal commitments and am ready to resume my contributions to the project. do you have anything in mind that I could help you with?

I can go through the new changes you have made and try out some of your WASM examples and maybe try my luck with fixing MutexGuard across an await points.

If possible, I’d prefer to also contact you through other means, such as Discord or email. You can reach out to me at akhile76@gmail.com.

@KimHenrikOtte
Copy link

@KimHenrikOtte KimHenrikOtte commented on 71fdc57 Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem,
I don't have much time to work on the branch at the moment.

But I think the current branch "KimHenrikOtte/candle at wgpu_cleanup" is at least in a state where it could be merged with candle.
(At least as a "beta" feature, so you can try it out).

Still, there are a few things one might want to improve/implement:

  • candle-core/asort impl (I tried it in the beginning with chatgpt, but I was confused about the implementation, but I think this should be doable (e.g. see candle-metal-kernels/src/sort.metal).

    • re-enable asort test
  • Candle-nn/(ropei, rope, rope_thd) impl

    • re-enable test
  • clippy warnings

  • Flash-attention?

    • I do not know if this is possible,
      But I think you would at least need a different bindgroup layout. (e.g. more than 3 inputs).
      Which would probably result in more than one bindgroup for operation.
  • Performance:

    • You could probably optimise the copmute graph further. (all calculations are collected and sent together to the GPU).
      Currently there are 3 implace optimisations (to reuse the buffer for unary, binary or copy operations if this is the last time the buffer is used).
    • There may be optimisations for combinations of copy and binary/unary with the var field:
      new_value = a + b;
      copy new_value to current_value;
      a + b could write directly to current_value instead of new_value
    • You can rearrange the pipelines to improve buffer usage and reuse.
    • Buffer reuse can be improved (currently buffers are allocated from start to end)
      • In the worst case, each buffer is slightly larger than the previous one, resulting in zero buffer reuse.
      • It may be better to allocate the largest buffer needed first and go from there, but doing this in a performant way is not trivial.
    • compare matmul-algs at runtime (at device creation, we may test some matmul-algs to find the best matmul-algs for the hardware)

Please sign in to comment.