diff --git a/candle-core/src/wgpu_backend/cache.rs b/candle-core/src/wgpu_backend/cache.rs index 2a5318fd38..25e09dbe52 100644 --- a/candle-core/src/wgpu_backend/cache.rs +++ b/candle-core/src/wgpu_backend/cache.rs @@ -1,22 +1,28 @@ -use std::{collections::{BTreeSet, HashMap}, num::NonZeroU64, u32}; +use std::{ + collections::{BTreeSet, HashMap}, + num::NonZeroU64, + u32, +}; use wgpu::BindGroupLayoutDescriptor; use crate::wgpu_backend::util::StorageTrait; -use super::{util::{Reference, ReferenceTrait, Storage, StorageOptional, ToU64}, WgpuDevice}; use super::{ device::PipelineType, util::{FixedSizeQueue, HashMapMulti}, wgpu_functions, }; +use super::{ + util::{Reference, ReferenceTrait, Storage, StorageOptional, ToU64}, + WgpuDevice, +}; //time = 0 is undefined // pub type BufferReferenceId = Reference; // pub type CachedBufferId = Reference; // pub type CachedBindgroupId = Reference; - #[derive(Debug, PartialEq, Eq, Hash, Clone, std::marker::Copy)] pub struct BufferReferenceId(Reference); #[derive(Debug, PartialEq, Eq, Hash, Clone, std::marker::Copy)] @@ -24,9 +30,8 @@ pub struct CachedBufferId(Reference); #[derive(Debug, PartialEq, Eq, Hash, Clone, std::marker::Copy)] pub struct CachedBindgroupId(Reference); - -impl ReferenceTrait for BufferReferenceId{ - fn new(id : u32, time : u32) -> Self { +impl ReferenceTrait for BufferReferenceId { + fn new(id: u32, time: u32) -> Self { Self(Reference::new(id, time)) } @@ -39,8 +44,8 @@ impl ReferenceTrait for BufferReferenceId{ } } -impl ReferenceTrait for CachedBufferId{ - fn new(id : u32, time : u32) -> Self { +impl ReferenceTrait for CachedBufferId { + fn new(id: u32, time: u32) -> Self { Self(Reference::new(id, time)) } @@ -53,8 +58,8 @@ impl ReferenceTrait for CachedBufferId{ } } -impl ReferenceTrait for CachedBindgroupId{ - fn new(id : u32, time : u32) -> Self { +impl ReferenceTrait for CachedBindgroupId { + fn new(id: u32, time: u32) -> Self { Self(Reference::new(id, time)) } @@ -67,18 +72,18 @@ impl ReferenceTrait for CachedBindgroupId{ } } - - #[derive(Debug)] pub(crate) struct BindgroupLayouts { pub bind_group_layout0: wgpu::BindGroupLayout, pub bind_group_layout1: wgpu::BindGroupLayout, + pub bind_group_layout1_8: wgpu::BindGroupLayout, pub bind_group_layout1_16: wgpu::BindGroupLayout, pub bind_group_layout2: wgpu::BindGroupLayout, pub bind_group_layout2_16: wgpu::BindGroupLayout, //for matmul, input buffer may be vec4 pub bind_group_layout3: wgpu::BindGroupLayout, pub pipeline_layout0: wgpu::PipelineLayout, pub pipeline_layout1: wgpu::PipelineLayout, + pub pipeline_layout1_8: wgpu::PipelineLayout, pub pipeline_layout1_16: wgpu::PipelineLayout, pub pipeline_layout2: wgpu::PipelineLayout, pub pipeline_layout2_16: wgpu::PipelineLayout, //for matmul, input buffer may be vec4 @@ -98,6 +103,17 @@ impl BindgroupLayouts { count: None, }; + let dest_entry_8 = wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: Some(NonZeroU64::new(8).unwrap()), + }, + count: None, + }; + let dest_entry_16 = wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, @@ -130,7 +146,18 @@ impl BindgroupLayouts { }, count: None, }; - + + let input1_entry_8 = wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: Some(NonZeroU64::new(8).unwrap()), + }, + count: None, + }; + let input1_entry_16 = wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, @@ -164,8 +191,6 @@ impl BindgroupLayouts { count: None, }; - - let input3_entry = wgpu::BindGroupLayoutEntry { binding: 4, visibility: wgpu::ShaderStages::COMPUTE, @@ -177,9 +202,6 @@ impl BindgroupLayouts { count: None, }; - - - let bind_group_layout0 = dev.create_bind_group_layout(&BindGroupLayoutDescriptor { label: None, entries: &[dest_entry, meta_entry], @@ -188,6 +210,10 @@ impl BindgroupLayouts { label: None, entries: &[dest_entry, meta_entry, input1_entry], }); + let bind_group_layout1_8 = dev.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: None, + entries: &[dest_entry_8, meta_entry, input1_entry], + }); let bind_group_layout1_16 = dev.create_bind_group_layout(&BindGroupLayoutDescriptor { label: None, entries: &[dest_entry_16, meta_entry, input1_entry_16], @@ -221,6 +247,11 @@ impl BindgroupLayouts { bind_group_layouts: &[&bind_group_layout1], push_constant_ranges: &[], }); + let pipeline_layout1_8 = dev.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&bind_group_layout1_8], + push_constant_ranges: &[], + }); let pipeline_layout1_16 = dev.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: None, bind_group_layouts: &[&bind_group_layout1_16], @@ -245,12 +276,14 @@ impl BindgroupLayouts { Self { bind_group_layout0, bind_group_layout1, + bind_group_layout1_8, bind_group_layout1_16, bind_group_layout2, bind_group_layout2_16, bind_group_layout3, pipeline_layout0, pipeline_layout1, + pipeline_layout1_8, pipeline_layout1_16, pipeline_layout2, pipeline_layout2_16, @@ -259,105 +292,101 @@ impl BindgroupLayouts { } } - - - - - - - - - - - ////////////////// BUFFER REFERENCE: - - /// Virtual Buffer, used in Compute Graph #[derive(Debug)] pub struct BufferReference { size: u64, - referenced_by_candle_storage : bool, - cached_buffer_id : CachedBufferId, - first_used : u32, - last_used : u32, //u32::max means indefitly + referenced_by_candle_storage: bool, + cached_buffer_id: CachedBufferId, + first_used: u32, + last_used: u32, //u32::max means indefitly } - -impl BufferReference{ - pub fn new(size: u64, referenced_by_candle_storage : bool) -> Self { - Self { size, cached_buffer_id : CachedBufferId::new(0,0),referenced_by_candle_storage, first_used : 0, last_used : if referenced_by_candle_storage {u32::MAX} else {0} } +impl BufferReference { + pub fn new(size: u64, referenced_by_candle_storage: bool) -> Self { + Self { + size, + cached_buffer_id: CachedBufferId::new(0, 0), + referenced_by_candle_storage, + first_used: 0, + last_used: if referenced_by_candle_storage { + u32::MAX + } else { + 0 + }, + } } - pub fn new_with_storage(size: u64, cached_buffer_id: CachedBufferId, referenced_by_candle_storage : bool) -> Self { - Self { size, cached_buffer_id, referenced_by_candle_storage, first_used : 0, last_used : if referenced_by_candle_storage {u32::MAX} else {0}} + pub fn new_with_storage( + size: u64, + cached_buffer_id: CachedBufferId, + referenced_by_candle_storage: bool, + ) -> Self { + Self { + size, + cached_buffer_id, + referenced_by_candle_storage, + first_used: 0, + last_used: if referenced_by_candle_storage { + u32::MAX + } else { + 0 + }, + } } - + pub fn size(&self) -> u64 { self.size } - + pub fn set_cached_buffer_id(&mut self, cached_buffer_id: CachedBufferId) { self.cached_buffer_id = cached_buffer_id; } - + pub fn cached_buffer_id(&self) -> &CachedBufferId { &self.cached_buffer_id } - + pub fn referenced_by_candle_storage(&self) -> bool { self.referenced_by_candle_storage } - + pub fn set_referenced_by_candle_storage(&mut self, referenced_by_candle_storage: bool) { self.referenced_by_candle_storage = referenced_by_candle_storage; } - + pub fn first_used(&self) -> u32 { self.first_used } - + pub fn set_first_used(&mut self, first_used: u32) { self.first_used = first_used; } - + pub fn last_used(&self) -> u32 { self.last_used } - + pub fn set_last_used(&mut self, last_used: u32) { self.last_used = last_used; } } - - - - - - - - - - - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum BindgroupInputBase { - Bindgroup0, // - Bindgroup1(T, bool), //input1 - Bindgroup2(T, T, bool), //input1, input2, is_16 - Bindgroup3(T, T, T), //input1, input2, input3 + Bindgroup0, // + Bindgroup1(T, bool, bool), //input1, is_16, is_8 + Bindgroup2(T, T, bool), //input1, input2, is_16 + Bindgroup3(T, T, T), //input1, input2, input3 } - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct BindgroupFullBase(T, BindgroupInputBase); - impl BindgroupFullBase { - - pub(crate) fn new(dest : T, input : BindgroupInputBase) -> Self{ + pub(crate) fn new(dest: T, input: BindgroupInputBase) -> Self { return BindgroupFullBase(dest, input); } @@ -375,99 +404,109 @@ pub type CachedBindgroupFull = BindgroupFullBase; pub type BindgroupReferenceInput = BindgroupInputBase; pub type BindgroupReferenceFull = BindgroupFullBase; - - - - ////////////////// CACHED BUFFER: - #[derive(Debug)] -pub struct CachedBuffer{ - buffer : wgpu::Buffer, +pub struct CachedBuffer { + buffer: wgpu::Buffer, //stored_free : bool, //wheter this buffer was free at the beginning to the queue - is_free : bool, //wheter this buffer is currently free - last_used_counter : u32, + is_free: bool, //wheter this buffer is currently free + last_used_counter: u32, //used_memory : u64, //the total memory this buffer was unsed for. Together with usage_counter we get the average buffer size, this buffer is used for } - - impl CachedBuffer { pub fn new(buffer: wgpu::Buffer) -> Self { - Self { buffer, is_free : false, last_used_counter: 0}//stored_free : false, used_memory : 0 } + Self { + buffer, + is_free: false, + last_used_counter: 0, + } //stored_free : false, used_memory : 0 } } - + pub fn buffer(&self) -> &wgpu::Buffer { &self.buffer } - + pub fn is_free(&self) -> bool { self.is_free } } #[derive(Debug)] -pub struct CachedBindgroup{ - bindgroup : wgpu::BindGroup, - buffer : CachedBindgroupFull +pub struct CachedBindgroup { + bindgroup: wgpu::BindGroup, + buffer: CachedBindgroupFull, } impl CachedBindgroup { pub fn new(bindgroup: wgpu::BindGroup, buffer: CachedBindgroupFull) -> Self { Self { bindgroup, buffer } } - + pub fn bindgroup(&self) -> &wgpu::BindGroup { &self.bindgroup } - + pub(crate) fn buffer(&self) -> &CachedBindgroupFull { &self.buffer } } - - - #[derive(Debug)] pub struct ModelCache { - pub(crate) buffer_reference : BufferReferenceStorage, + pub(crate) buffer_reference: BufferReferenceStorage, pub(crate) buffers: BufferCacheStorage, pub(crate) bindgroups: BindgroupCacheStorage, pub(crate) mappings: BufferMappingCache, } impl ModelCache { - pub fn new(mapping_size : u32) -> Self { + pub fn new(mapping_size: u32) -> Self { Self { - buffer_reference : BufferReferenceStorage::new(), + buffer_reference: BufferReferenceStorage::new(), buffers: BufferCacheStorage::new(), bindgroups: BindgroupCacheStorage::new(), mappings: BufferMappingCache::new(mapping_size), } } - pub fn create_buffer_reference(&mut self, size: T, referenced_by_candle_storage : bool) -> BufferReferenceId{ - let buffer_reference = BufferReference::new(size.to_u64(),referenced_by_candle_storage); + pub fn create_buffer_reference( + &mut self, + size: T, + referenced_by_candle_storage: bool, + ) -> BufferReferenceId { + let buffer_reference = BufferReference::new(size.to_u64(), referenced_by_candle_storage); return self.buffer_reference.insert(buffer_reference); } - pub fn create_buffer_reference_init(&mut self,dev: &WgpuDevice, data: &[T], referenced_by_candle_storage : bool) -> BufferReferenceId{ + pub fn create_buffer_reference_init( + &mut self, + dev: &WgpuDevice, + data: &[T], + referenced_by_candle_storage: bool, + ) -> BufferReferenceId { let data = bytemuck::cast_slice(data); - let buffer = self.buffers.search_buffer(dev, data.len() as u64, 0,u32::MAX - 1); //TODO use exact size? - dev.queue.write_buffer(&self.buffers.get_buffer(&buffer).unwrap().buffer, 0, data); - - let buffer_reference = BufferReference::new_with_storage(data.len() as u64, buffer, referenced_by_candle_storage); + let buffer = self + .buffers + .search_buffer(dev, data.len() as u64, 0, u32::MAX - 1); //TODO use exact size? + dev.queue + .write_buffer(&self.buffers.get_buffer(&buffer).unwrap().buffer, 0, data); + + let buffer_reference = BufferReference::new_with_storage( + data.len() as u64, + buffer, + referenced_by_candle_storage, + ); return self.buffer_reference.insert(buffer_reference); } - /// returns, wheter we should stop the command_queue and delete not used buffers + /// returns, wheter we should stop the command_queue and delete not used buffers pub fn should_delete_unused(&mut self) -> bool { let current_memory = self.buffers.buffer_memory; let memory_margin = self.buffers.max_memory_allowed; - if current_memory > memory_margin{ + if current_memory > memory_margin { return !self.buffers.order.is_empty(); } return false; @@ -478,57 +517,76 @@ impl ModelCache { let memory_margin = self.buffers.max_memory_allowed; let delete_until_margin = (self.buffers.max_memory_allowed * 4) / 5; - //remove buffers, that + //remove buffers, that // 1. were not used for a long time // 2. have a big memory diff (the actual buffer size vs the average size the buffer is used with) let mut check_bindgroups = false; - if current_memory > memory_margin{ - log::debug!("deleting buffers: ({}) current {current_memory}/{memory_margin}",self.buffers.storage.len()); - - //every entry in self.buffers.order will be free and can be potentially deleted - //this is ordered from small to big. - let buffers : Vec<_> = self.buffers.order.iter().map(|entry|{ - let (id, val) = self.buffers.storage.get_reference(entry.index).expect("item in order, that could ne be found in storage"); - return (id, val.last_used_counter); - }).collect(); + if current_memory > memory_margin { + log::debug!( + "deleting buffers: ({}) current {current_memory}/{memory_margin}", + self.buffers.storage.len() + ); - for (id, _) in buffers{ + //every entry in self.buffers.order will be free and can be potentially deleted + //this is ordered from small to big. + let buffers: Vec<_> = self + .buffers + .order + .iter() + .map(|entry| { + let (id, val) = self + .buffers + .storage + .get_reference(entry.index) + .expect("item in order, that could ne be found in storage"); + return (id, val.last_used_counter); + }) + .collect(); + + for (id, _) in buffers { check_bindgroups = true; self.buffers.delete_buffer(&id); - if self.buffers.buffer_memory <= delete_until_margin{ + if self.buffers.buffer_memory <= delete_until_margin { break; //deleted enaugh } } let current_memory = self.buffers.buffer_memory; - log::debug!("after deleting: ({}) current {current_memory}/{}",self.buffers.storage.len(),self.buffers.max_memory_allowed); + log::debug!( + "after deleting: ({}) current {current_memory}/{}", + self.buffers.storage.len(), + self.buffers.max_memory_allowed + ); } //remove bindgroups: //1. if we removed a buffer, we should also remove the bindgroup //2. bindgroups that werent used for a long time may be deleted - if check_bindgroups{ - self.bindgroups.retain_bindgroups(|bindgroup | - { + if check_bindgroups { + self.bindgroups.retain_bindgroups(|bindgroup| { let check_buffer = |buffer_reference| { return self.buffers.get_buffer(buffer_reference).is_some(); }; - - let is_valid = check_buffer(bindgroup.buffer.get_dest()) && - match &bindgroup.buffer.get_input() { - BindgroupInputBase::Bindgroup0 => true, - BindgroupInputBase::Bindgroup1(v1, _) => check_buffer(v1), - BindgroupInputBase::Bindgroup2(v1, v2, _) => check_buffer(v1) && check_buffer(v2), - BindgroupInputBase::Bindgroup3(v1, v2, v3) => check_buffer(v1) && check_buffer(v2) && check_buffer(v3), - }; - + + let is_valid = check_buffer(bindgroup.buffer.get_dest()) + && match &bindgroup.buffer.get_input() { + BindgroupInputBase::Bindgroup0 => true, + BindgroupInputBase::Bindgroup1(v1, _, _) => check_buffer(v1), + BindgroupInputBase::Bindgroup2(v1, v2, _) => { + check_buffer(v1) && check_buffer(v2) + } + BindgroupInputBase::Bindgroup3(v1, v2, v3) => { + check_buffer(v1) && check_buffer(v2) && check_buffer(v3) + } + }; + //check if all buffers for this bindgroup still exist! - if !is_valid{ + if !is_valid { return false; } - return true; + return true; }); } return false; @@ -539,61 +597,95 @@ impl ModelCache { dev: &WgpuDevice, bindgroup_reference: &BindgroupReferenceFull, pipeline: PipelineType, - command_id : u32 + command_id: u32, ) -> CachedBindgroupId { - - fn check_buffer_reference(cache : &mut ModelCache, bindgroup_reference: &BindgroupReferenceFull, pipeline: PipelineType,){ - + fn check_buffer_reference( + cache: &mut ModelCache, + bindgroup_reference: &BindgroupReferenceFull, + pipeline: PipelineType, + ) { let check_buffer = |buffer_reference_id| { - if let Some(buffer_reference) = cache.buffer_reference.get(buffer_reference_id){ - if !buffer_reference.cached_buffer_id.is_valid(){ + if let Some(buffer_reference) = cache.buffer_reference.get(buffer_reference_id) { + if !buffer_reference.cached_buffer_id.is_valid() { panic!("input buffer {:?}({:?}) in {:?} had no cached_storage set for pipeline {:?}", buffer_reference,buffer_reference_id, bindgroup_reference, pipeline); - } - else{ - if cache.buffers.get_buffer(&buffer_reference.cached_buffer_id).is_none(){ - if let Some(buffer_reference) = cache.buffer_reference.get(buffer_reference_id){ + } else { + if cache + .buffers + .get_buffer(&buffer_reference.cached_buffer_id) + .is_none() + { + if let Some(buffer_reference) = + cache.buffer_reference.get(buffer_reference_id) + { panic!("input buffer {:?}({:?}) in {:?} had no cached_storage set to {:?} widch could not be found for pipeline {:?}", buffer_reference,buffer_reference_id, bindgroup_reference,buffer_reference.cached_buffer_id, pipeline); } } } - } - else{ - if let Some(val) = cache.buffer_reference.get_reference(buffer_reference_id.id()){ + } else { + if let Some(val) = cache + .buffer_reference + .get_reference(buffer_reference_id.id()) + { panic!("Reference {:?} inside Bindgroup {:?} invalid for pipeline {:?}, Reference was replaced, current: {:?}", buffer_reference_id, bindgroup_reference, pipeline, val.0); - } - else{ + } else { panic!("Reference {:?} inside Bindgroup {:?} invalid for pipeline {:?} (Reference was deleted)", buffer_reference_id, bindgroup_reference, pipeline); - } + } } - }; match &bindgroup_reference.1 { BindgroupInputBase::Bindgroup0 => BindgroupInputBase::Bindgroup0, - BindgroupInputBase::Bindgroup1(v1, is_16) => BindgroupInputBase::Bindgroup1(check_buffer(v1), *is_16), - BindgroupInputBase::Bindgroup2(v1, v2, is_16) => BindgroupInputBase::Bindgroup2(check_buffer(v1),check_buffer(v2), *is_16), - BindgroupInputBase::Bindgroup3(v1, v2, v3) => BindgroupInputBase::Bindgroup3(check_buffer(v1),check_buffer(v2),check_buffer(v3)), + BindgroupInputBase::Bindgroup1(v1, is_16, is_8) => { + BindgroupInputBase::Bindgroup1(check_buffer(v1), *is_16, *is_8) + } + BindgroupInputBase::Bindgroup2(v1, v2, is_16) => { + BindgroupInputBase::Bindgroup2(check_buffer(v1), check_buffer(v2), *is_16) + } + BindgroupInputBase::Bindgroup3(v1, v2, v3) => BindgroupInputBase::Bindgroup3( + check_buffer(v1), + check_buffer(v2), + check_buffer(v3), + ), }; } check_buffer_reference(self, bindgroup_reference, pipeline.clone()); - fn get_storage(cache : &ModelCache, id: &BufferReferenceId) -> CachedBufferId { - cache.buffer_reference.get(&id).unwrap().cached_buffer_id.clone() + fn get_storage(cache: &ModelCache, id: &BufferReferenceId) -> CachedBufferId { + cache + .buffer_reference + .get(&id) + .unwrap() + .cached_buffer_id + .clone() } fn get_buffer_referece_key( - cache : &ModelCache, + cache: &ModelCache, dest_buffer: CachedBufferId, bindgroup_reference: &BindgroupReferenceFull, ) -> CachedBindgroupFull { - return BindgroupFullBase(dest_buffer, - match &bindgroup_reference.1 { - BindgroupInputBase::Bindgroup0 => BindgroupInputBase::Bindgroup0, - BindgroupInputBase::Bindgroup1(v1, is_16) => BindgroupInputBase::Bindgroup1(get_storage(cache, v1), *is_16), - BindgroupInputBase::Bindgroup2(v1, v2, is_16) => BindgroupInputBase::Bindgroup2(get_storage(cache, v1),get_storage(cache, v2), *is_16), - BindgroupInputBase::Bindgroup3(v1, v2, v3) => BindgroupInputBase::Bindgroup3(get_storage(cache, v1),get_storage(cache, v2),get_storage(cache, v3)), - }); + return BindgroupFullBase( + dest_buffer, + match &bindgroup_reference.1 { + BindgroupInputBase::Bindgroup0 => BindgroupInputBase::Bindgroup0, + BindgroupInputBase::Bindgroup1(v1, is_16, is_8) => { + BindgroupInputBase::Bindgroup1(get_storage(cache, v1), *is_16, *is_8) + } + BindgroupInputBase::Bindgroup2(v1, v2, is_16) => { + BindgroupInputBase::Bindgroup2( + get_storage(cache, v1), + get_storage(cache, v2), + *is_16, + ) + } + BindgroupInputBase::Bindgroup3(v1, v2, v3) => BindgroupInputBase::Bindgroup3( + get_storage(cache, v1), + get_storage(cache, v2), + get_storage(cache, v3), + ), + }, + ); } let buf_dest_id = bindgroup_reference.get_dest(); @@ -605,41 +697,52 @@ impl ModelCache { buf_dest_cached_id = buf_dest_reference.cached_buffer_id.clone(); required_size = buf_dest_reference.size; buf_dest_length = buf_dest_reference.last_used - buf_dest_reference.first_used; - if buf_dest_reference.last_used < buf_dest_reference.first_used{ + if buf_dest_reference.last_used < buf_dest_reference.first_used { panic!("buffer {:?}({:?})", buf_dest_reference, buf_dest_id); } } if dev.configuration.use_cache { //the destination buffer of this bindgroup already has a buffer set - if buf_dest_cached_id.is_valid(){ - let bindgroup_inputs = get_buffer_referece_key(self, buf_dest_cached_id, &bindgroup_reference); - if let Some(bg) = self.bindgroups.get_bindgroup_reference_by_description(&bindgroup_inputs).cloned() { + if buf_dest_cached_id.is_valid() { + let bindgroup_inputs = + get_buffer_referece_key(self, buf_dest_cached_id, &bindgroup_reference); + if let Some(bg) = self + .bindgroups + .get_bindgroup_reference_by_description(&bindgroup_inputs) + .cloned() + { self.bindgroups.cached_bindgroup_use_counter += 1; self.mappings.add_buffer(buf_dest_cached_id, pipeline); return bg; } } //reference storage is not set -> search a free buffer or create new one - else{ + else { if let Some(buffer_id) = self.mappings.get_buffer(pipeline.clone()) { - let buffer : Option<&CachedBuffer> = self.buffers.get_buffer(&buffer_id); - if let Some(buffer) = buffer{ - if buffer.is_free(){ + let buffer: Option<&CachedBuffer> = self.buffers.get_buffer(&buffer_id); + if let Some(buffer) = buffer { + if buffer.is_free() { if buffer.buffer.size() >= required_size { - let buf_dest_reference = self.buffer_reference.get_mut(buf_dest_id).unwrap(); + let buf_dest_reference = + self.buffer_reference.get_mut(buf_dest_id).unwrap(); //use this buffer for the buffer reference: buf_dest_reference.cached_buffer_id = buffer_id; buf_dest_cached_id = buffer_id; self.buffers.use_buffer(&buffer_id, command_id); - + //reuse a bindgroup, if we could find one: - let bindgroup_inputs = get_buffer_referece_key(self, buffer_id, &bindgroup_reference); - if let Some(bg) = self.bindgroups.get_bindgroup_reference_by_description(&bindgroup_inputs).cloned() { + let bindgroup_inputs = + get_buffer_referece_key(self, buffer_id, &bindgroup_reference); + if let Some(bg) = self + .bindgroups + .get_bindgroup_reference_by_description(&bindgroup_inputs) + .cloned() + { self.bindgroups.cached_bindgroup_use_counter += 1; - self.mappings.add_buffer(buffer_id,pipeline); + self.mappings.add_buffer(buffer_id, pipeline); return bg.clone(); - } + } } else { //the required size increased -> also request a little bit more required_size *= 2; @@ -648,9 +751,11 @@ impl ModelCache { } } - let bindgroup_inputs = get_buffer_referece_key(self, CachedBufferId::new(0, 0), &bindgroup_reference); + let bindgroup_inputs = + get_buffer_referece_key(self, CachedBufferId::new(0, 0), &bindgroup_reference); // let bindgroup_inputs = &bindgroup_reference.1; - let max_size : u64 = BufferCacheStorage::max_cached_size(required_size as u64, buf_dest_length); + let max_size: u64 = + BufferCacheStorage::max_cached_size(required_size as u64, buf_dest_length); let candidates_to_process = self .bindgroups @@ -658,22 +763,25 @@ impl ModelCache { .filter_map(|(id, bindgroup)| { let cbuf_dest_id = bindgroup.buffer.get_dest(); - if buf_dest_cached_id.is_valid(){ - if let Some(bindgroup) = self.bindgroups.get_bindgroup(&id){ - if buf_dest_cached_id == *bindgroup.buffer.get_dest(){ - if let Some(c_buf_dest) = self.buffers.get_buffer(cbuf_dest_id){ - return Some((id, bindgroup,c_buf_dest.buffer.size())); + if buf_dest_cached_id.is_valid() { + if let Some(bindgroup) = self.bindgroups.get_bindgroup(&id) { + if buf_dest_cached_id == *bindgroup.buffer.get_dest() { + if let Some(c_buf_dest) = self.buffers.get_buffer(cbuf_dest_id) + { + return Some((id, bindgroup, c_buf_dest.buffer.size())); } } } - } - else{ - if let Some(c_buf_dest) = self.buffers.get_buffer(cbuf_dest_id){ - if c_buf_dest.buffer.size() >= required_size && c_buf_dest.is_free() && c_buf_dest.buffer.size() <= max_size{ - return Some((id, bindgroup, c_buf_dest.buffer.size())) + } else { + if let Some(c_buf_dest) = self.buffers.get_buffer(cbuf_dest_id) { + if c_buf_dest.buffer.size() >= required_size + && c_buf_dest.is_free() + && c_buf_dest.buffer.size() <= max_size + { + return Some((id, bindgroup, c_buf_dest.buffer.size())); } } - } + } return None; }); @@ -695,7 +803,7 @@ impl ModelCache { let buf_dest_reference = self.buffer_reference.get_mut(buf_dest_id).unwrap(); buf_dest_reference.cached_buffer_id = *cached_dest_buffer_id; self.buffers.use_buffer(&cached_dest_buffer_id, command_id); - self.mappings.add_buffer(*cached_dest_buffer_id,pipeline); + self.mappings.add_buffer(*cached_dest_buffer_id, pipeline); self.bindgroups.cached_bindgroup_use_counter += 1; return cached_bindgroup_id; } @@ -703,20 +811,24 @@ impl ModelCache { } let buf_dest_reference = self.buffer_reference.get_mut(buf_dest_id).unwrap(); - + //create new buffer, if buffer was not already set: let dest_buffer_id; - if buf_dest_reference.cached_buffer_id.is_valid(){ //this buffer reference already has a buffer connected,use this buffer + if buf_dest_reference.cached_buffer_id.is_valid() { + //this buffer reference already has a buffer connected,use this buffer dest_buffer_id = buf_dest_reference.cached_buffer_id; - } - else{//create a new buffer - dest_buffer_id = self.buffers.search_buffer(dev, required_size, command_id, buf_dest_length); + } else { + //create a new buffer + dest_buffer_id = + self.buffers + .search_buffer(dev, required_size, command_id, buf_dest_length); //use this buffer for the buffer reference: buf_dest_reference.cached_buffer_id = dest_buffer_id; } //create new bindgroup: - let bindgroup_reference = get_buffer_referece_key(self, dest_buffer_id, bindgroup_reference); + let bindgroup_reference = + get_buffer_referece_key(self, dest_buffer_id, bindgroup_reference); let bindgroup_id = self.create_bindgroup(dev, bindgroup_reference); if dev.configuration.use_cache { @@ -726,66 +838,72 @@ impl ModelCache { } //creats a Bindgroup - fn create_bindgroup(&mut self, dev : &WgpuDevice, bindgroup_d : CachedBindgroupFull) -> CachedBindgroupId { + fn create_bindgroup( + &mut self, + dev: &WgpuDevice, + bindgroup_d: CachedBindgroupFull, + ) -> CachedBindgroupId { let bindgroup = wgpu_functions::create_bindgroup(dev, bindgroup_d.clone(), self); let bindgroup = CachedBindgroup::new(bindgroup, bindgroup_d.clone()); let id = self.bindgroups.storage.insert(bindgroup); - self.bindgroups.bindgroups.add_mapping(bindgroup_d.1.clone(), id); + self.bindgroups + .bindgroups + .add_mapping(bindgroup_d.1.clone(), id); self.bindgroups.bindgroups_full.insert(bindgroup_d, id); self.bindgroups.bindgroup_counter += 1; return id; } - } #[derive(Debug)] -pub (crate) struct BufferReferenceStorage{ - storage : Storage, - deletion_queue : Vec //entires that are marked for deletion +pub(crate) struct BufferReferenceStorage { + storage: Storage, + deletion_queue: Vec, //entires that are marked for deletion } -impl BufferReferenceStorage{ +impl BufferReferenceStorage { fn new() -> Self { - Self { storage : Storage::new(), deletion_queue : vec![] } + Self { + storage: Storage::new(), + deletion_queue: vec![], + } } - fn insert(&mut self, referece : BufferReference) -> BufferReferenceId - { + fn insert(&mut self, referece: BufferReference) -> BufferReferenceId { let id = self.storage.insert(referece); //println!("create new buffer Reference: {:?}", id); return id; } - pub fn get(&self, id : &BufferReferenceId) -> Option<&BufferReference>{ + pub fn get(&self, id: &BufferReferenceId) -> Option<&BufferReference> { self.storage.get(id) } - pub fn get_mut(&mut self, id : &BufferReferenceId) -> Option<&mut BufferReference>{ + pub fn get_mut(&mut self, id: &BufferReferenceId) -> Option<&mut BufferReference> { self.storage.get_mut(id) } - pub fn queue_for_deletion(&mut self, id : &BufferReferenceId){ + pub fn queue_for_deletion(&mut self, id: &BufferReferenceId) { self.deletion_queue.push(*id); } - pub fn get_deletion_entries(&mut self) -> Vec{ + pub fn get_deletion_entries(&mut self) -> Vec { std::mem::take(&mut self.deletion_queue) } - pub fn delete(&mut self, id : &BufferReferenceId) -> bool{ + pub fn delete(&mut self, id: &BufferReferenceId) -> bool { //println!("deleting buffer Reference: {:?}", id); self.storage.delete(id) } - pub fn get_reference(&self, id : u32) -> Option<(BufferReferenceId, &BufferReference)>{ + pub fn get_reference(&self, id: u32) -> Option<(BufferReferenceId, &BufferReference)> { self.storage.get_reference(id) } } - // Struct used for ordering by size #[derive(Debug, Eq, PartialEq)] struct OrderedIndex { @@ -800,38 +918,38 @@ impl OrderedIndex { } // Implementing Ord and PartialOrd for OrderedIndex so it can be stored in BTreeSet -impl Ord for OrderedIndex { +impl Ord for OrderedIndex { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.value.cmp(&other.value).then_with(|| self.index.cmp(&other.index)) + self.value + .cmp(&other.value) + .then_with(|| self.index.cmp(&other.index)) } } -impl PartialOrd for OrderedIndex { +impl PartialOrd for OrderedIndex { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } - - /// Cache of all free CachedBuffers #[derive(Debug)] pub(crate) struct BufferCacheStorage { - storage : StorageOptional, - order : BTreeSet>, //contains a ordered list of the currently free buffers in the storage - //(a buffer may be free if it is currently not used, this does not mean that it was deleted. a deleted buffer was complely removed from the storage and droped) - buffer_counter: u32, //total number of buffers created + storage: StorageOptional, + order: BTreeSet>, //contains a ordered list of the currently free buffers in the storage + //(a buffer may be free if it is currently not used, this does not mean that it was deleted. a deleted buffer was complely removed from the storage and droped) + buffer_counter: u32, //total number of buffers created buffer_reuse_counter: u32, //total number of buffers created - buffer_memory: u64, //total memory allocated - buffer_memory_free: u64, //total memory in buffers btree map + buffer_memory: u64, //total memory allocated + buffer_memory_free: u64, //total memory in buffers btree map max_memory_allowed: u64, } -impl BufferCacheStorage{ +impl BufferCacheStorage { pub fn new() -> Self { return Self { - storage : StorageOptional::new(), - order : BTreeSet::new(), + storage: StorageOptional::new(), + order: BTreeSet::new(), buffer_counter: 0, buffer_reuse_counter: 0, buffer_memory: 0, @@ -841,7 +959,7 @@ impl BufferCacheStorage{ } //creats a Buffer, expect that it will be used and not be part of free memory - fn create_buffer(&mut self, dev : &WgpuDevice, size : u64, command_id : u32) -> CachedBufferId { + fn create_buffer(&mut self, dev: &WgpuDevice, size: u64, command_id: u32) -> CachedBufferId { let buffer = wgpu_functions::create_buffer(dev, size); let mut buffer = CachedBuffer::new(buffer); buffer.last_used_counter = command_id; @@ -852,28 +970,31 @@ impl BufferCacheStorage{ return id; } - - pub fn delete_buffer(&mut self, id : &CachedBufferId){ + pub fn delete_buffer(&mut self, id: &CachedBufferId) { let value = self.storage.delete_move(id); - if let Some(val) = value{ - if self.order.remove(&OrderedIndex::new(id.id(), val.buffer.size())){ + if let Some(val) = value { + if self + .order + .remove(&OrderedIndex::new(id.id(), val.buffer.size())) + { self.buffer_memory_free -= val.buffer.size(); } self.buffer_memory -= val.buffer.size(); - } - + } } - pub fn get_buffer(&self, id : &CachedBufferId) -> Option<&CachedBuffer>{ + pub fn get_buffer(&self, id: &CachedBufferId) -> Option<&CachedBuffer> { self.storage.get(id) } //will not delete the buffer, but mark it free - pub fn free_buffer(&mut self, id : &CachedBufferId){ - let buffer : Option<&mut CachedBuffer> = self.storage.get_mut(id); - if let Some(buffer) = buffer{ - if buffer.is_free == false{ //the buffer is currently not free -> add it into the free order list - self.order.insert(OrderedIndex::new(id.id(), buffer.buffer.size())); + pub fn free_buffer(&mut self, id: &CachedBufferId) { + let buffer: Option<&mut CachedBuffer> = self.storage.get_mut(id); + if let Some(buffer) = buffer { + if buffer.is_free == false { + //the buffer is currently not free -> add it into the free order list + self.order + .insert(OrderedIndex::new(id.id(), buffer.buffer.size())); buffer.is_free = true; self.buffer_memory_free += buffer.buffer.size() } @@ -881,11 +1002,13 @@ impl BufferCacheStorage{ } //will not create a buffer, but mark the buffer as used - pub fn use_buffer(&mut self, id : &CachedBufferId, command_id : u32){ - let buffer : Option<&mut CachedBuffer> = self.storage.get_mut(id); - if let Some(buffer) = buffer{ - if buffer.is_free == true{ //the buffer is currently free -> remove it from the free order list - self.order.remove(&OrderedIndex::new(id.id(), buffer.buffer.size())); + pub fn use_buffer(&mut self, id: &CachedBufferId, command_id: u32) { + let buffer: Option<&mut CachedBuffer> = self.storage.get_mut(id); + if let Some(buffer) = buffer { + if buffer.is_free == true { + //the buffer is currently free -> remove it from the free order list + self.order + .remove(&OrderedIndex::new(id.id(), buffer.buffer.size())); buffer.is_free = false; buffer.last_used_counter = command_id; self.buffer_reuse_counter += 1; @@ -919,40 +1042,45 @@ impl BufferCacheStorage{ // } //the length, this buffer should be used for(if a buffer is only used temporary we may use a way bigger buffer for just one command) - fn max_cached_size(size : u64, length : u32) -> u64{ + fn max_cached_size(size: u64, length: u32) -> u64 { let length = (length + 1).min(100); let i = (300 / (length * length * length)).min(64).max(1) as u64; - const TRANSITION_POINT : u64 = 1000*1024; - return size + (i * size * TRANSITION_POINT / (TRANSITION_POINT + size)); + const TRANSITION_POINT: u64 = 1000 * 1024; + return size + (i * size * TRANSITION_POINT / (TRANSITION_POINT + size)); } //will try to find a free buffer in the cache, or create a new one - pub fn search_buffer(&mut self, dev : &WgpuDevice, size : u64, command_id : u32, length : u32) -> CachedBufferId { + pub fn search_buffer( + &mut self, + dev: &WgpuDevice, + size: u64, + command_id: u32, + length: u32, + ) -> CachedBufferId { //println!("search buffer: size: {size}"); let max_size = BufferCacheStorage::max_cached_size(size, length); - if dev.configuration.use_cache{ + if dev.configuration.use_cache { let mut buffer_found = None; - for id in self.order.range(OrderedIndex::new(0, size)..){ - if id.value < size{ + for id in self.order.range(OrderedIndex::new(0, size)..) { + if id.value < size { panic!("Did not expect size to be smaller, than key"); } - if id.value > max_size{ + if id.value > max_size { break; } buffer_found = Some(id); } //remove this buffer from free memory: - if let Some(buffer_found) = buffer_found{ - if let Some((reference, _)) = self.storage.get_reference(buffer_found.index){ + if let Some(buffer_found) = buffer_found { + if let Some((reference, _)) = self.storage.get_reference(buffer_found.index) { //println!("search buffer: found free buffer, using: {:?}, size: {:?}", reference, buffer.buffer.size()); self.use_buffer(&reference, command_id); return reference; } - } } return self.create_buffer(dev, size, command_id); @@ -961,34 +1089,28 @@ impl BufferCacheStorage{ pub fn max_memory_allowed(&self) -> u64 { self.max_memory_allowed } - + pub fn set_max_memory_allowed(&mut self, max_memory_allowed: u64) { self.max_memory_allowed = max_memory_allowed; } - + pub(crate) fn buffer_memory(&self) -> u64 { self.buffer_memory } - + pub(crate) fn buffer_reuse_counter(&self) -> u32 { self.buffer_reuse_counter } - + pub(crate) fn buffer_counter(&self) -> u32 { self.buffer_counter } - } - - - - - /// Cache of all available CachedBindGroups #[derive(Debug)] pub(crate) struct BindgroupCacheStorage { - storage : StorageOptional, + storage: StorageOptional, bindgroups: HashMapMulti, //all bindgroups based on input buffers bindgroups_full: HashMap, //all bindgroups based on input und dest buffers bindgroup_counter: u32, @@ -998,7 +1120,7 @@ pub(crate) struct BindgroupCacheStorage { impl BindgroupCacheStorage { fn new() -> Self { return Self { - storage : StorageOptional::new(), + storage: StorageOptional::new(), bindgroups: HashMapMulti::new(), bindgroups_full: HashMap::new(), bindgroup_counter: 0, @@ -1006,77 +1128,57 @@ impl BindgroupCacheStorage { }; } - fn retain_bindgroups(&mut self, mut keep : impl FnMut(&CachedBindgroup) -> bool){ + fn retain_bindgroups(&mut self, mut keep: impl FnMut(&CachedBindgroup) -> bool) { self.storage.retain_mut(|(id, bg)| { let keep = keep(bg); - if !keep{ + if !keep { let id = id.clone(); let buf_reference_input_full = bg.buffer.clone(); - self.bindgroups.remove_mapping(buf_reference_input_full.1.clone(), &id); + self.bindgroups + .remove_mapping(buf_reference_input_full.1.clone(), &id); self.bindgroups_full.remove(&buf_reference_input_full); } return keep; }); - } + } - pub fn get_bindgroup(&self, id : &CachedBindgroupId) -> Option<&CachedBindgroup>{ + pub fn get_bindgroup(&self, id: &CachedBindgroupId) -> Option<&CachedBindgroup> { self.storage.get(id) } - fn get_bindgroup_reference_by_description(&self, bindgroup_d : &CachedBindgroupFull) -> Option<&CachedBindgroupId>{ + fn get_bindgroup_reference_by_description( + &self, + bindgroup_d: &CachedBindgroupFull, + ) -> Option<&CachedBindgroupId> { self.bindgroups_full.get(bindgroup_d) } - - fn get_bindgroup_reference_by_description_input(&self, bindgroup_d : &CachedBindgroupInput) -> &Vec{ + fn get_bindgroup_reference_by_description_input( + &self, + bindgroup_d: &CachedBindgroupInput, + ) -> &Vec { self.bindgroups.get(bindgroup_d) } - - - fn enumerate_bindgroup_by_description_input(&self, bindgroup_d : &CachedBindgroupInput) -> impl Iterator{ - self.get_bindgroup_reference_by_description_input(bindgroup_d).iter().filter_map(|c| - Some((c.clone(), self.get_bindgroup(c)?)) - ) + fn enumerate_bindgroup_by_description_input( + &self, + bindgroup_d: &CachedBindgroupInput, + ) -> impl Iterator { + self.get_bindgroup_reference_by_description_input(bindgroup_d) + .iter() + .filter_map(|c| Some((c.clone(), self.get_bindgroup(c)?))) } - + pub(crate) fn bindgroup_counter(&self) -> u32 { self.bindgroup_counter } - + pub(crate) fn cached_bindgroup_use_counter(&self) -> u32 { self.cached_bindgroup_use_counter } - } - - - - - - - - - - - - - - - - - - - - - - - - - - ///Cache, that stores previously Flushed Gpu Commands, we try to use the same buffers as the last time #[derive(Debug)] pub(crate) struct BufferMappingCache { @@ -1086,7 +1188,7 @@ pub(crate) struct BufferMappingCache { } impl BufferMappingCache { - fn new(size : u32) -> Self { + fn new(size: u32) -> Self { Self { last_buffer_mappings: FixedSizeQueue::new(size as usize), current_buffer_mapping: None, @@ -1100,10 +1202,16 @@ impl BufferMappingCache { .iter() .position(|b| b.hash == hash); if let Some(index) = index { - log::debug!("reuse mapping: {index}, hash: {hash}, mappings: {}", self.last_buffer_mappings.deque.len()); + log::debug!( + "reuse mapping: {index}, hash: {hash}, mappings: {}", + self.last_buffer_mappings.deque.len() + ); self.current_buffer_mapping = self.last_buffer_mappings.deque.remove(index); } else { - log::debug!("create new mapping: hash: {hash}, mappings: {}", self.last_buffer_mappings.deque.len()); + log::debug!( + "create new mapping: hash: {hash}, mappings: {}", + self.last_buffer_mappings.deque.len() + ); self.current_buffer_mapping = Some(CachedBufferMappings::new(hash)); } } @@ -1122,11 +1230,7 @@ impl BufferMappingCache { } ///Stores, that at the provided buffer was used - pub(crate) fn add_buffer( - &mut self, - buffer: CachedBufferId, - pipeline: PipelineType - ) { + pub(crate) fn add_buffer(&mut self, buffer: CachedBufferId, pipeline: PipelineType) { if let Some(mapping) = &mut self.current_buffer_mapping { let data = CachedBufferMapping::new(pipeline, buffer); if (self.current_index as usize) < mapping.data.len() { @@ -1174,4 +1278,5 @@ impl CachedBufferMappings { fn new(hash: u64) -> Self { Self { data: vec![], hash } } -} \ No newline at end of file +} + diff --git a/candle-core/src/wgpu_backend/device.rs b/candle-core/src/wgpu_backend/device.rs index df1302c91a..b39730aa20 100644 --- a/candle-core/src/wgpu_backend/device.rs +++ b/candle-core/src/wgpu_backend/device.rs @@ -16,45 +16,44 @@ use crate::backend::BackendStorage; use crate::{notImplemented, wrongType, Layout}; #[cfg(feature = "wgpu_debug")] -use super::debug_info::{DebugInfo, Measurements,MInfo, ShaderInfo}; +use super::debug_info::{DebugInfo, MInfo, Measurements, ShaderInfo}; use super::cache::{BindgroupLayouts, CachedBindgroupId, CachedBufferId, ModelCache}; use super::storage::{create_wgpu_storage, create_wgpu_storage_init}; use super::util::{Counter, ObjectToIdMapper, ToF64, ToU32}; -use super::wgpu_functions::{ConstArray, KernelParameterMeta}; use super::wgpu_functions::{self, unary::UnaryOperation, MetaArray}; +use super::wgpu_functions::{ConstArray, KernelParameterMeta}; use super::WgpuStorage; - #[derive(Debug)] -pub struct WgpuDeviceConfig{ - pub meta_buffer_size : u32, //the size of the buffer used for storing meta information (e.g. input layouts) - pub max_workload_size : u64, //specifys how much max floating point operations will be queued in one single command. (e.g. a matrix multiplication of 1000x1000 * 1000x1000 would be about 1gb operations, so only 2 of theses may be queued in one command buffer) - pub buffer_cached_max_allowed_size : u64,//maximum size for cached wgpu::buffers. When this size is reached, free buffers will be deleted until only 75% of this max size is used. - //if this value is to low for the desired model, the performance may drop significatly(e.g. model needs at least 2gb of data, if this value would be e.g. only 100mb all free buffers would be deleted after each command) - pub use_cache : bool, - pub queue_delay_miliseconds : u32, //specifys the amout of time to wait after each command (may be usefull for debuging purposes if one expect, that the impl causes to much stress on the gpu) - pub flush_gpu_before_buffer_init : bool, //when data is copied from cpu to the wgpu device, all previous commands may be flushed, to allow other buffers to be freed and reused. - //But on webGpu this may not be optimal, as we can not wait for commands to finish (as this functin is not asyny) - pub buffer_mapping_size : u32, +pub struct WgpuDeviceConfig { + pub meta_buffer_size: u32, //the size of the buffer used for storing meta information (e.g. input layouts) + pub max_workload_size: u64, //specifys how much max floating point operations will be queued in one single command. (e.g. a matrix multiplication of 1000x1000 * 1000x1000 would be about 1gb operations, so only 2 of theses may be queued in one command buffer) + pub buffer_cached_max_allowed_size: u64, //maximum size for cached wgpu::buffers. When this size is reached, free buffers will be deleted until only 75% of this max size is used. + //if this value is to low for the desired model, the performance may drop significatly(e.g. model needs at least 2gb of data, if this value would be e.g. only 100mb all free buffers would be deleted after each command) + pub use_cache: bool, + pub queue_delay_miliseconds: u32, //specifys the amout of time to wait after each command (may be usefull for debuging purposes if one expect, that the impl causes to much stress on the gpu) + pub flush_gpu_before_buffer_init: bool, //when data is copied from cpu to the wgpu device, all previous commands may be flushed, to allow other buffers to be freed and reused. + //But on webGpu this may not be optimal, as we can not wait for commands to finish (as this functin is not asyny) + pub buffer_mapping_size: u32, } impl Default for WgpuDeviceConfig { fn default() -> WgpuDeviceConfig { WgpuDeviceConfig { - meta_buffer_size : 10*1024*1024, - max_workload_size : 1024u64*1024*1024*2, - buffer_cached_max_allowed_size : 1024*1024*1024*8, - use_cache : true, - queue_delay_miliseconds : 0, - flush_gpu_before_buffer_init : true, - buffer_mapping_size : 3, + meta_buffer_size: 10 * 1024 * 1024, + max_workload_size: 1024u64 * 1024 * 1024 * 2, + buffer_cached_max_allowed_size: 1024 * 1024 * 1024 * 8, + use_cache: true, + queue_delay_miliseconds: 0, + flush_gpu_before_buffer_init: true, + buffer_mapping_size: 3, } } } #[derive(Debug)] -pub (crate) enum MlQueue{ +pub(crate) enum MlQueue { Dispatch(MlQueueDispatch), } @@ -76,93 +75,116 @@ impl Hash for OrderedFloat { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct OpIsInplaceable{ - pub input1_inplaceable : bool, - pub input2_inplaceable : bool, +pub struct OpIsInplaceable { + pub input1_inplaceable: bool, + pub input2_inplaceable: bool, } impl OpIsInplaceable { pub fn new() -> Self { - Self { input1_inplaceable : false, input2_inplaceable : false } + Self { + input1_inplaceable: false, + input2_inplaceable: false, + } } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub (crate) struct PipelineType(pub candle_wgpu_kernels::Pipelines, pub usize, pub OpIsInplaceable); +pub(crate) struct PipelineType( + pub candle_wgpu_kernels::Pipelines, + pub usize, + pub OpIsInplaceable, +); //TODO: use BindgroupReferenceFull instead of BindgroupReference -pub (crate) type BindGroupReference = crate::wgpu_backend::cache::BindgroupReferenceFull; +pub(crate) type BindGroupReference = crate::wgpu_backend::cache::BindgroupReferenceFull; #[derive(Debug)] -pub (crate) enum DispatchedBindgroup { +pub(crate) enum DispatchedBindgroup { BindgroupReference(BindGroupReference), CachedBindgroup(CachedBindgroupId), - None //optimized away - } + None, //optimized away +} #[derive(Debug)] -pub (crate) struct MlQueueDispatch{ - pub (crate) x : u32, - pub (crate) y : u32, - pub (crate) z : u32, - pub (crate) pipeline : PipelineType, - pub (crate) bindgroup : DispatchedBindgroup, - pub (crate) pipeline_cached : Option>, - pub (crate) meta : u32, - pub (crate) workload_size : usize, //the total size needed to calculate. Needed so we do not queue to many operations at once. +pub(crate) struct MlQueueDispatch { + pub(crate) x: u32, + pub(crate) y: u32, + pub(crate) z: u32, + pub(crate) pipeline: PipelineType, + pub(crate) bindgroup: DispatchedBindgroup, + pub(crate) pipeline_cached: Option>, + pub(crate) meta: u32, + pub(crate) workload_size: usize, //the total size needed to calculate. Needed so we do not queue to many operations at once. #[cfg(feature = "wgpu_debug")] - pub (crate) debug : Option, + pub(crate) debug: Option, } #[derive(Debug)] -pub struct ShaderModuleComputePipelines{ - shader : Arc, - pipelines : Mutex>> +pub struct ShaderModuleComputePipelines { + shader: Arc, + pipelines: Mutex>>, } -//a struct, where all operations are chunked +//a struct, where all operations are chunked #[derive(Debug)] -pub struct QueueBuffer{ - pub (crate) command_queue : Vec, - meta_array : MetaArray, - const_array : ConstArray, - const_id_map : ObjectToIdMapper, - global_command_index : u32, - pub (crate) id_to_const_array : Vec>, - pub (crate) current_meta : u32, - pub (crate) last_buffer : Option //will be used to wait for the last command queue +pub struct QueueBuffer { + pub(crate) command_queue: Vec, + meta_array: MetaArray, + const_array: ConstArray, + const_id_map: ObjectToIdMapper, + global_command_index: u32, + pub(crate) id_to_const_array: Vec>, + pub(crate) current_meta: u32, + pub(crate) last_buffer: Option, //will be used to wait for the last command queue } impl QueueBuffer { - pub fn new(size : u32) -> Self { - Self { command_queue: vec![], meta_array :MetaArray::new(size), current_meta : 0 , const_array: ConstArray::new(), const_id_map : ObjectToIdMapper::new() , id_to_const_array : Vec::new(), last_buffer : None, global_command_index : 1} + pub fn new(size: u32) -> Self { + Self { + command_queue: vec![], + meta_array: MetaArray::new(size), + current_meta: 0, + const_array: ConstArray::new(), + const_id_map: ObjectToIdMapper::new(), + id_to_const_array: Vec::new(), + last_buffer: None, + global_command_index: 1, + } } - pub fn init(&mut self){ + pub fn init(&mut self) { self.const_array.0.clear(); } - pub fn clear(&mut self){ + pub fn clear(&mut self) { self.command_queue.clear(); self.meta_array.0.clear(); self.init(); self.current_meta = 0; } - pub fn get_meta(&self) -> &Vec{ + pub fn get_meta(&self) -> &Vec { return &self.meta_array.0; } - pub fn get_meta_mut(&mut self) -> &mut Vec{ + pub fn get_meta_mut(&mut self) -> &mut Vec { return &mut self.meta_array.0; } - fn add_layout(&mut self, layout: &Layout, is_contiguous : bool, constant_dims : Constants, constant_is_startofsset_zero : Constants, constant_is_contiguous : Constants){ + fn add_layout( + &mut self, + layout: &Layout, + is_contiguous: bool, + constant_dims: Constants, + constant_is_startofsset_zero: Constants, + constant_is_contiguous: Constants, + ) { let shape = layout.shape().dims(); let stride = layout.stride(); self.add_const(constant_dims, shape.len()); - if layout.start_offset() != 0{ + if layout.start_offset() != 0 { self.add_const(constant_is_startofsset_zero, false); self.add(layout.start_offset()); } @@ -171,104 +193,176 @@ impl QueueBuffer { self.add(layout.shape().elem_count()); } else { self.add_const(constant_is_contiguous, false); - + self.get_meta_mut().extend(shape.iter().map(|&x| x as u32)); self.get_meta_mut().extend(stride.iter().map(|&x| x as u32)); - } + } } - pub(crate) fn add_layout1(&mut self, layout: &Layout) { - self.add_layout(layout, layout.is_contiguous(), Constants::ConstDims1, Constants::ConstIsStartoffsetZero1, Constants::ConstIsContiguous1); + pub(crate) fn add_layout1(&mut self, layout: &Layout) { + self.add_layout( + layout, + layout.is_contiguous(), + Constants::ConstDims1, + Constants::ConstIsStartoffsetZero1, + Constants::ConstIsContiguous1, + ); } - pub(crate) fn add_layout2(&mut self, layout: &Layout) { - self.add_layout(layout, layout.is_contiguous(), Constants::ConstDims2, Constants::ConstIsStartoffsetZero2, Constants::ConstIsContiguous2); + pub(crate) fn add_layout2(&mut self, layout: &Layout) { + self.add_layout( + layout, + layout.is_contiguous(), + Constants::ConstDims2, + Constants::ConstIsStartoffsetZero2, + Constants::ConstIsContiguous2, + ); } - pub(crate) fn add_layout3(&mut self, layout: &Layout) { - self.add_layout(layout, layout.is_contiguous(), Constants::ConstDims3, Constants::ConstIsStartoffsetZero3, Constants::ConstIsContiguous3); + pub(crate) fn add_layout3(&mut self, layout: &Layout) { + self.add_layout( + layout, + layout.is_contiguous(), + Constants::ConstDims3, + Constants::ConstIsStartoffsetZero3, + Constants::ConstIsContiguous3, + ); } //forces to write the shapes and strides - pub(crate) fn add_layout1_non_contiguous(&mut self, layout: &Layout) { - self.add_layout(layout, false, Constants::ConstDims1, Constants::ConstIsStartoffsetZero1, Constants::ConstIsContiguous1); + pub(crate) fn add_layout1_non_contiguous(&mut self, layout: &Layout) { + self.add_layout( + layout, + false, + Constants::ConstDims1, + Constants::ConstIsStartoffsetZero1, + Constants::ConstIsContiguous1, + ); } - pub(crate) fn add_layout2_non_contiguous(&mut self, layout: &Layout) { - self.add_layout(layout, false, Constants::ConstDims2, Constants::ConstIsStartoffsetZero2, Constants::ConstIsContiguous2); + pub(crate) fn add_layout2_non_contiguous(&mut self, layout: &Layout) { + self.add_layout( + layout, + false, + Constants::ConstDims2, + Constants::ConstIsStartoffsetZero2, + Constants::ConstIsContiguous2, + ); } - pub(crate) fn add_layout3_non_contiguous(&mut self, layout: &Layout) { - self.add_layout(layout, false, Constants::ConstDims3, Constants::ConstIsStartoffsetZero3, Constants::ConstIsContiguous3); + pub(crate) fn add_layout3_non_contiguous(&mut self, layout: &Layout) { + self.add_layout( + layout, + false, + Constants::ConstDims3, + Constants::ConstIsStartoffsetZero3, + Constants::ConstIsContiguous3, + ); } - pub (crate) fn get_pipeline(&mut self, pipeline: Pipelines) -> PipelineType { - let (index, is_new) = self.const_id_map.get_or_insert( &self.const_array); + pub(crate) fn get_pipeline(&mut self, pipeline: Pipelines) -> PipelineType { + let (index, is_new) = self.const_id_map.get_or_insert(&self.const_array); if is_new { - let hmap = HashMap::from_iter(self.const_array.0.iter().map(|(k,v)| (k.get_entry_point().to_owned(), v.to_f64()))); + let hmap = HashMap::from_iter( + self.const_array + .0 + .iter() + .map(|(k, v)| (k.get_entry_point().to_owned(), v.to_f64())), + ); self.id_to_const_array.push(hmap) } self.init(); return PipelineType(pipeline, index, OpIsInplaceable::new()); } - pub (crate) fn get_pipeline_inplaceable(&mut self, pipeline: Pipelines, inplaceable : OpIsInplaceable) -> PipelineType { - let (index, is_new) = self.const_id_map.get_or_insert( &self.const_array); + pub(crate) fn get_pipeline_inplaceable( + &mut self, + pipeline: Pipelines, + inplaceable: OpIsInplaceable, + ) -> PipelineType { + let (index, is_new) = self.const_id_map.get_or_insert(&self.const_array); if is_new { - let hmap = HashMap::from_iter(self.const_array.0.iter().map(|(k,v)| (k.get_entry_point().to_owned(), v.to_f64()))); + let hmap = HashMap::from_iter( + self.const_array + .0 + .iter() + .map(|(k, v)| (k.get_entry_point().to_owned(), v.to_f64())), + ); self.id_to_const_array.push(hmap) } self.init(); return PipelineType(pipeline, index, inplaceable); } - pub (crate) fn get_pipeline_const(&mut self, pipeline: Pipelines, const_vec : Vec) -> PipelineType { - for (index, v) in const_vec.into_iter().enumerate(){ - self.const_array.0.push(( candle_wgpu_kernels::Constants::get_const(index), v.to_u32())); + pub(crate) fn get_pipeline_const( + &mut self, + pipeline: Pipelines, + const_vec: Vec, + ) -> PipelineType { + for (index, v) in const_vec.into_iter().enumerate() { + self.const_array + .0 + .push((candle_wgpu_kernels::Constants::get_const(index), v.to_u32())); } - let (index, is_new) = self.const_id_map.get_or_insert( &self.const_array); + let (index, is_new) = self.const_id_map.get_or_insert(&self.const_array); if is_new { - let hmap = HashMap::from_iter(self.const_array.0.iter().map(|(k,v)| (k.get_entry_point().to_owned(), v.to_f64()))); + let hmap = HashMap::from_iter( + self.const_array + .0 + .iter() + .map(|(k, v)| (k.get_entry_point().to_owned(), v.to_f64())), + ); self.id_to_const_array.push(hmap) } self.init(); return PipelineType(pipeline, index, OpIsInplaceable::new()); } - pub (crate) fn get_pipeline_const_inplace(&mut self, pipeline: Pipelines, const_vec : Vec, inplaceable : OpIsInplaceable) -> PipelineType { - for (index, v) in const_vec.into_iter().enumerate(){ - self.const_array.0.push(( candle_wgpu_kernels::Constants::get_const(index), v.to_u32())); + pub(crate) fn get_pipeline_const_inplace( + &mut self, + pipeline: Pipelines, + const_vec: Vec, + inplaceable: OpIsInplaceable, + ) -> PipelineType { + for (index, v) in const_vec.into_iter().enumerate() { + self.const_array + .0 + .push((candle_wgpu_kernels::Constants::get_const(index), v.to_u32())); } - let (index, is_new) = self.const_id_map.get_or_insert( &self.const_array); + let (index, is_new) = self.const_id_map.get_or_insert(&self.const_array); if is_new { - let hmap = HashMap::from_iter(self.const_array.0.iter().map(|(k,v)| (k.get_entry_point().to_owned(), v.to_f64()))); + let hmap = HashMap::from_iter( + self.const_array + .0 + .iter() + .map(|(k, v)| (k.get_entry_point().to_owned(), v.to_f64())), + ); self.id_to_const_array.push(hmap) } self.init(); return PipelineType(pipeline, index, inplaceable); - } + } - pub (crate) fn add(&mut self, value : T){ + pub(crate) fn add(&mut self, value: T) { self.meta_array.add(value); } - pub (crate) fn add_const(&mut self, key : candle_wgpu_kernels::Constants, value : T){ + pub(crate) fn add_const(&mut self, key: candle_wgpu_kernels::Constants, value: T) { self.const_array.insert(key, value); } - + pub fn global_command_index(&self) -> u32 { self.global_command_index } - + pub fn set_global_command_index(&mut self, global_command_index: u32) { self.global_command_index = global_command_index; } - } #[derive(Clone)] -pub enum MatmulAlgorithm{ +pub enum MatmulAlgorithm { MatmulX, Matmul7, Matmul1, @@ -284,10 +378,10 @@ pub enum MatmulAlgorithm{ Matmul1_128(bool, bool, bool), Matmul1_256(bool, bool, bool), Matmul24_24(bool, bool, bool, bool), - Matmul24_48(bool, bool, bool, bool) + Matmul24_48(bool, bool, bool, bool), } -impl fmt::Debug for MatmulAlgorithm{ +impl fmt::Debug for MatmulAlgorithm { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::MatmulX => write!(f, "MatmulX"), @@ -295,57 +389,122 @@ impl fmt::Debug for MatmulAlgorithm{ Self::Matmul1 => write!(f, "Matmul1"), Self::Matmul1_4 => write!(f, "Matmul1_4"), Self::Matmul16_16 => write!(f, "Matmul5_16_16"), - Self::Matmul32_32(prefatch, no_padded, loada, loadb) => write!(f, "Matmul5_32_32({}{}{}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}, if !*loada {"_LoadA"} else {""}, if !*loadb {"_LoadB"} else {""}), - Self::Matmul64_64(prefatch, no_padded) => write!(f, "Matuml5_64_64({}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}), - Self::Matmul64_64_8_8(prefatch, no_padded) => write!(f, "Matmul5_64_64_8_8({}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}), - Self::Matmul64_128(prefatch, no_padded) => write!(f, "Matuml5_64_128({}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}), - Self::Matmul64_128_8_8(prefatch, no_padded) => write!(f, "Matmul5_64_128_8_8({}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}), - Self::Matmul128_128(prefatch, no_padded) => write!(f, "Matmul5_128_128({}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}), - Self::Matmul16_64(prefatch, no_padded, loada, loadb) => write!(f, "Matmul5_16_64({}{}{}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}, if !*loada {"_LoadA"} else {""}, if !*loadb {"_LoadB"} else {""}), - Self::Matmul1_128(prefatch, no_padded, loada) => write!(f, "Matmul5_1_128({}{}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}, if !*loada {"_LoadA"} else {""}), - Self::Matmul1_256(prefatch, no_padded, loada) => write!(f, "Matmul5_1_256({}{}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}, if !*loada {"_LoadA"} else {""}), - - Self::Matmul24_24(prefatch, no_padded, loada, loadb) => write!(f, "Matmul5_24_24({}{}{}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}, if !*loada {"_LoadA"} else {""}, if !*loadb {"_LoadB"} else {""}), - Self::Matmul24_48(prefatch, no_padded, loada, loadb) => write!(f, "Matmul5_24_48({}{}{}{})", if *prefatch {"_Prefetch"} else {""}, if *no_padded {"_NoPadded"} else {""}, if !*loada {"_LoadA"} else {""}, if !*loadb {"_LoadB"} else {""}), + Self::Matmul32_32(prefatch, no_padded, loada, loadb) => write!( + f, + "Matmul5_32_32({}{}{}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" }, + if !*loada { "_LoadA" } else { "" }, + if !*loadb { "_LoadB" } else { "" } + ), + Self::Matmul64_64(prefatch, no_padded) => write!( + f, + "Matuml5_64_64({}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" } + ), + Self::Matmul64_64_8_8(prefatch, no_padded) => write!( + f, + "Matmul5_64_64_8_8({}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" } + ), + Self::Matmul64_128(prefatch, no_padded) => write!( + f, + "Matuml5_64_128({}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" } + ), + Self::Matmul64_128_8_8(prefatch, no_padded) => write!( + f, + "Matmul5_64_128_8_8({}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" } + ), + Self::Matmul128_128(prefatch, no_padded) => write!( + f, + "Matmul5_128_128({}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" } + ), + Self::Matmul16_64(prefatch, no_padded, loada, loadb) => write!( + f, + "Matmul5_16_64({}{}{}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" }, + if !*loada { "_LoadA" } else { "" }, + if !*loadb { "_LoadB" } else { "" } + ), + Self::Matmul1_128(prefatch, no_padded, loada) => write!( + f, + "Matmul5_1_128({}{}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" }, + if !*loada { "_LoadA" } else { "" } + ), + Self::Matmul1_256(prefatch, no_padded, loada) => write!( + f, + "Matmul5_1_256({}{}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" }, + if !*loada { "_LoadA" } else { "" } + ), + + Self::Matmul24_24(prefatch, no_padded, loada, loadb) => write!( + f, + "Matmul5_24_24({}{}{}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" }, + if !*loada { "_LoadA" } else { "" }, + if !*loadb { "_LoadB" } else { "" } + ), + Self::Matmul24_48(prefatch, no_padded, loada, loadb) => write!( + f, + "Matmul5_24_48({}{}{}{})", + if *prefatch { "_Prefetch" } else { "" }, + if *no_padded { "_NoPadded" } else { "" }, + if !*loada { "_LoadA" } else { "" }, + if !*loadb { "_LoadB" } else { "" } + ), } } } #[derive(Debug)] -pub struct WgpuDeviceInner{ - pub device : wgpu::Device, - pub device_limits : wgpu::Limits, //we cache the limits here, because device.limit() was relatively slow on the browser +pub struct WgpuDeviceInner { + pub device: wgpu::Device, + pub device_limits: wgpu::Limits, //we cache the limits here, because device.limit() was relatively slow on the browser - pub queue : wgpu::Queue, - pub (crate) shader : Mutex>, - pub (crate) rand_state : Mutex, + pub queue: wgpu::Queue, + pub(crate) shader: Mutex>, + pub(crate) rand_state: Mutex, - pub (crate) command_queue : Mutex, - pub (crate) meta_buffer : wgpu::Buffer, //buffer for storing meta information + pub(crate) command_queue: Mutex, + pub(crate) meta_buffer: wgpu::Buffer, //buffer for storing meta information - pub (crate) bindgroup_layouts : BindgroupLayouts, + pub(crate) bindgroup_layouts: BindgroupLayouts, - pub (crate) staging_probe_buffer : wgpu::Buffer, //wait for submission is not supported on wgpu, we use a mapping to a staging buffer as a work around. + pub(crate) staging_probe_buffer: wgpu::Buffer, //wait for submission is not supported on wgpu, we use a mapping to a staging buffer as a work around. - pub (crate) cache : Mutex, //if cache is set, all commands are not queued to the gpu, but are cached inside ModelCache, so there can be reused later on + pub(crate) cache: Mutex, //if cache is set, all commands are not queued to the gpu, but are cached inside ModelCache, so there can be reused later on //debug counter - pub (crate) unary_inplace_counter : Counter, - pub (crate) binary_inplace_counter : Counter, - pub (crate) copy_inplace_counter : Counter, + pub(crate) unary_inplace_counter: Counter, + pub(crate) binary_inplace_counter: Counter, + pub(crate) copy_inplace_counter: Counter, #[cfg(feature = "wgpu_debug")] - pub debug : DebugInfo, + pub debug: DebugInfo, - pub configuration : WgpuDeviceConfig, + pub configuration: WgpuDeviceConfig, - pub matmul_alg : Mutex + pub matmul_alg: Mutex, } #[derive(Debug, Clone)] -pub struct WgpuDevice { - pub inner : Arc, +pub struct WgpuDevice { + pub inner: Arc, } -impl std::ops::Deref for WgpuDevice{ +impl std::ops::Deref for WgpuDevice { type Target = WgpuDeviceInner; fn deref(&self) -> &Self::Target { @@ -353,30 +512,56 @@ impl std::ops::Deref for WgpuDevice{ } } -impl WgpuDevice{ - pub (crate) async fn create(_: usize, configuration : WgpuDeviceConfig) -> crate::Result{ - let instance = wgpu::Instance::new(InstanceDescriptor{ backends: Backends::PRIMARY, flags:InstanceFlags::default() , dx12_shader_compiler: wgpu::Dx12Compiler::Fxc, gles_minor_version: wgpu::Gles3MinorVersion::Automatic }); - +impl WgpuDevice { + pub(crate) async fn create(_: usize, configuration: WgpuDeviceConfig) -> crate::Result { + // let instance = wgpu::Instance::default();//::new(InstanceDescriptor{ backends: Backends::PRIMARY, flags:InstanceFlags::default() , dx12_shader_compiler: wgpu::Dx12Compiler::Fxc, gles_minor_version: wgpu::Gles3MinorVersion::Automatic }); + // let instance = wgpu::Instance::new(InstanceDescriptor { + // backends: Backends::VULKAN, + // flags: InstanceFlags::default(), + // dx12_shader_compiler: wgpu::Dx12Compiler::default(), + // gles_minor_version: wgpu::Gles3MinorVersion::Automatic, + // }); + let instance = wgpu::Instance::new(wgpu::InstanceDescriptor { + backends: wgpu::Backends::VULKAN, + flags: wgpu::InstanceFlags::VALIDATION, + dx12_shader_compiler: wgpu::Dx12Compiler::Fxc, + gles_minor_version: wgpu::Gles3MinorVersion::default(), + }); + // `request_adapter` instantiates the general connection to the GPU + // let adapter = instance + // .request_adapter(&wgpu::RequestAdapterOptions { + // power_preference: wgpu::PowerPreference::LowPower, + // force_fallback_adapter: false, + // compatible_surface: None, + // }) + // .await + // .unwrap(); let adapter = instance - .request_adapter(&wgpu::RequestAdapterOptions{ power_preference: wgpu::PowerPreference::HighPerformance, force_fallback_adapter: false, compatible_surface: None }).await.unwrap(); + .request_adapter(&wgpu::RequestAdapterOptions::default()) + .await + .unwrap(); let mut limits = wgpu::Limits::downlevel_defaults(); - #[cfg(feature = "wgpu_debug")] - let features = wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::TIMESTAMP_QUERY_INSIDE_PASSES | wgpu::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS; - #[cfg(not(feature = "wgpu_debug"))] - let features = wgpu::Features::empty(); + // #[cfg(feature = "wgpu_debug")] + let features = wgpu::Features::SHADER_INT64; + // | wgpu::Features::TIMESTAMP_QUERY + // | wgpu::Features::TIMESTAMP_QUERY_INSIDE_PASSES + // | wgpu::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS; + // #[cfg(not(feature = "wgpu_debug"))] + // let features = wgpu::Features::empty(); let adatper_limits = adapter.limits(); - limits.min_storage_buffer_offset_alignment = adatper_limits.min_storage_buffer_offset_alignment; + limits.min_storage_buffer_offset_alignment = + adatper_limits.min_storage_buffer_offset_alignment; limits.max_storage_buffers_per_shader_stage = 5; limits.max_storage_buffer_binding_size = adatper_limits.max_storage_buffer_binding_size; //use as much as possible limits.max_buffer_size = adatper_limits.max_buffer_size; //use as much as possible - + // `request_device` instantiates the feature specific connection to the GPU, defining some parameters, // `features` being the available features. - log::debug !("Request Device"); + log::debug!("Request Device"); log::debug!("Features: {:?}", features); log::debug!("Limits: {:?}", limits); let (device, queue) = adapter @@ -385,15 +570,17 @@ impl WgpuDevice{ label: None, required_features: features, required_limits: limits, - memory_hints : wgpu::MemoryHints::Performance + memory_hints: wgpu::MemoryHints::MemoryUsage, }, None, - ).await.map_err(|err| crate::Error::WebGpu(err.to_string().into()))?; - log::info!("Device Requested"); - + ) + .await + .map_err(|err| crate::Error::WebGpu(err.to_string().into()))?; + log::info!("Device Requested"); + #[cfg(feature = "wgpu_debug")] let debug_info = super::debug_info::DebugInfo::new(&device); - + let meta_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: None, size: configuration.meta_buffer_size as u64, @@ -403,7 +590,7 @@ impl WgpuDevice{ let device_limits = device.limits(); let bindgroup_layouts = BindgroupLayouts::new(&device); - + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: None, size: 16, @@ -412,61 +599,91 @@ impl WgpuDevice{ }); Ok(WgpuDevice { - inner : Arc::new(WgpuDeviceInner{ + inner: Arc::new(WgpuDeviceInner { device: device, device_limits: device_limits, queue: queue, - shader : Mutex::new(HashMap::new()), + shader: Mutex::new(HashMap::new()), rand_state: Mutex::new(rand::rngs::StdRng::from_entropy()), #[cfg(feature = "wgpu_debug")] - debug : debug_info, + debug: debug_info, command_queue: Mutex::new(QueueBuffer::new(configuration.meta_buffer_size)), - meta_buffer : meta_buffer, - cache : Mutex::new(ModelCache::new(configuration.buffer_mapping_size)), + meta_buffer: meta_buffer, + cache: Mutex::new(ModelCache::new(configuration.buffer_mapping_size)), bindgroup_layouts, - staging_probe_buffer : staging_buffer, - unary_inplace_counter : Counter::new(0), - binary_inplace_counter : Counter::new(0), - copy_inplace_counter : Counter::new(0), - matmul_alg : Mutex::new(MatmulAlgorithm::MatmulX), - configuration : configuration - }) + staging_probe_buffer: staging_buffer, + unary_inplace_counter: Counter::new(0), + binary_inplace_counter: Counter::new(0), + copy_inplace_counter: Counter::new(0), + matmul_alg: Mutex::new(MatmulAlgorithm::MatmulX), + configuration: configuration, + }), }) } - pub fn flush_gpu_command(&self) -> crate::Result<()>{ + pub fn flush_gpu_command(&self) -> crate::Result<()> { let mut queue = self.command_queue.lock().unwrap(); wgpu_functions::flush_gpu_command(self, &mut queue) } - pub fn print_bindgroup_reuseinfo(&self){ + pub fn print_bindgroup_reuseinfo(&self) { let cache = self.cache.lock().unwrap(); - log::warn!("Buffer: created: {}, resued : {}", cache.buffers.buffer_counter(), cache.buffers.buffer_reuse_counter()); - log::warn!("Bindgroup: created: {}, resued : {}", cache.bindgroups.bindgroup_counter(), cache.bindgroups.cached_bindgroup_use_counter()); - log::warn!("Inplace used: unary: {}, binary {}, copy: {}", self.unary_inplace_counter.get(), self.binary_inplace_counter.get(), self.copy_inplace_counter.get()); + log::warn!( + "Buffer: created: {}, resued : {}", + cache.buffers.buffer_counter(), + cache.buffers.buffer_reuse_counter() + ); + log::warn!( + "Bindgroup: created: {}, resued : {}", + cache.bindgroups.bindgroup_counter(), + cache.bindgroups.cached_bindgroup_use_counter() + ); + log::warn!( + "Inplace used: unary: {}, binary {}, copy: {}", + self.unary_inplace_counter.get(), + self.binary_inplace_counter.get(), + self.copy_inplace_counter.get() + ); } - pub fn print_bindgroup_reuseinfo2(&self){ + pub fn print_bindgroup_reuseinfo2(&self) { let cache = self.cache.lock().unwrap(); - - println!("Buffer: created: {}, resued : {}", cache.buffers.buffer_counter(), cache.buffers.buffer_reuse_counter()); - println!("Bindgroup: created: {}, resued : {}", cache.bindgroups.bindgroup_counter(), cache.bindgroups.cached_bindgroup_use_counter()); - println!("Inplace used: unary: {}, binary {}, copy: {}", self.unary_inplace_counter.get(), self.binary_inplace_counter.get(), self.copy_inplace_counter.get()); + + println!( + "Buffer: created: {}, resued : {}", + cache.buffers.buffer_counter(), + cache.buffers.buffer_reuse_counter() + ); + println!( + "Bindgroup: created: {}, resued : {}", + cache.bindgroups.bindgroup_counter(), + cache.bindgroups.cached_bindgroup_use_counter() + ); + println!( + "Inplace used: unary: {}, binary {}, copy: {}", + self.unary_inplace_counter.get(), + self.binary_inplace_counter.get(), + self.copy_inplace_counter.get() + ); } #[cfg(feature = "wgpu_debug")] - pub async fn get_debug_info_full(&self) -> crate::Result{ + pub async fn get_debug_info_full(&self) -> crate::Result { use super::wgpu_functions::synchronize_async; synchronize_async(self).await?; - let data = wgpu_functions::read_data_from_gpu_async_buffer::(self, &self.debug.query_set_buffer).await; - + let data = wgpu_functions::read_data_from_gpu_async_buffer::( + self, + &self.debug.query_set_buffer, + ) + .await; + let period = self.queue.get_timestamp_period(); let mut result = Measurements::new(period); let mut last_end_time = 0u64; let mut i = 0; let mut shader_pipeline2 = self.debug.shader_pipeline.lock().unwrap(); let shader_pipeline = shader_pipeline2.clone(); - let mut indexes : Vec<_> = shader_pipeline.into_iter().collect(); + let mut indexes: Vec<_> = shader_pipeline.into_iter().collect(); indexes.sort_by_key(|f| f.0); for p in indexes { let start_time = data[(p.0 / 8) as usize]; @@ -488,204 +705,313 @@ impl WgpuDevice{ } last_end_time = end_time; i += 1; - result.data.push(MInfo::new(p.1.0.to_owned(), start_time, end_time, p.1.1, p.1.2, p.1.3, p.1.4)); + result.data.push(MInfo::new( + p.1 .0.to_owned(), + start_time, + end_time, + p.1 .1, + p.1 .2, + p.1 .3, + p.1 .4, + )); } - self.debug.counter.store(0u32, std::sync::atomic::Ordering::Relaxed); + self.debug + .counter + .store(0u32, std::sync::atomic::Ordering::Relaxed); shader_pipeline2.clear(); Ok(result) } #[cfg(feature = "wgpu_debug")] - pub async fn get_debug_info(&self) -> crate::Result>>{ + pub async fn get_debug_info( + &self, + ) -> crate::Result>> { let info = self.get_debug_info_full().await?; - let mut map: std::collections::HashMap> = std::collections::HashMap::new(); + let mut map: std::collections::HashMap> = + std::collections::HashMap::new(); for item in info.data.iter() { - map.entry(item.label.clone()).or_insert_with(Vec::new).push((item.end_time - item.start_time, item.output_size, item.x, item.y, item.z)); + map.entry(item.label.clone()) + .or_insert_with(Vec::new) + .push(( + item.end_time - item.start_time, + item.output_size, + item.x, + item.y, + item.z, + )); } return Ok(map); } #[cfg(feature = "wgpu_debug")] - pub fn get_pipeline_info(&self) -> crate::Result>{ + pub fn get_pipeline_info(&self) -> crate::Result> { use super::debug_info; let shaders = self.shader.lock().unwrap(); let queue = self.command_queue.lock().unwrap(); - - return Ok(shaders.iter().map(|(k, v)|{ - let pipelines = v.pipelines.lock().unwrap(); - let s = debug_info::ShaderInfo{ - name: format!("{:?}", k).to_owned(), - pipelines: pipelines.iter().map(|(pk, _)|{ - return debug_info::PipelineInfo { - name: format!("{:?}", pk.0).to_owned(), - consts : queue.id_to_const_array[pk.1].clone() - } - }).collect() - - }; - return s; - } - ).collect()); + + return Ok(shaders + .iter() + .map(|(k, v)| { + let pipelines = v.pipelines.lock().unwrap(); + let s = debug_info::ShaderInfo { + name: format!("{:?}", k).to_owned(), + pipelines: pipelines + .iter() + .map(|(pk, _)| { + return debug_info::PipelineInfo { + name: format!("{:?}", pk.0).to_owned(), + consts: queue.id_to_const_array[pk.1].clone(), + }; + }) + .collect(), + }; + return s; + }) + .collect()); } #[instrument] - fn load_pipeline(device : &wgpu::Device, shader : Arc, pipeline : &PipelineType, pipeline_layout : &wgpu::PipelineLayout, consts : &HashMap) -> wgpu::ComputePipeline{ + fn load_pipeline( + device: &wgpu::Device, + shader: Arc, + pipeline: &PipelineType, + pipeline_layout: &wgpu::PipelineLayout, + consts: &HashMap, + ) -> wgpu::ComputePipeline { let entry_point = pipeline.0.get_entry_point(); - if consts.is_empty(){ - return device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + if consts.is_empty() { + return device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, layout: Some(pipeline_layout), module: &shader, entry_point: entry_point, - compilation_options : wgpu::PipelineCompilationOptions::default(), - cache : None + compilation_options: wgpu::PipelineCompilationOptions::default(), + cache: None, }); - } - else{ - return device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + } else { + return device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, layout: Some(pipeline_layout), module: &shader, entry_point: entry_point, - compilation_options : wgpu::PipelineCompilationOptions{ + compilation_options: wgpu::PipelineCompilationOptions { constants: &consts, zero_initialize_workgroup_memory: true, vertex_pulling_transform: false, }, - cache : None + cache: None, }); } } #[instrument] - pub (crate) fn get_pipeline(&self, pipeline: &PipelineType, pipeline_layout : &wgpu::PipelineLayout, consts : &HashMap) -> crate::Result> { + pub(crate) fn get_pipeline( + &self, + pipeline: &PipelineType, + pipeline_layout: &wgpu::PipelineLayout, + consts: &HashMap, + ) -> crate::Result> { let shader = pipeline.0.get_shader(); let mut shaders = self.shader.lock().unwrap(); - if !shaders.contains_key(&shader){ + if !shaders.contains_key(&shader) { let s = wgpu_functions::get_shader(&self.device, shader.load_shader()); - shaders.insert(shader.clone(), ShaderModuleComputePipelines{ shader: Arc::new(s), pipelines: Mutex::new(HashMap::new())}); + shaders.insert( + shader.clone(), + ShaderModuleComputePipelines { + shader: Arc::new(s), + pipelines: Mutex::new(HashMap::new()), + }, + ); } - - if let Some(s) = shaders.get(&shader){ + + if let Some(s) = shaders.get(&shader) { let mut pipelines = s.pipelines.lock().unwrap(); - if !pipelines.contains_key(&pipeline){ - let p = crate::WgpuDevice::load_pipeline(&self.device, s.shader.clone(), pipeline ,pipeline_layout, consts); + if !pipelines.contains_key(&pipeline) { + let p = crate::WgpuDevice::load_pipeline( + &self.device, + s.shader.clone(), + pipeline, + pipeline_layout, + consts, + ); pipelines.insert(pipeline.clone(), Arc::new(p)); } - - if let Some(p) = pipelines.get(&pipeline){ + + if let Some(p) = pipelines.get(&pipeline) { return Ok(p.clone()); - } - else{ + } else { panic!("Not expected") } - } - else{ + } else { panic!("Not expected") } } - pub (crate) async fn synchronize_async(&self) -> crate::Result<()> { + pub(crate) async fn synchronize_async(&self) -> crate::Result<()> { wgpu_functions::synchronize_async(self).await } } - -impl crate::backend::BackendDevice for WgpuDevice{ +impl crate::backend::BackendDevice for WgpuDevice { type Storage = WgpuStorage; fn new(_: usize) -> crate::Result { - return Err(crate::Error::WebGpu("A WgpuDevice must be created using the asynchronous create method".to_owned().into())); + return Err(crate::Error::WebGpu( + "A WgpuDevice must be created using the asynchronous create method" + .to_owned() + .into(), + )); } fn location(&self) -> crate::DeviceLocation { - return crate::DeviceLocation::Wgpu { gpu_id: 0 }; + return crate::DeviceLocation::Wgpu { gpu_id: 0 }; } fn same_device(&self, other: &Self) -> bool { return self.device.global_id() == other.device.global_id(); } - fn zeros_impl(&self, shape: &crate::Shape, dtype: crate::DType) -> crate::Result { + fn zeros_impl( + &self, + shape: &crate::Shape, + dtype: crate::DType, + ) -> crate::Result { let buffer = create_wgpu_storage(self, dtype, shape.elem_count() * 4); - if shape.elem_count() > 0{ - wgpu_functions::queue_unary_inplace_op(self, buffer.buffer.clone(), UnaryOperation::SetZero, 0.0, 0.0,dtype, &Layout::contiguous(shape))?; + if shape.elem_count() > 0 { + wgpu_functions::queue_unary_inplace_op( + self, + buffer.buffer.clone(), + UnaryOperation::SetZero, + 0.0, + 0.0, + dtype, + &Layout::contiguous(shape), + )?; } - + return Ok(buffer); } fn ones_impl(&self, shape: &crate::Shape, dtype: crate::DType) -> crate::Result { let buffer = create_wgpu_storage(self, dtype, shape.elem_count() * 4); - - if shape.elem_count() > 0{ - wgpu_functions::queue_unary_inplace_op(self, buffer.buffer.clone(), UnaryOperation::SetOne, 0.0, 0.0,dtype,&Layout::contiguous(shape))?; + + if shape.elem_count() > 0 { + wgpu_functions::queue_unary_inplace_op( + self, + buffer.buffer.clone(), + UnaryOperation::SetOne, + 0.0, + 0.0, + dtype, + &Layout::contiguous(shape), + )?; } return Ok(buffer); } - unsafe fn alloc_uninit(&self, shape: &crate::Shape, dtype: crate::DType) -> crate::Result { - if dtype == crate::DType::F32 || dtype == crate::DType::U32{ + unsafe fn alloc_uninit( + &self, + shape: &crate::Shape, + dtype: crate::DType, + ) -> crate::Result { + if dtype == crate::DType::F32 || dtype == crate::DType::U32 { return Ok(create_wgpu_storage(self, dtype, shape.elem_count() * 4)); - } - else{ + } else { wrongType!(alloc_uninit, dtype); } } - fn storage_from_slice(&self, data : &[T]) -> crate::Result { + fn storage_from_slice(&self, data: &[T]) -> crate::Result { let buffer; - if T::DTYPE == crate::DType::F32{ - let data = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) }; - buffer = create_wgpu_storage_init(self,T::DTYPE, &data)?; - } - else if T::DTYPE == crate::DType::U32{ - let data = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u32, data.len()) }; - buffer = create_wgpu_storage_init(self,T::DTYPE, &data)?; - } - else{ + if T::DTYPE == crate::DType::F32 { + let data = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) }; + buffer = create_wgpu_storage_init(self, T::DTYPE, &data)?; + } else if T::DTYPE == crate::DType::U32 { + let data = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u32, data.len()) }; + buffer = create_wgpu_storage_init(self, T::DTYPE, &data)?; + } else { // Panic if T is not f32 or u32 wrongType!(storage_from_slice, T::DTYPE); } return Ok(buffer); } - fn storage_from_cpu_storage(&self, storage: &crate::CpuStorage) -> crate::Result { - match storage{ + fn storage_from_cpu_storage( + &self, + storage: &crate::CpuStorage, + ) -> crate::Result { + match storage { crate::CpuStorage::F32(data) => { return create_wgpu_storage_init(self, crate::DType::F32, data); - }, + } crate::CpuStorage::U32(data) => { - return create_wgpu_storage_init(self, crate::DType::U32, data); - }, - _ => wrongType!(storage_from_cpu_storage, storage.dtype()), + return create_wgpu_storage_init(self, crate::DType::U32, data); + } + _ => wrongType!(storage_from_cpu_storage, storage.dtype()), } } - fn storage_from_cpu_storage_owned(&self, storage: crate::CpuStorage) -> crate::Result { - match storage{ + fn storage_from_cpu_storage_owned( + &self, + storage: crate::CpuStorage, + ) -> crate::Result { + match storage { crate::CpuStorage::F32(data) => { - return create_wgpu_storage_init(self,crate::DType::F32, &data); - }, + return create_wgpu_storage_init(self, crate::DType::F32, &data); + } crate::CpuStorage::U32(data) => { - return create_wgpu_storage_init(self,crate::DType::U32, &data); - }, - _ => wrongType!(storage_from_cpu_storage_owned, storage.dtype()), + return create_wgpu_storage_init(self, crate::DType::U32, &data); + } + // i64 + crate::CpuStorage::I64(data) => { + return create_wgpu_storage_init(self, crate::DType::I64, &data); + } + _ => wrongType!(storage_from_cpu_storage_owned, storage.dtype()), } } - fn rand_uniform(&self, shape: &crate::Shape, dtype: crate::DType, lo: f64, up: f64) -> crate::Result { + fn rand_uniform( + &self, + shape: &crate::Shape, + dtype: crate::DType, + lo: f64, + up: f64, + ) -> crate::Result { let buffer = create_wgpu_storage(self, dtype, shape.elem_count() * 4); - wgpu_functions::queue_unary_inplace_op(self, buffer.buffer.clone(), UnaryOperation::RandUniform, lo as f32, up as f32,dtype,&Layout::contiguous(shape))?; + wgpu_functions::queue_unary_inplace_op( + self, + buffer.buffer.clone(), + UnaryOperation::RandUniform, + lo as f32, + up as f32, + dtype, + &Layout::contiguous(shape), + )?; return Ok(buffer); } - fn rand_normal(&self, shape: &crate::Shape, dtype: crate::DType, mean: f64, std: f64) -> crate::Result { + fn rand_normal( + &self, + shape: &crate::Shape, + dtype: crate::DType, + mean: f64, + std: f64, + ) -> crate::Result { let buffer = create_wgpu_storage(self, dtype, shape.elem_count() * 4); - wgpu_functions::queue_unary_inplace_op(self, buffer.buffer.clone(), UnaryOperation::RandNormal, mean as f32, std as f32, dtype,&Layout::contiguous(shape))?; + wgpu_functions::queue_unary_inplace_op( + self, + buffer.buffer.clone(), + UnaryOperation::RandNormal, + mean as f32, + std as f32, + dtype, + &Layout::contiguous(shape), + )?; return Ok(buffer); } @@ -693,7 +1019,6 @@ impl crate::backend::BackendDevice for WgpuDevice{ notImplemented!(set_seed) } - #[cfg(target_arch = "wasm32")] fn synchronize(&self) -> crate::Result<()> { panic!("Synchronize is not possible on wasm. (on_submitted_work_done is currently not implemented in wgpu). In addition synchronize can only be handled async"); diff --git a/candle-core/src/wgpu_backend/storage.rs b/candle-core/src/wgpu_backend/storage.rs index a52715880a..464474b605 100644 --- a/candle-core/src/wgpu_backend/storage.rs +++ b/candle-core/src/wgpu_backend/storage.rs @@ -1,49 +1,67 @@ use crate::{backend::BackendStorage, DType, Layout, Shape}; use super::{ - cache::BufferReferenceId, device::WgpuDevice, util::ToU64, wgpu_functions::{self, binary::BinaryOperation, cmp::CmpOperation, matmul::SGEMMParams, read_data_from_gpu_async, reduce::ReduceOperations, unary::UnaryOperation} + cache::BufferReferenceId, + device::WgpuDevice, + util::ToU64, + wgpu_functions::{ + self, binary::BinaryOperation, cmp::CmpOperation, matmul::SGEMMParams, + read_data_from_gpu_async, reduce::ReduceOperations, unary::UnaryOperation, + }, }; #[derive(Debug)] pub struct WgpuStorage { pub buffer: BufferReferenceId, - pub size : u64, + pub size: u64, pub wgpu_device: WgpuDevice, pub dtype: crate::DType, } -pub fn create_wgpu_storage(dev : &WgpuDevice, dtype : crate::DType, size : T) -> WgpuStorage{ +pub fn create_wgpu_storage( + dev: &WgpuDevice, + dtype: crate::DType, + size: T, +) -> WgpuStorage { let size = size.to_u64(); let buffer; { let mut cache = dev.cache.lock().unwrap(); buffer = cache.create_buffer_reference(size, true); - } - return WgpuStorage::new(buffer,dev.clone(), dtype, size); + } + return WgpuStorage::new(buffer, dev.clone(), dtype, size); } -pub fn create_wgpu_storage_init(dev : &WgpuDevice, dtype : crate::DType, data : &[T]) -> crate::Result{ - let data : &[u8] = bytemuck::cast_slice(data); +pub fn create_wgpu_storage_init( + dev: &WgpuDevice, + dtype: crate::DType, + data: &[T], +) -> crate::Result { + let data: &[u8] = bytemuck::cast_slice(data); let size = data.len(); let buffer; { - if dev.configuration.flush_gpu_before_buffer_init{ + if dev.configuration.flush_gpu_before_buffer_init { dev.flush_gpu_command()?; } let mut cache = dev.cache.lock().unwrap(); - buffer = cache.create_buffer_reference_init(dev, data,true); - } - return Ok(WgpuStorage::new(buffer,dev.clone(), dtype, size as u64)); + buffer = cache.create_buffer_reference_init(dev, data, true); + } + return Ok(WgpuStorage::new(buffer, dev.clone(), dtype, size as u64)); } - impl WgpuStorage { - pub fn new(buffer: BufferReferenceId, wgpu_device: WgpuDevice, dtype: crate::DType, size : u64) -> Self { + pub fn new( + buffer: BufferReferenceId, + wgpu_device: WgpuDevice, + dtype: crate::DType, + size: u64, + ) -> Self { Self { buffer, wgpu_device, dtype: dtype, - size + size, } } @@ -60,21 +78,27 @@ impl WgpuStorage { )) } crate::DType::U8 => { - return Ok( - crate::CpuStorage::U8(read_data_from_gpu_async(&self.wgpu_device, self.buffer.clone()).await?) - ) + return Ok(crate::CpuStorage::U8( + read_data_from_gpu_async(&self.wgpu_device, self.buffer.clone()).await?, + )) + } + // i64 + crate::DType::I64 => { + return Ok(crate::CpuStorage::I64( + read_data_from_gpu_async(&self.wgpu_device, self.buffer.clone()).await?, + )) } _ => todo!(), } } pub fn get_length(&self) -> usize { - return (self.size / 4) as usize; //f32 } fn try_clone_layout(&self, layout: &crate::Layout) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); self.copy_strided_src(&buffer_dest, 0, layout)?; return Ok(buffer_dest); } @@ -96,16 +120,22 @@ impl WgpuStorage { dst_offset, start, to_copy, - self.dtype + self.dtype, )?; } None => { - wgpu_functions::queue_copy_strided(self.device(), dst.buffer, self.buffer.clone(), self.dtype, src_l, dst_offset as u32)?; + wgpu_functions::queue_copy_strided( + self.device(), + dst.buffer, + self.buffer.clone(), + self.dtype, + src_l, + dst_offset as u32, + )?; } } return Ok(()); } - } impl crate::backend::BackendStorage for WgpuStorage { @@ -120,7 +150,7 @@ impl crate::backend::BackendStorage for WgpuStorage { 0, 0, (self.size / 4) as usize, - self.dtype + self.dtype, )?; return Ok(buffer_dest); @@ -136,7 +166,8 @@ impl crate::backend::BackendStorage for WgpuStorage { #[cfg(target_arch = "wasm32")] fn to_cpu_storage(&self) -> crate::Result { - panic!("Sync copy to CpuStorage is not allowed for WebGpu device in WebAssembly. First copy the date asynchronously to a CpuStorage"); //panic, so we get a stacktrace and see where we wanted to copy + panic!("Sync copy to CpuStorage is not allowed for WebGpu device in WebAssembly. First copy the date asynchronously to a CpuStorage"); + //panic, so we get a stacktrace and see where we wanted to copy //return Err(crate::Error::WebGpu("Sync copy to CpuStorage is not allowed for WebGpu device in WebAssembly. First copy the date asynchronously to a CpuStorage".to_owned().into())); } @@ -146,7 +177,8 @@ impl crate::backend::BackendStorage for WgpuStorage { } fn affine(&self, layout: &crate::Layout, mul: f64, add: f64) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); wgpu_functions::queue_unary_from_buffer_op( self.device(), buffer_dest.buffer.clone(), @@ -161,7 +193,8 @@ impl crate::backend::BackendStorage for WgpuStorage { } fn powf(&self, layout: &crate::Layout, e: f64) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); wgpu_functions::queue_unary_from_buffer_op( self.device(), buffer_dest.buffer.clone(), @@ -176,7 +209,8 @@ impl crate::backend::BackendStorage for WgpuStorage { } fn elu(&self, layout: &crate::Layout, alpha: f64) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); wgpu_functions::queue_unary_from_buffer_op( self.device(), buffer_dest.buffer.clone(), @@ -219,7 +253,8 @@ impl crate::backend::BackendStorage for WgpuStorage { strides.reverse(); strides } - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, dst_shape.elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, dst_shape.elem_count() * 4); let op = match reduce_op { crate::op::ReduceOp::Sum => ReduceOperations::Sum, @@ -302,7 +337,6 @@ impl crate::backend::BackendStorage for WgpuStorage { } let output_count = current_shape.iter().fold(1, |prev, c| prev * c); - let mut cache = self.device().cache.lock().unwrap(); let buffer_temp = cache.create_buffer_reference(output_count * 4, false); @@ -369,7 +403,7 @@ impl crate::backend::BackendStorage for WgpuStorage { ) -> crate::Result { let buffer_size = ((lhs_l.shape().elem_count() + 3) / 4) * 4; //TODO: get next divisible by 4 let buffer_dest = create_wgpu_storage(self.device(), crate::DType::U8, buffer_size); - + let op2 = match op { crate::op::CmpOp::Eq => CmpOperation::Eq, crate::op::CmpOp::Ne => CmpOperation::Ne, @@ -397,7 +431,8 @@ impl crate::backend::BackendStorage for WgpuStorage { (DType::F32, DType::F32) => self.try_clone_layout(layout), (DType::U32, DType::U32) => self.try_clone_layout(layout), (DType::U32, DType::F32) => { - let buffer_dest = create_wgpu_storage(self.device(), DType::F32, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), DType::F32, layout.shape().elem_count() * 4); wgpu_functions::queue_convert_u32_to_f32( self.device(), buffer_dest.buffer.clone(), @@ -407,7 +442,8 @@ impl crate::backend::BackendStorage for WgpuStorage { Ok(buffer_dest) } (DType::U8, DType::F32) => { - let buffer_dest = create_wgpu_storage(self.device(), DType::F32, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), DType::F32, layout.shape().elem_count() * 4); wgpu_functions::queue_convert_u8_to_f32( self.device(), buffer_dest.buffer.clone(), @@ -417,7 +453,8 @@ impl crate::backend::BackendStorage for WgpuStorage { Ok(buffer_dest) } (DType::F32, DType::U32) => { - let buffer_dest = create_wgpu_storage(self.device(), DType::U32, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), DType::U32, layout.shape().elem_count() * 4); wgpu_functions::queue_convert_f32_to_u32( self.device(), buffer_dest.buffer.clone(), @@ -425,41 +462,66 @@ impl crate::backend::BackendStorage for WgpuStorage { layout, )?; Ok(buffer_dest) - }, + } (DType::F32, DType::U8) => { - if !layout.is_contiguous(){ - panic!("conversion from {:?} to {:?} not suported for non contiguous matrix", self.dtype, dtype); + if !layout.is_contiguous() { + panic!( + "conversion from {:?} to {:?} not suported for non contiguous matrix", + self.dtype, dtype + ); } - let buffer_dest = create_wgpu_storage(self.device(), DType::U8, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), DType::U8, layout.shape().elem_count() * 4); wgpu_functions::queue_convert_f32_to_u8( self.device(), buffer_dest.buffer.clone(), self.buffer.clone(), layout.start_offset() as u32, - layout.shape().elem_count() as u32 + layout.shape().elem_count() as u32, )?; Ok(buffer_dest) - }, + } (DType::U32, DType::U8) => { - if !layout.is_contiguous(){ - panic!("conversion from {:?} to {:?} not suported for non contiguous matrix", self.dtype, dtype); + if !layout.is_contiguous() { + panic!( + "conversion from {:?} to {:?} not suported for non contiguous matrix", + self.dtype, dtype + ); } - let buffer_dest = create_wgpu_storage(self.device(), DType::U8, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), DType::U8, layout.shape().elem_count() * 4); wgpu_functions::queue_convert_u32_to_u8( self.device(), buffer_dest.buffer.clone(), self.buffer.clone(), layout.start_offset() as u32, - layout.shape().elem_count() as u32 + layout.shape().elem_count() as u32, + )?; + Ok(buffer_dest) + } + // U32 to I64 + (DType::U32, DType::I64) => { + let buffer_dest = + create_wgpu_storage(self.device(), DType::I64, layout.shape().elem_count() * 8); + wgpu_functions::queue_convert_u32_to_i64( + self.device(), + buffer_dest.buffer.clone(), + self.buffer.clone(), + layout.start_offset() as u32, + layout.shape().elem_count() as u32, )?; Ok(buffer_dest) } - _ => panic!("conversion from {:?} to {:?} not suported on wgpu", self.dtype, dtype), + _ => panic!( + "conversion from {:?} to {:?} not suported on wgpu", + self.dtype, dtype + ), } } fn unary_impl(&self, layout: &crate::Layout) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, layout.shape().elem_count() * 4); let op = match B::NAME { "gelu" => UnaryOperation::Gelu, @@ -506,7 +568,11 @@ impl crate::backend::BackendStorage for WgpuStorage { lhs_layout: &crate::Layout, rhs_layout: &crate::Layout, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, lhs_layout.shape().elem_count() * 4); + let buffer_dest = create_wgpu_storage( + self.device(), + self.dtype, + lhs_layout.shape().elem_count() * 4, + ); let op = match B::NAME { "add" => BinaryOperation::Add, @@ -535,15 +601,29 @@ impl crate::backend::BackendStorage for WgpuStorage { fn where_cond( &self, - input_layout : &crate::Layout, + input_layout: &crate::Layout, t: &Self, //true values t_layout: &crate::Layout, f: &Self, //false values f_layout: &crate::Layout, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), t.dtype, input_layout.shape().elem_count() * 4); + let buffer_dest = create_wgpu_storage( + self.device(), + t.dtype, + input_layout.shape().elem_count() * 4, + ); - wgpu_functions::where_cond::queue_where_cond_u32(self.device(), buffer_dest.buffer.clone(), self.buffer.clone(), t.buffer.clone(), f.buffer.clone(), input_layout, t_layout, f_layout, t.dtype)?; + wgpu_functions::where_cond::queue_where_cond_u32( + self.device(), + buffer_dest.buffer.clone(), + self.buffer.clone(), + t.buffer.clone(), + f.buffer.clone(), + input_layout, + t_layout, + f_layout, + t.dtype, + )?; return Ok(buffer_dest); } @@ -554,7 +634,11 @@ impl crate::backend::BackendStorage for WgpuStorage { kernel_l: &crate::Layout, params: &crate::conv::ParamsConv1D, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (params.b_size * params.c_out * params.l_out()) * 4); + let buffer_dest = create_wgpu_storage( + self.device(), + self.dtype, + (params.b_size * params.c_out * params.l_out()) * 4, + ); wgpu_functions::queue_conv1d( self.device(), @@ -576,7 +660,11 @@ impl crate::backend::BackendStorage for WgpuStorage { kernel_l: &crate::Layout, params: &crate::conv::ParamsConvTranspose1D, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (params.b_size * params.c_out * params.l_out()) * 4); + let buffer_dest = create_wgpu_storage( + self.device(), + self.dtype, + (params.b_size * params.c_out * params.l_out()) * 4, + ); wgpu_functions::queue_conv1d_transpose( self.device(), buffer_dest.buffer.clone(), @@ -597,7 +685,11 @@ impl crate::backend::BackendStorage for WgpuStorage { kernel_l: &crate::Layout, params: &crate::conv::ParamsConv2D, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (params.b_size * params.c_out * params.out_h() * params.out_w()) * 4); + let buffer_dest = create_wgpu_storage( + self.device(), + self.dtype, + (params.b_size * params.c_out * params.out_h() * params.out_w()) * 4, + ); wgpu_functions::queue_conv2d( self.device(), buffer_dest.buffer.clone(), @@ -618,7 +710,11 @@ impl crate::backend::BackendStorage for WgpuStorage { kernel_l: &crate::Layout, params: &crate::conv::ParamsConvTranspose2D, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (params.b_size * params.c_out * params.out_h() * params.out_w()) * 4); + let buffer_dest = create_wgpu_storage( + self.device(), + self.dtype, + (params.b_size * params.c_out * params.out_h() * params.out_w()) * 4, + ); wgpu_functions::queue_conv2d_transpose( self.device(), buffer_dest.buffer.clone(), @@ -638,15 +734,23 @@ impl crate::backend::BackendStorage for WgpuStorage { kernel_size: (usize, usize), stride: (usize, usize), ) -> crate::Result { - let (b, c, h, w) = layout.shape().dims4()?; let h_out = (h - kernel_size.1) / stride.1 + 1; let w_out = (w - kernel_size.0) / stride.0 + 1; - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (b * c * h_out * w_out) * 4); - - wgpu_functions::queue_avg_pool2d(self.device(), buffer_dest.buffer.clone(), self.buffer.clone(),layout, self.dtype(), kernel_size, stride)?; - + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, (b * c * h_out * w_out) * 4); + + wgpu_functions::queue_avg_pool2d( + self.device(), + buffer_dest.buffer.clone(), + self.buffer.clone(), + layout, + self.dtype(), + kernel_size, + stride, + )?; + return Ok(buffer_dest); } @@ -656,35 +760,70 @@ impl crate::backend::BackendStorage for WgpuStorage { kernel_size: (usize, usize), stride: (usize, usize), ) -> crate::Result { - let (b, c, h, w) = layout.shape().dims4()?; let h_out = (h - kernel_size.1) / stride.1 + 1; let w_out = (w - kernel_size.0) / stride.0 + 1; - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (b * c * h_out * w_out) * 4); - - wgpu_functions::queue_max_pool2d(self.device(), buffer_dest.buffer.clone(), self.buffer.clone(),layout, self.dtype(), kernel_size, stride)?; - + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, (b * c * h_out * w_out) * 4); + + wgpu_functions::queue_max_pool2d( + self.device(), + buffer_dest.buffer.clone(), + self.buffer.clone(), + layout, + self.dtype(), + kernel_size, + stride, + )?; + return Ok(buffer_dest); } - fn upsample_nearest1d(&self, layout: &crate::Layout, target_size: usize) -> crate::Result { + fn upsample_nearest1d( + &self, + layout: &crate::Layout, + target_size: usize, + ) -> crate::Result { let (b, c, _) = layout.shape().dims3()?; - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (b * c * target_size) * 4); - - wgpu_functions::queue_upsample1d(self.device(), buffer_dest.buffer.clone(), self.buffer.clone(),layout, self.dtype(), target_size)?; - + let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (b * c * target_size) * 4); + + wgpu_functions::queue_upsample1d( + self.device(), + buffer_dest.buffer.clone(), + self.buffer.clone(), + layout, + self.dtype(), + target_size, + )?; + return Ok(buffer_dest); } - fn upsample_nearest2d(&self, layout: &crate::Layout, target_size_y: usize, target_size_x: usize) -> crate::Result { + fn upsample_nearest2d( + &self, + layout: &crate::Layout, + target_size_y: usize, + target_size_x: usize, + ) -> crate::Result { let (b, c, _, _) = layout.shape().dims4()?; - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (b * c * target_size_x * target_size_y) * 4); - - wgpu_functions::queue_upsample2d(self.device(), buffer_dest.buffer.clone(), self.buffer.clone(),layout, self.dtype(), (target_size_y, target_size_x))?; - + let buffer_dest = create_wgpu_storage( + self.device(), + self.dtype, + (b * c * target_size_x * target_size_y) * 4, + ); + + wgpu_functions::queue_upsample2d( + self.device(), + buffer_dest.buffer.clone(), + self.buffer.clone(), + layout, + self.dtype(), + (target_size_y, target_size_x), + )?; + return Ok(buffer_dest); } @@ -695,10 +834,23 @@ impl crate::backend::BackendStorage for WgpuStorage { indexes_l: &Layout, d: usize, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (indexes_l.shape().elem_count()) * 4); + let buffer_dest = create_wgpu_storage( + self.device(), + self.dtype, + (indexes_l.shape().elem_count()) * 4, + ); + + wgpu_functions::queue_gather( + self.device(), + buffer_dest.buffer.clone(), + self.buffer.clone(), + indexes.buffer.clone(), + self.dtype(), + l, + indexes_l, + d, + )?; - wgpu_functions::queue_gather(self.device(), buffer_dest.buffer.clone(), self.buffer.clone(),indexes.buffer.clone(), self.dtype(), l, indexes_l, d)?; - return Ok(buffer_dest); } @@ -711,12 +863,23 @@ impl crate::backend::BackendStorage for WgpuStorage { source_l: &Layout, d: usize, ) -> crate::Result { - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (l.shape().elem_count()) * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, (l.shape().elem_count()) * 4); - self.copy_strided_src( &buffer_dest, 0, l)?; + self.copy_strided_src(&buffer_dest, 0, l)?; + + wgpu_functions::queue_scatter_add_inplace( + self.device(), + buffer_dest.buffer.clone(), + indexes.buffer.clone(), + source.buffer.clone(), + self.dtype(), + &Layout::contiguous(l.shape().clone()), + indexes_l, + source_l, + d, + )?; - wgpu_functions::queue_scatter_add_inplace(self.device(), buffer_dest.buffer.clone(),indexes.buffer.clone(), source.buffer.clone(), self.dtype(), &Layout::contiguous(l.shape().clone()), indexes_l, source_l, d)?; - return Ok(buffer_dest); } @@ -731,7 +894,8 @@ impl crate::backend::BackendStorage for WgpuStorage { new_shape[d] = rhs_l.shape().elem_count(); let new_shape = Shape::from_dims(&new_shape[..]); - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (new_shape.elem_count()) * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, (new_shape.elem_count()) * 4); wgpu_functions::queue_index_select( self.device(), @@ -755,13 +919,23 @@ impl crate::backend::BackendStorage for WgpuStorage { source_l: &Layout, d: usize, ) -> crate::Result { - - let buffer_dest = create_wgpu_storage(self.device(), self.dtype, (l.shape().elem_count()) * 4); + let buffer_dest = + create_wgpu_storage(self.device(), self.dtype, (l.shape().elem_count()) * 4); + + self.copy_strided_src(&buffer_dest, 0, l)?; - self.copy_strided_src( &buffer_dest, 0, l)?; + wgpu_functions::queue_index_add_inplace( + self.device(), + buffer_dest.buffer.clone(), + indexes.buffer.clone(), + source.buffer.clone(), + self.dtype(), + &Layout::contiguous(l.shape().clone()), + indexes_l, + source_l, + d, + )?; - wgpu_functions::queue_index_add_inplace(self.device(), buffer_dest.buffer.clone(),indexes.buffer.clone(), source.buffer.clone(), self.dtype(), &Layout::contiguous(l.shape().clone()), indexes_l, source_l, d)?; - return Ok(buffer_dest); } @@ -817,19 +991,18 @@ impl crate::backend::BackendStorage for WgpuStorage { src_stride1 as u32, dst_stride1 as u32, src_offset as u32, - dst_offset as u32 + dst_offset as u32, )?; Ok(()) } } - -impl Drop for WgpuStorage{ +impl Drop for WgpuStorage { fn drop(&mut self) { let mut cache = self.device().cache.lock().unwrap(); - cache.buffer_reference.queue_for_deletion(&self. buffer); + cache.buffer_reference.queue_for_deletion(&self.buffer); // if let Some(reference) = cache.buffer_reference.get_mut(&self.buffer){ // reference.set_referenced_by_candle_storage(false); // } } -} \ No newline at end of file +} diff --git a/candle-core/src/wgpu_backend/wgpu_functions/convert.rs b/candle-core/src/wgpu_backend/wgpu_functions/convert.rs index 03895e01e6..15d896967b 100644 --- a/candle-core/src/wgpu_backend/wgpu_functions/convert.rs +++ b/candle-core/src/wgpu_backend/wgpu_functions/convert.rs @@ -11,20 +11,18 @@ pub fn queue_convert_u32_to_f32( let mut meta = get_meta(&dev); meta.add_layout1(&input_layout); - let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U32, Functions::ConvertToF32)); - let bind_group = create_bind_group_input1( buffer_dest, buffer_input); + let bind_group = create_bind_group_input1(buffer_dest, buffer_input); enqueue( meta, pipeline, bind_group, input_layout.shape().elem_count() as u32, - input_layout.shape().elem_count() + input_layout.shape().elem_count(), ); return Ok(()); } - pub fn queue_convert_u8_to_f32( dev: &WgpuDevice, buffer_dest: BufferReferenceId, @@ -35,13 +33,13 @@ pub fn queue_convert_u8_to_f32( meta.add_layout1(&input_layout); let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U8, Functions::ConvertU8ToF32)); - let bind_group = create_bind_group_input1( buffer_dest, buffer_input); + let bind_group = create_bind_group_input1(buffer_dest, buffer_input); enqueue( meta, pipeline, bind_group, input_layout.shape().elem_count() as u32, - input_layout.shape().elem_count() + input_layout.shape().elem_count(), ); return Ok(()); } @@ -57,7 +55,7 @@ pub fn queue_convert_f32_to_u32( let pipeline = meta.get_pipeline(Pipelines::Convert(DType::F32, Functions::ConvertToU32)); - let bind_group = create_bind_group_input1( buffer_dest, buffer_input); + let bind_group = create_bind_group_input1(buffer_dest, buffer_input); enqueue( meta, pipeline, @@ -68,21 +66,20 @@ pub fn queue_convert_f32_to_u32( return Ok(()); } - pub fn queue_convert_u32_to_u8( dev: &WgpuDevice, buffer_dest: BufferReferenceId, buffer_input: BufferReferenceId, start_offset: u32, - size : u32 + size: u32, ) -> crate::Result<()> { let mut meta = get_meta(&dev); meta.add(start_offset); meta.add(size); let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U32, Functions::ConvertU32ToU8)); - - let bind_group = create_bind_group_input1( buffer_dest, buffer_input); + + let bind_group = create_bind_group_input1(buffer_dest, buffer_input); enqueue( meta, pipeline, @@ -93,12 +90,45 @@ pub fn queue_convert_u32_to_u8( return Ok(()); } +pub fn queue_convert_u32_to_i64( + dev: &WgpuDevice, + buffer_dest: BufferReferenceId, + buffer_input: BufferReferenceId, + start_offset: u32, + size: u32, +) -> crate::Result<()> { + let mut meta = get_meta(&dev); + meta.add(start_offset); + meta.add(size); + + // Get the appropriate pipeline for converting u32 to i64 + let pipeline = meta.get_pipeline(Pipelines::Convert(DType::U32, Functions::ConvertU32ToI64)); + + // Create a bind group for the destination and input buffers + let bind_group = create_bind_group_input1_8(buffer_dest, buffer_input); + + // Calculate the number of workgroups needed. + // Assuming that each workgroup processes 4 u32 values, and each u32 is converted to an i64 + let num_workgroups = ((size + 3) / 4) as u32; + + // Enqueue the compute operation + enqueue( + meta, + pipeline, + bind_group, + num_workgroups, // number of workgroups + size as usize, // total size (in elements) + ); + + Ok(()) +} + pub fn queue_convert_f32_to_u8( dev: &WgpuDevice, buffer_dest: BufferReferenceId, buffer_input: BufferReferenceId, start_offset: u32, - size : u32 + size: u32, ) -> crate::Result<()> { let mut meta = get_meta(&dev); meta.add(start_offset); @@ -106,7 +136,7 @@ pub fn queue_convert_f32_to_u8( let pipeline = meta.get_pipeline(Pipelines::Convert(DType::F32, Functions::ConvertF32ToU8)); - let bind_group = create_bind_group_input1( buffer_dest, buffer_input); + let bind_group = create_bind_group_input1(buffer_dest, buffer_input); enqueue( meta, pipeline, @@ -115,4 +145,4 @@ pub fn queue_convert_f32_to_u8( size as usize, ); return Ok(()); -} \ No newline at end of file +} diff --git a/candle-core/src/wgpu_backend/wgpu_functions/mod.rs b/candle-core/src/wgpu_backend/wgpu_functions/mod.rs index ad940f6626..c88cba574e 100644 --- a/candle-core/src/wgpu_backend/wgpu_functions/mod.rs +++ b/candle-core/src/wgpu_backend/wgpu_functions/mod.rs @@ -44,7 +44,7 @@ pub use cmp::queue_cmp_buffer_from_buffer; pub use conv2d::{queue_conv1d, queue_conv1d_transpose, queue_conv2d, queue_conv2d_transpose}; pub use convert::{ queue_convert_f32_to_u32, queue_convert_f32_to_u8, queue_convert_u32_to_f32, - queue_convert_u32_to_u8, queue_convert_u8_to_f32, + queue_convert_u32_to_u8, queue_convert_u8_to_f32, queue_convert_u32_to_i64, }; pub use copy::{queue_copy, queue_copy2d, queue_copy3d,queue_copy3d_padded, queue_copy_strided}; pub use gather::{queue_gather, queue_index_add_inplace, queue_scatter_add_inplace}; @@ -306,7 +306,7 @@ fn get_command_buffer( let vd = buffers.get_dest(); match buffers.get_input(){ BindgroupInputBase::Bindgroup0 => {}, - BindgroupInputBase::Bindgroup1(v1, _) => { + BindgroupInputBase::Bindgroup1(v1, _, _) => { if v1 == vd{ panic!("B1: output and input are equal"); } @@ -445,7 +445,7 @@ fn prepare(dev: &WgpuDevice, queue_buffer: &mut QueueBuffer, cache : &mut ModelC match input { BindgroupInputBase::Bindgroup0 => {}, - BindgroupInputBase::Bindgroup1(v1, _) => { + BindgroupInputBase::Bindgroup1(v1, _, _) => { check_buffer(v1); }, BindgroupInputBase::Bindgroup2(v1,v2, _) => { @@ -548,7 +548,7 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut { if let Pipelines::Unary(dtype, candle_wgpu_kernels::unary::Functions::UnaryFromBufferContiguous) = &q.pipeline.0{ if q.pipeline.2.input1_inplaceable{ - if let BindgroupReferenceInput::Bindgroup1(v1_id, _) = bindgroup_reference.get_input() + if let BindgroupReferenceInput::Bindgroup1(v1_id, _, _) = bindgroup_reference.get_input() { if optmize_inplace(bindgroup_reference.get_dest(), v1_id) { @@ -578,7 +578,7 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut q.pipeline.0 = Pipelines::Binary(dtype.clone(), candle_wgpu_kernels::binary::Functions::BinaryBufferInplace1ContiguousBoth); q.bindgroup = DispatchedBindgroup::BindgroupReference( - BindGroupReference::new(v1_id.clone(), BindgroupInputBase::Bindgroup1(v2_id.clone(), false))); + BindGroupReference::new(v1_id.clone(), BindgroupInputBase::Bindgroup1(v2_id.clone(), false,false))); } } else if q.pipeline.2.input2_inplaceable{ @@ -587,14 +587,14 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut q.pipeline.0 = Pipelines::Binary(dtype.clone(), candle_wgpu_kernels::binary::Functions::BinaryBufferInplace2ContiguousBoth); q.bindgroup = DispatchedBindgroup::BindgroupReference( - BindGroupReference::new(v2_id.clone(), BindgroupInputBase::Bindgroup1(v1_id.clone(), false))); + BindGroupReference::new(v2_id.clone(), BindgroupInputBase::Bindgroup1(v1_id.clone(), false,false))); } } } } else if let Pipelines::Copy(_, candle_wgpu_kernels::copy::Functions::Copy) = &q.pipeline.0{ if q.pipeline.2.input1_inplaceable{ - if let BindgroupReferenceInput::Bindgroup1(v1_id, _) = bindgroup_reference.get_input() + if let BindgroupReferenceInput::Bindgroup1(v1_id, _,_) = bindgroup_reference.get_input() { let v1 = cache.buffer_reference.get(v1_id); if let Some(v1) = v1{ @@ -659,7 +659,7 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut chec_buffer(dest, cache, command_index); match input { BindgroupInputBase::Bindgroup0 => {}, - BindgroupInputBase::Bindgroup1(v1, _) => {chec_buffer(v1, cache, command_index);}, + BindgroupInputBase::Bindgroup1(v1, _, _) => {chec_buffer(v1, cache, command_index);}, BindgroupInputBase::Bindgroup2(v1, v2, _) => {chec_buffer(v1, cache, command_index);chec_buffer(v2, cache, command_index);}, BindgroupInputBase::Bindgroup3(v1, v2, v3) => {chec_buffer(v1, cache, command_index);chec_buffer(v2, cache, command_index);chec_buffer(v3, cache, command_index);}, } @@ -675,10 +675,13 @@ fn set_buffers(dev: &WgpuDevice, command_buffer: &mut QueueBuffer, index : &mut BindgroupReferenceInput::Bindgroup0 => { &dev.bindgroup_layouts.pipeline_layout0 } - BindgroupReferenceInput::Bindgroup1( _,false) => { + BindgroupReferenceInput::Bindgroup1(_, _, true) => { + &dev.bindgroup_layouts.pipeline_layout1_8 + } + BindgroupReferenceInput::Bindgroup1( _,false,_) => { &dev.bindgroup_layouts.pipeline_layout1 } - BindgroupReferenceInput::Bindgroup1( _, true) => { + BindgroupReferenceInput::Bindgroup1( _, true, _) => { &dev.bindgroup_layouts.pipeline_layout1_16 } BindgroupReferenceInput::Bindgroup2( _, _, false) => { @@ -1032,8 +1035,9 @@ pub fn create_bindgroup(dev: &WgpuDevice, bindgroup: CachedBindgroupFull, cache let bind_group_layout = match bindgroup.get_input() { CachedBindgroupInput::Bindgroup0 => &dev.bindgroup_layouts.bind_group_layout0, - CachedBindgroupInput::Bindgroup1(_, false) => &dev.bindgroup_layouts.bind_group_layout1, - CachedBindgroupInput::Bindgroup1(_, true) => &dev.bindgroup_layouts.bind_group_layout1_16, + CachedBindgroupInput::Bindgroup1(_, _, true) => &dev.bindgroup_layouts.bind_group_layout1_8, + CachedBindgroupInput::Bindgroup1(_, false, false) => &dev.bindgroup_layouts.bind_group_layout1, + CachedBindgroupInput::Bindgroup1(_, true, false) => &dev.bindgroup_layouts.bind_group_layout1_16, CachedBindgroupInput::Bindgroup2(_, _, false) => &dev.bindgroup_layouts.bind_group_layout2, CachedBindgroupInput::Bindgroup2(_, _, true) => &dev.bindgroup_layouts.bind_group_layout2_16, CachedBindgroupInput::Bindgroup3(_, _, _) => &dev.bindgroup_layouts.bind_group_layout3, @@ -1058,7 +1062,7 @@ pub fn create_bindgroup(dev: &WgpuDevice, bindgroup: CachedBindgroupFull, cache entries: entries, }) } - CachedBindgroupInput::Bindgroup1(buffer_input1, _) => { + CachedBindgroupInput::Bindgroup1(buffer_input1, _,_) => { if cache.buffers.get_buffer(buffer_input1).is_none(){ panic!("buffer_input_1 : {:?} could not be found(in {:?})", buffer_input1, bindgroup); @@ -1145,8 +1149,22 @@ fn create_bind_group_input1( buffer_dest: BufferReferenceId, buffer_input1: BufferReferenceId, ) -> BindGroupReference { - BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, false)) + BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, false,false)) +} + +fn create_bind_group_input1_8( + buffer_dest: BufferReferenceId, + buffer_input1: BufferReferenceId, +) -> BindGroupReference { + BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, true, true)) } +fn create_bind_group_input1_16( + buffer_dest: BufferReferenceId, + buffer_input1: BufferReferenceId, +) -> BindGroupReference { + BindGroupReference::new(buffer_dest, BindgroupInputBase::Bindgroup1(buffer_input1, true, false)) +} + fn create_bind_group_input2( buffer_dest: BufferReferenceId, diff --git a/candle-wgpu-kernels/src/generated.rs b/candle-wgpu-kernels/src/generated.rs index 864e580903..6276dcc159 100644 --- a/candle-wgpu-kernels/src/generated.rs +++ b/candle-wgpu-kernels/src/generated.rs @@ -2,214 +2,214 @@ use crate::*; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Pipelines{ - Binary(DType, kernels::binary::Functions), - Cmp(DType, kernels::cmp::Functions), - Conv1d(DType, kernels::conv1d::Functions), - Conv2d(DType, kernels::conv2d::Functions), Convert(DType, kernels::convert::Functions), - Copy(DType, kernels::copy::Functions), - Gather(DType, kernels::gather::Functions), - IndexSelect(DType, kernels::index_select::Functions), Matmul(DType, kernels::matmul::Functions), - Pool2d(DType, kernels::pool2d::Functions), - Reduce(DType, kernels::reduce::Functions), - RmsNorm(DType, kernels::rms_norm::Functions), - Matmul128x128(DType, kernels::sgemm::matmul128x128::Functions), - Matmul128x128Prefetch(DType, kernels::sgemm::matmul128x128_prefetch::Functions), - Matmul16x64(DType, kernels::sgemm::matmul16x64::Functions), + Copy(DType, kernels::copy::Functions), + Unary(DType, kernels::unary::Functions), + Upsample(DType, kernels::upsample::Functions), + Softmax(DType, kernels::softmax::Functions), + Matmul32x32Prefetch(DType, kernels::sgemm::matmul32x32_prefetch::Functions), + Matmul64x648x8(DType, kernels::sgemm::matmul64x64_8x8::Functions), Matmul16x64Prefetch(DType, kernels::sgemm::matmul16x64_prefetch::Functions), - Matmul1x128(DType, kernels::sgemm::matmul1x128::Functions), - Matmul1x128Prefetch(DType, kernels::sgemm::matmul1x128_prefetch::Functions), - Matmul1x256(DType, kernels::sgemm::matmul1x256::Functions), - Matmul1x256Prefetch(DType, kernels::sgemm::matmul1x256_prefetch::Functions), Matmul24x24(DType, kernels::sgemm::matmul24x24::Functions), - Matmul24x24Prefetch(DType, kernels::sgemm::matmul24x24_prefetch::Functions), - Matmul24x48(DType, kernels::sgemm::matmul24x48::Functions), - Matmul24x48Prefetch(DType, kernels::sgemm::matmul24x48_prefetch::Functions), + Matmul16x64(DType, kernels::sgemm::matmul16x64::Functions), Matmul32x32(DType, kernels::sgemm::matmul32x32::Functions), - Matmul32x32Prefetch(DType, kernels::sgemm::matmul32x32_prefetch::Functions), - Matmul64x1284x8(DType, kernels::sgemm::matmul64x128_4x8::Functions), - Matmul64x1284x8Prefetch(DType, kernels::sgemm::matmul64x128_4x8_prefetch::Functions), + Matmul128x128Prefetch(DType, kernels::sgemm::matmul128x128_prefetch::Functions), + Matmul1x256(DType, kernels::sgemm::matmul1x256::Functions), + Matmul1x256Prefetch(DType, kernels::sgemm::matmul1x256_prefetch::Functions), + Matmul64x64Prefetch(DType, kernels::sgemm::matmul64x64_prefetch::Functions), + Matmul64x64(DType, kernels::sgemm::matmul64x64::Functions), Matmul64x1288x8(DType, kernels::sgemm::matmul64x128_8x8::Functions), + Matmul128x128(DType, kernels::sgemm::matmul128x128::Functions), Matmul64x1288x8Prefetch(DType, kernels::sgemm::matmul64x128_8x8_prefetch::Functions), - Matmul64x64(DType, kernels::sgemm::matmul64x64::Functions), - Matmul64x648x8(DType, kernels::sgemm::matmul64x64_8x8::Functions), + Matmul64x1284x8(DType, kernels::sgemm::matmul64x128_4x8::Functions), + Matmul24x24Prefetch(DType, kernels::sgemm::matmul24x24_prefetch::Functions), + Matmul1x128(DType, kernels::sgemm::matmul1x128::Functions), Matmul64x648x8Prefetch(DType, kernels::sgemm::matmul64x64_8x8_prefetch::Functions), - Matmul64x64Prefetch(DType, kernels::sgemm::matmul64x64_prefetch::Functions), - Softmax(DType, kernels::softmax::Functions), - Unary(DType, kernels::unary::Functions), - Upsample(DType, kernels::upsample::Functions), + Matmul64x1284x8Prefetch(DType, kernels::sgemm::matmul64x128_4x8_prefetch::Functions), + Matmul1x128Prefetch(DType, kernels::sgemm::matmul1x128_prefetch::Functions), + Matmul24x48(DType, kernels::sgemm::matmul24x48::Functions), + Matmul24x48Prefetch(DType, kernels::sgemm::matmul24x48_prefetch::Functions), + Gather(DType, kernels::gather::Functions), + Conv2d(DType, kernels::conv2d::Functions), + Cmp(DType, kernels::cmp::Functions), WhereCond(DType, kernels::where_cond::Functions), + IndexSelect(DType, kernels::index_select::Functions), + Pool2d(DType, kernels::pool2d::Functions), + Binary(DType, kernels::binary::Functions), + Reduce(DType, kernels::reduce::Functions), + Conv1d(DType, kernels::conv1d::Functions), + RmsNorm(DType, kernels::rms_norm::Functions), } impl crate::EntryPoint for Pipelines{ fn get_entry_point(&self) -> &'static str{ match self{ - Pipelines::Binary(_, f) => f.get_entry_point(), - Pipelines::Cmp(_, f) => f.get_entry_point(), - Pipelines::Conv1d(_, f) => f.get_entry_point(), - Pipelines::Conv2d(_, f) => f.get_entry_point(), Pipelines::Convert(_, f) => f.get_entry_point(), - Pipelines::Copy(_, f) => f.get_entry_point(), - Pipelines::Gather(_, f) => f.get_entry_point(), - Pipelines::IndexSelect(_, f) => f.get_entry_point(), Pipelines::Matmul(_, f) => f.get_entry_point(), - Pipelines::Pool2d(_, f) => f.get_entry_point(), - Pipelines::Reduce(_, f) => f.get_entry_point(), - Pipelines::RmsNorm(_, f) => f.get_entry_point(), - Pipelines::Matmul128x128(_, f) => f.get_entry_point(), - Pipelines::Matmul128x128Prefetch(_, f) => f.get_entry_point(), - Pipelines::Matmul16x64(_, f) => f.get_entry_point(), + Pipelines::Copy(_, f) => f.get_entry_point(), + Pipelines::Unary(_, f) => f.get_entry_point(), + Pipelines::Upsample(_, f) => f.get_entry_point(), + Pipelines::Softmax(_, f) => f.get_entry_point(), + Pipelines::Matmul32x32Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul64x648x8(_, f) => f.get_entry_point(), Pipelines::Matmul16x64Prefetch(_, f) => f.get_entry_point(), - Pipelines::Matmul1x128(_, f) => f.get_entry_point(), - Pipelines::Matmul1x128Prefetch(_, f) => f.get_entry_point(), - Pipelines::Matmul1x256(_, f) => f.get_entry_point(), - Pipelines::Matmul1x256Prefetch(_, f) => f.get_entry_point(), Pipelines::Matmul24x24(_, f) => f.get_entry_point(), - Pipelines::Matmul24x24Prefetch(_, f) => f.get_entry_point(), - Pipelines::Matmul24x48(_, f) => f.get_entry_point(), - Pipelines::Matmul24x48Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul16x64(_, f) => f.get_entry_point(), Pipelines::Matmul32x32(_, f) => f.get_entry_point(), - Pipelines::Matmul32x32Prefetch(_, f) => f.get_entry_point(), - Pipelines::Matmul64x1284x8(_, f) => f.get_entry_point(), - Pipelines::Matmul64x1284x8Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul128x128Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul1x256(_, f) => f.get_entry_point(), + Pipelines::Matmul1x256Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul64x64Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul64x64(_, f) => f.get_entry_point(), Pipelines::Matmul64x1288x8(_, f) => f.get_entry_point(), + Pipelines::Matmul128x128(_, f) => f.get_entry_point(), Pipelines::Matmul64x1288x8Prefetch(_, f) => f.get_entry_point(), - Pipelines::Matmul64x64(_, f) => f.get_entry_point(), - Pipelines::Matmul64x648x8(_, f) => f.get_entry_point(), + Pipelines::Matmul64x1284x8(_, f) => f.get_entry_point(), + Pipelines::Matmul24x24Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul1x128(_, f) => f.get_entry_point(), Pipelines::Matmul64x648x8Prefetch(_, f) => f.get_entry_point(), - Pipelines::Matmul64x64Prefetch(_, f) => f.get_entry_point(), - Pipelines::Softmax(_, f) => f.get_entry_point(), - Pipelines::Unary(_, f) => f.get_entry_point(), - Pipelines::Upsample(_, f) => f.get_entry_point(), - Pipelines::WhereCond(_, f) => f.get_entry_point() + Pipelines::Matmul64x1284x8Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul1x128Prefetch(_, f) => f.get_entry_point(), + Pipelines::Matmul24x48(_, f) => f.get_entry_point(), + Pipelines::Matmul24x48Prefetch(_, f) => f.get_entry_point(), + Pipelines::Gather(_, f) => f.get_entry_point(), + Pipelines::Conv2d(_, f) => f.get_entry_point(), + Pipelines::Cmp(_, f) => f.get_entry_point(), + Pipelines::WhereCond(_, f) => f.get_entry_point(), + Pipelines::IndexSelect(_, f) => f.get_entry_point(), + Pipelines::Pool2d(_, f) => f.get_entry_point(), + Pipelines::Binary(_, f) => f.get_entry_point(), + Pipelines::Reduce(_, f) => f.get_entry_point(), + Pipelines::Conv1d(_, f) => f.get_entry_point(), + Pipelines::RmsNorm(_, f) => f.get_entry_point() } } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Shaders{ - Binary(DType), - Cmp(DType), - Conv1d(DType), - Conv2d(DType), Convert(DType), - Copy(DType), - Gather(DType), - IndexSelect(DType), Matmul(DType), - Pool2d(DType), - Reduce(DType), - RmsNorm(DType), - Matmul128x128(DType), - Matmul128x128Prefetch(DType), - Matmul16x64(DType), + Copy(DType), + Unary(DType), + Upsample(DType), + Softmax(DType), + Matmul32x32Prefetch(DType), + Matmul64x648x8(DType), Matmul16x64Prefetch(DType), - Matmul1x128(DType), - Matmul1x128Prefetch(DType), - Matmul1x256(DType), - Matmul1x256Prefetch(DType), Matmul24x24(DType), - Matmul24x24Prefetch(DType), - Matmul24x48(DType), - Matmul24x48Prefetch(DType), + Matmul16x64(DType), Matmul32x32(DType), - Matmul32x32Prefetch(DType), - Matmul64x1284x8(DType), - Matmul64x1284x8Prefetch(DType), + Matmul128x128Prefetch(DType), + Matmul1x256(DType), + Matmul1x256Prefetch(DType), + Matmul64x64Prefetch(DType), + Matmul64x64(DType), Matmul64x1288x8(DType), + Matmul128x128(DType), Matmul64x1288x8Prefetch(DType), - Matmul64x64(DType), - Matmul64x648x8(DType), + Matmul64x1284x8(DType), + Matmul24x24Prefetch(DType), + Matmul1x128(DType), Matmul64x648x8Prefetch(DType), - Matmul64x64Prefetch(DType), - Softmax(DType), - Unary(DType), - Upsample(DType), + Matmul64x1284x8Prefetch(DType), + Matmul1x128Prefetch(DType), + Matmul24x48(DType), + Matmul24x48Prefetch(DType), + Gather(DType), + Conv2d(DType), + Cmp(DType), WhereCond(DType), + IndexSelect(DType), + Pool2d(DType), + Binary(DType), + Reduce(DType), + Conv1d(DType), + RmsNorm(DType), } impl Pipelines { pub fn get_shader(&self) -> Shaders{ match self{ - Pipelines::Binary(typ, _) => Shaders::Binary(typ.clone()), - Pipelines::Cmp(typ, _) => Shaders::Cmp(typ.clone()), - Pipelines::Conv1d(typ, _) => Shaders::Conv1d(typ.clone()), - Pipelines::Conv2d(typ, _) => Shaders::Conv2d(typ.clone()), Pipelines::Convert(typ, _) => Shaders::Convert(typ.clone()), - Pipelines::Copy(typ, _) => Shaders::Copy(typ.clone()), - Pipelines::Gather(typ, _) => Shaders::Gather(typ.clone()), - Pipelines::IndexSelect(typ, _) => Shaders::IndexSelect(typ.clone()), Pipelines::Matmul(typ, _) => Shaders::Matmul(typ.clone()), - Pipelines::Pool2d(typ, _) => Shaders::Pool2d(typ.clone()), - Pipelines::Reduce(typ, _) => Shaders::Reduce(typ.clone()), - Pipelines::RmsNorm(typ, _) => Shaders::RmsNorm(typ.clone()), - Pipelines::Matmul128x128(typ, _) => Shaders::Matmul128x128(typ.clone()), - Pipelines::Matmul128x128Prefetch(typ, _) => Shaders::Matmul128x128Prefetch(typ.clone()), - Pipelines::Matmul16x64(typ, _) => Shaders::Matmul16x64(typ.clone()), + Pipelines::Copy(typ, _) => Shaders::Copy(typ.clone()), + Pipelines::Unary(typ, _) => Shaders::Unary(typ.clone()), + Pipelines::Upsample(typ, _) => Shaders::Upsample(typ.clone()), + Pipelines::Softmax(typ, _) => Shaders::Softmax(typ.clone()), + Pipelines::Matmul32x32Prefetch(typ, _) => Shaders::Matmul32x32Prefetch(typ.clone()), + Pipelines::Matmul64x648x8(typ, _) => Shaders::Matmul64x648x8(typ.clone()), Pipelines::Matmul16x64Prefetch(typ, _) => Shaders::Matmul16x64Prefetch(typ.clone()), - Pipelines::Matmul1x128(typ, _) => Shaders::Matmul1x128(typ.clone()), - Pipelines::Matmul1x128Prefetch(typ, _) => Shaders::Matmul1x128Prefetch(typ.clone()), - Pipelines::Matmul1x256(typ, _) => Shaders::Matmul1x256(typ.clone()), - Pipelines::Matmul1x256Prefetch(typ, _) => Shaders::Matmul1x256Prefetch(typ.clone()), Pipelines::Matmul24x24(typ, _) => Shaders::Matmul24x24(typ.clone()), - Pipelines::Matmul24x24Prefetch(typ, _) => Shaders::Matmul24x24Prefetch(typ.clone()), - Pipelines::Matmul24x48(typ, _) => Shaders::Matmul24x48(typ.clone()), - Pipelines::Matmul24x48Prefetch(typ, _) => Shaders::Matmul24x48Prefetch(typ.clone()), + Pipelines::Matmul16x64(typ, _) => Shaders::Matmul16x64(typ.clone()), Pipelines::Matmul32x32(typ, _) => Shaders::Matmul32x32(typ.clone()), - Pipelines::Matmul32x32Prefetch(typ, _) => Shaders::Matmul32x32Prefetch(typ.clone()), - Pipelines::Matmul64x1284x8(typ, _) => Shaders::Matmul64x1284x8(typ.clone()), - Pipelines::Matmul64x1284x8Prefetch(typ, _) => Shaders::Matmul64x1284x8Prefetch(typ.clone()), + Pipelines::Matmul128x128Prefetch(typ, _) => Shaders::Matmul128x128Prefetch(typ.clone()), + Pipelines::Matmul1x256(typ, _) => Shaders::Matmul1x256(typ.clone()), + Pipelines::Matmul1x256Prefetch(typ, _) => Shaders::Matmul1x256Prefetch(typ.clone()), + Pipelines::Matmul64x64Prefetch(typ, _) => Shaders::Matmul64x64Prefetch(typ.clone()), + Pipelines::Matmul64x64(typ, _) => Shaders::Matmul64x64(typ.clone()), Pipelines::Matmul64x1288x8(typ, _) => Shaders::Matmul64x1288x8(typ.clone()), + Pipelines::Matmul128x128(typ, _) => Shaders::Matmul128x128(typ.clone()), Pipelines::Matmul64x1288x8Prefetch(typ, _) => Shaders::Matmul64x1288x8Prefetch(typ.clone()), - Pipelines::Matmul64x64(typ, _) => Shaders::Matmul64x64(typ.clone()), - Pipelines::Matmul64x648x8(typ, _) => Shaders::Matmul64x648x8(typ.clone()), + Pipelines::Matmul64x1284x8(typ, _) => Shaders::Matmul64x1284x8(typ.clone()), + Pipelines::Matmul24x24Prefetch(typ, _) => Shaders::Matmul24x24Prefetch(typ.clone()), + Pipelines::Matmul1x128(typ, _) => Shaders::Matmul1x128(typ.clone()), Pipelines::Matmul64x648x8Prefetch(typ, _) => Shaders::Matmul64x648x8Prefetch(typ.clone()), - Pipelines::Matmul64x64Prefetch(typ, _) => Shaders::Matmul64x64Prefetch(typ.clone()), - Pipelines::Softmax(typ, _) => Shaders::Softmax(typ.clone()), - Pipelines::Unary(typ, _) => Shaders::Unary(typ.clone()), - Pipelines::Upsample(typ, _) => Shaders::Upsample(typ.clone()), - Pipelines::WhereCond(typ, _) => Shaders::WhereCond(typ.clone()) + Pipelines::Matmul64x1284x8Prefetch(typ, _) => Shaders::Matmul64x1284x8Prefetch(typ.clone()), + Pipelines::Matmul1x128Prefetch(typ, _) => Shaders::Matmul1x128Prefetch(typ.clone()), + Pipelines::Matmul24x48(typ, _) => Shaders::Matmul24x48(typ.clone()), + Pipelines::Matmul24x48Prefetch(typ, _) => Shaders::Matmul24x48Prefetch(typ.clone()), + Pipelines::Gather(typ, _) => Shaders::Gather(typ.clone()), + Pipelines::Conv2d(typ, _) => Shaders::Conv2d(typ.clone()), + Pipelines::Cmp(typ, _) => Shaders::Cmp(typ.clone()), + Pipelines::WhereCond(typ, _) => Shaders::WhereCond(typ.clone()), + Pipelines::IndexSelect(typ, _) => Shaders::IndexSelect(typ.clone()), + Pipelines::Pool2d(typ, _) => Shaders::Pool2d(typ.clone()), + Pipelines::Binary(typ, _) => Shaders::Binary(typ.clone()), + Pipelines::Reduce(typ, _) => Shaders::Reduce(typ.clone()), + Pipelines::Conv1d(typ, _) => Shaders::Conv1d(typ.clone()), + Pipelines::RmsNorm(typ, _) => Shaders::RmsNorm(typ.clone()) } } pub fn load_shader(&self) -> &'static str{ match self{ - Pipelines::Binary(typ, _) => kernels::binary::load_shader(typ.clone()), - Pipelines::Cmp(typ, _) => kernels::cmp::load_shader(typ.clone()), - Pipelines::Conv1d(typ, _) => kernels::conv1d::load_shader(typ.clone()), - Pipelines::Conv2d(typ, _) => kernels::conv2d::load_shader(typ.clone()), Pipelines::Convert(typ, _) => kernels::convert::load_shader(typ.clone()), - Pipelines::Copy(typ, _) => kernels::copy::load_shader(typ.clone()), - Pipelines::Gather(typ, _) => kernels::gather::load_shader(typ.clone()), - Pipelines::IndexSelect(typ, _) => kernels::index_select::load_shader(typ.clone()), Pipelines::Matmul(typ, _) => kernels::matmul::load_shader(typ.clone()), - Pipelines::Pool2d(typ, _) => kernels::pool2d::load_shader(typ.clone()), - Pipelines::Reduce(typ, _) => kernels::reduce::load_shader(typ.clone()), - Pipelines::RmsNorm(typ, _) => kernels::rms_norm::load_shader(typ.clone()), - Pipelines::Matmul128x128(typ, _) => kernels::sgemm::matmul128x128::load_shader(typ.clone()), - Pipelines::Matmul128x128Prefetch(typ, _) => kernels::sgemm::matmul128x128_prefetch::load_shader(typ.clone()), - Pipelines::Matmul16x64(typ, _) => kernels::sgemm::matmul16x64::load_shader(typ.clone()), + Pipelines::Copy(typ, _) => kernels::copy::load_shader(typ.clone()), + Pipelines::Unary(typ, _) => kernels::unary::load_shader(typ.clone()), + Pipelines::Upsample(typ, _) => kernels::upsample::load_shader(typ.clone()), + Pipelines::Softmax(typ, _) => kernels::softmax::load_shader(typ.clone()), + Pipelines::Matmul32x32Prefetch(typ, _) => kernels::sgemm::matmul32x32_prefetch::load_shader(typ.clone()), + Pipelines::Matmul64x648x8(typ, _) => kernels::sgemm::matmul64x64_8x8::load_shader(typ.clone()), Pipelines::Matmul16x64Prefetch(typ, _) => kernels::sgemm::matmul16x64_prefetch::load_shader(typ.clone()), - Pipelines::Matmul1x128(typ, _) => kernels::sgemm::matmul1x128::load_shader(typ.clone()), - Pipelines::Matmul1x128Prefetch(typ, _) => kernels::sgemm::matmul1x128_prefetch::load_shader(typ.clone()), - Pipelines::Matmul1x256(typ, _) => kernels::sgemm::matmul1x256::load_shader(typ.clone()), - Pipelines::Matmul1x256Prefetch(typ, _) => kernels::sgemm::matmul1x256_prefetch::load_shader(typ.clone()), Pipelines::Matmul24x24(typ, _) => kernels::sgemm::matmul24x24::load_shader(typ.clone()), - Pipelines::Matmul24x24Prefetch(typ, _) => kernels::sgemm::matmul24x24_prefetch::load_shader(typ.clone()), - Pipelines::Matmul24x48(typ, _) => kernels::sgemm::matmul24x48::load_shader(typ.clone()), - Pipelines::Matmul24x48Prefetch(typ, _) => kernels::sgemm::matmul24x48_prefetch::load_shader(typ.clone()), + Pipelines::Matmul16x64(typ, _) => kernels::sgemm::matmul16x64::load_shader(typ.clone()), Pipelines::Matmul32x32(typ, _) => kernels::sgemm::matmul32x32::load_shader(typ.clone()), - Pipelines::Matmul32x32Prefetch(typ, _) => kernels::sgemm::matmul32x32_prefetch::load_shader(typ.clone()), - Pipelines::Matmul64x1284x8(typ, _) => kernels::sgemm::matmul64x128_4x8::load_shader(typ.clone()), - Pipelines::Matmul64x1284x8Prefetch(typ, _) => kernels::sgemm::matmul64x128_4x8_prefetch::load_shader(typ.clone()), + Pipelines::Matmul128x128Prefetch(typ, _) => kernels::sgemm::matmul128x128_prefetch::load_shader(typ.clone()), + Pipelines::Matmul1x256(typ, _) => kernels::sgemm::matmul1x256::load_shader(typ.clone()), + Pipelines::Matmul1x256Prefetch(typ, _) => kernels::sgemm::matmul1x256_prefetch::load_shader(typ.clone()), + Pipelines::Matmul64x64Prefetch(typ, _) => kernels::sgemm::matmul64x64_prefetch::load_shader(typ.clone()), + Pipelines::Matmul64x64(typ, _) => kernels::sgemm::matmul64x64::load_shader(typ.clone()), Pipelines::Matmul64x1288x8(typ, _) => kernels::sgemm::matmul64x128_8x8::load_shader(typ.clone()), + Pipelines::Matmul128x128(typ, _) => kernels::sgemm::matmul128x128::load_shader(typ.clone()), Pipelines::Matmul64x1288x8Prefetch(typ, _) => kernels::sgemm::matmul64x128_8x8_prefetch::load_shader(typ.clone()), - Pipelines::Matmul64x64(typ, _) => kernels::sgemm::matmul64x64::load_shader(typ.clone()), - Pipelines::Matmul64x648x8(typ, _) => kernels::sgemm::matmul64x64_8x8::load_shader(typ.clone()), + Pipelines::Matmul64x1284x8(typ, _) => kernels::sgemm::matmul64x128_4x8::load_shader(typ.clone()), + Pipelines::Matmul24x24Prefetch(typ, _) => kernels::sgemm::matmul24x24_prefetch::load_shader(typ.clone()), + Pipelines::Matmul1x128(typ, _) => kernels::sgemm::matmul1x128::load_shader(typ.clone()), Pipelines::Matmul64x648x8Prefetch(typ, _) => kernels::sgemm::matmul64x64_8x8_prefetch::load_shader(typ.clone()), - Pipelines::Matmul64x64Prefetch(typ, _) => kernels::sgemm::matmul64x64_prefetch::load_shader(typ.clone()), - Pipelines::Softmax(typ, _) => kernels::softmax::load_shader(typ.clone()), - Pipelines::Unary(typ, _) => kernels::unary::load_shader(typ.clone()), - Pipelines::Upsample(typ, _) => kernels::upsample::load_shader(typ.clone()), - Pipelines::WhereCond(typ, _) => kernels::where_cond::load_shader(typ.clone()) + Pipelines::Matmul64x1284x8Prefetch(typ, _) => kernels::sgemm::matmul64x128_4x8_prefetch::load_shader(typ.clone()), + Pipelines::Matmul1x128Prefetch(typ, _) => kernels::sgemm::matmul1x128_prefetch::load_shader(typ.clone()), + Pipelines::Matmul24x48(typ, _) => kernels::sgemm::matmul24x48::load_shader(typ.clone()), + Pipelines::Matmul24x48Prefetch(typ, _) => kernels::sgemm::matmul24x48_prefetch::load_shader(typ.clone()), + Pipelines::Gather(typ, _) => kernels::gather::load_shader(typ.clone()), + Pipelines::Conv2d(typ, _) => kernels::conv2d::load_shader(typ.clone()), + Pipelines::Cmp(typ, _) => kernels::cmp::load_shader(typ.clone()), + Pipelines::WhereCond(typ, _) => kernels::where_cond::load_shader(typ.clone()), + Pipelines::IndexSelect(typ, _) => kernels::index_select::load_shader(typ.clone()), + Pipelines::Pool2d(typ, _) => kernels::pool2d::load_shader(typ.clone()), + Pipelines::Binary(typ, _) => kernels::binary::load_shader(typ.clone()), + Pipelines::Reduce(typ, _) => kernels::reduce::load_shader(typ.clone()), + Pipelines::Conv1d(typ, _) => kernels::conv1d::load_shader(typ.clone()), + Pipelines::RmsNorm(typ, _) => kernels::rms_norm::load_shader(typ.clone()) } } } @@ -217,87 +217,87 @@ impl Pipelines { impl Shaders { pub fn get_shader(&self) -> Shaders{ match self{ - Shaders::Binary(typ) => Shaders::Binary(typ.clone()), - Shaders::Cmp(typ) => Shaders::Cmp(typ.clone()), - Shaders::Conv1d(typ) => Shaders::Conv1d(typ.clone()), - Shaders::Conv2d(typ) => Shaders::Conv2d(typ.clone()), Shaders::Convert(typ) => Shaders::Convert(typ.clone()), - Shaders::Copy(typ) => Shaders::Copy(typ.clone()), - Shaders::Gather(typ) => Shaders::Gather(typ.clone()), - Shaders::IndexSelect(typ) => Shaders::IndexSelect(typ.clone()), Shaders::Matmul(typ) => Shaders::Matmul(typ.clone()), - Shaders::Pool2d(typ) => Shaders::Pool2d(typ.clone()), - Shaders::Reduce(typ) => Shaders::Reduce(typ.clone()), - Shaders::RmsNorm(typ) => Shaders::RmsNorm(typ.clone()), - Shaders::Matmul128x128(typ) => Shaders::Matmul128x128(typ.clone()), - Shaders::Matmul128x128Prefetch(typ) => Shaders::Matmul128x128Prefetch(typ.clone()), - Shaders::Matmul16x64(typ) => Shaders::Matmul16x64(typ.clone()), + Shaders::Copy(typ) => Shaders::Copy(typ.clone()), + Shaders::Unary(typ) => Shaders::Unary(typ.clone()), + Shaders::Upsample(typ) => Shaders::Upsample(typ.clone()), + Shaders::Softmax(typ) => Shaders::Softmax(typ.clone()), + Shaders::Matmul32x32Prefetch(typ) => Shaders::Matmul32x32Prefetch(typ.clone()), + Shaders::Matmul64x648x8(typ) => Shaders::Matmul64x648x8(typ.clone()), Shaders::Matmul16x64Prefetch(typ) => Shaders::Matmul16x64Prefetch(typ.clone()), + Shaders::Matmul24x24(typ) => Shaders::Matmul24x24(typ.clone()), + Shaders::Matmul16x64(typ) => Shaders::Matmul16x64(typ.clone()), + Shaders::Matmul32x32(typ) => Shaders::Matmul32x32(typ.clone()), + Shaders::Matmul128x128Prefetch(typ) => Shaders::Matmul128x128Prefetch(typ.clone()), + Shaders::Matmul1x256(typ) => Shaders::Matmul1x256(typ.clone()), + Shaders::Matmul1x256Prefetch(typ) => Shaders::Matmul1x256Prefetch(typ.clone()), + Shaders::Matmul64x64Prefetch(typ) => Shaders::Matmul64x64Prefetch(typ.clone()), + Shaders::Matmul64x64(typ) => Shaders::Matmul64x64(typ.clone()), + Shaders::Matmul64x1288x8(typ) => Shaders::Matmul64x1288x8(typ.clone()), + Shaders::Matmul128x128(typ) => Shaders::Matmul128x128(typ.clone()), + Shaders::Matmul64x1288x8Prefetch(typ) => Shaders::Matmul64x1288x8Prefetch(typ.clone()), + Shaders::Matmul64x1284x8(typ) => Shaders::Matmul64x1284x8(typ.clone()), + Shaders::Matmul24x24Prefetch(typ) => Shaders::Matmul24x24Prefetch(typ.clone()), Shaders::Matmul1x128(typ) => Shaders::Matmul1x128(typ.clone()), + Shaders::Matmul64x648x8Prefetch(typ) => Shaders::Matmul64x648x8Prefetch(typ.clone()), + Shaders::Matmul64x1284x8Prefetch(typ) => Shaders::Matmul64x1284x8Prefetch(typ.clone()), Shaders::Matmul1x128Prefetch(typ) => Shaders::Matmul1x128Prefetch(typ.clone()), - Shaders::Matmul1x256(typ) => Shaders::Matmul1x256(typ.clone()), - Shaders::Matmul1x256Prefetch(typ) => Shaders::Matmul1x256Prefetch(typ.clone()), - Shaders::Matmul24x24(typ) => Shaders::Matmul24x24(typ.clone()), - Shaders::Matmul24x24Prefetch(typ) => Shaders::Matmul24x24Prefetch(typ.clone()), Shaders::Matmul24x48(typ) => Shaders::Matmul24x48(typ.clone()), Shaders::Matmul24x48Prefetch(typ) => Shaders::Matmul24x48Prefetch(typ.clone()), - Shaders::Matmul32x32(typ) => Shaders::Matmul32x32(typ.clone()), - Shaders::Matmul32x32Prefetch(typ) => Shaders::Matmul32x32Prefetch(typ.clone()), - Shaders::Matmul64x1284x8(typ) => Shaders::Matmul64x1284x8(typ.clone()), - Shaders::Matmul64x1284x8Prefetch(typ) => Shaders::Matmul64x1284x8Prefetch(typ.clone()), - Shaders::Matmul64x1288x8(typ) => Shaders::Matmul64x1288x8(typ.clone()), - Shaders::Matmul64x1288x8Prefetch(typ) => Shaders::Matmul64x1288x8Prefetch(typ.clone()), - Shaders::Matmul64x64(typ) => Shaders::Matmul64x64(typ.clone()), - Shaders::Matmul64x648x8(typ) => Shaders::Matmul64x648x8(typ.clone()), - Shaders::Matmul64x648x8Prefetch(typ) => Shaders::Matmul64x648x8Prefetch(typ.clone()), - Shaders::Matmul64x64Prefetch(typ) => Shaders::Matmul64x64Prefetch(typ.clone()), - Shaders::Softmax(typ) => Shaders::Softmax(typ.clone()), - Shaders::Unary(typ) => Shaders::Unary(typ.clone()), - Shaders::Upsample(typ) => Shaders::Upsample(typ.clone()), - Shaders::WhereCond(typ) => Shaders::WhereCond(typ.clone()) + Shaders::Gather(typ) => Shaders::Gather(typ.clone()), + Shaders::Conv2d(typ) => Shaders::Conv2d(typ.clone()), + Shaders::Cmp(typ) => Shaders::Cmp(typ.clone()), + Shaders::WhereCond(typ) => Shaders::WhereCond(typ.clone()), + Shaders::IndexSelect(typ) => Shaders::IndexSelect(typ.clone()), + Shaders::Pool2d(typ) => Shaders::Pool2d(typ.clone()), + Shaders::Binary(typ) => Shaders::Binary(typ.clone()), + Shaders::Reduce(typ) => Shaders::Reduce(typ.clone()), + Shaders::Conv1d(typ) => Shaders::Conv1d(typ.clone()), + Shaders::RmsNorm(typ) => Shaders::RmsNorm(typ.clone()) } } pub fn load_shader(&self) -> &'static str{ match self{ - Shaders::Binary(typ) => kernels::binary::load_shader(typ.clone()), - Shaders::Cmp(typ) => kernels::cmp::load_shader(typ.clone()), - Shaders::Conv1d(typ) => kernels::conv1d::load_shader(typ.clone()), - Shaders::Conv2d(typ) => kernels::conv2d::load_shader(typ.clone()), Shaders::Convert(typ) => kernels::convert::load_shader(typ.clone()), - Shaders::Copy(typ) => kernels::copy::load_shader(typ.clone()), - Shaders::Gather(typ) => kernels::gather::load_shader(typ.clone()), - Shaders::IndexSelect(typ) => kernels::index_select::load_shader(typ.clone()), Shaders::Matmul(typ) => kernels::matmul::load_shader(typ.clone()), - Shaders::Pool2d(typ) => kernels::pool2d::load_shader(typ.clone()), - Shaders::Reduce(typ) => kernels::reduce::load_shader(typ.clone()), - Shaders::RmsNorm(typ) => kernels::rms_norm::load_shader(typ.clone()), - Shaders::Matmul128x128(typ) => kernels::sgemm::matmul128x128::load_shader(typ.clone()), - Shaders::Matmul128x128Prefetch(typ) => kernels::sgemm::matmul128x128_prefetch::load_shader(typ.clone()), - Shaders::Matmul16x64(typ) => kernels::sgemm::matmul16x64::load_shader(typ.clone()), + Shaders::Copy(typ) => kernels::copy::load_shader(typ.clone()), + Shaders::Unary(typ) => kernels::unary::load_shader(typ.clone()), + Shaders::Upsample(typ) => kernels::upsample::load_shader(typ.clone()), + Shaders::Softmax(typ) => kernels::softmax::load_shader(typ.clone()), + Shaders::Matmul32x32Prefetch(typ) => kernels::sgemm::matmul32x32_prefetch::load_shader(typ.clone()), + Shaders::Matmul64x648x8(typ) => kernels::sgemm::matmul64x64_8x8::load_shader(typ.clone()), Shaders::Matmul16x64Prefetch(typ) => kernels::sgemm::matmul16x64_prefetch::load_shader(typ.clone()), - Shaders::Matmul1x128(typ) => kernels::sgemm::matmul1x128::load_shader(typ.clone()), - Shaders::Matmul1x128Prefetch(typ) => kernels::sgemm::matmul1x128_prefetch::load_shader(typ.clone()), - Shaders::Matmul1x256(typ) => kernels::sgemm::matmul1x256::load_shader(typ.clone()), - Shaders::Matmul1x256Prefetch(typ) => kernels::sgemm::matmul1x256_prefetch::load_shader(typ.clone()), Shaders::Matmul24x24(typ) => kernels::sgemm::matmul24x24::load_shader(typ.clone()), - Shaders::Matmul24x24Prefetch(typ) => kernels::sgemm::matmul24x24_prefetch::load_shader(typ.clone()), - Shaders::Matmul24x48(typ) => kernels::sgemm::matmul24x48::load_shader(typ.clone()), - Shaders::Matmul24x48Prefetch(typ) => kernels::sgemm::matmul24x48_prefetch::load_shader(typ.clone()), + Shaders::Matmul16x64(typ) => kernels::sgemm::matmul16x64::load_shader(typ.clone()), Shaders::Matmul32x32(typ) => kernels::sgemm::matmul32x32::load_shader(typ.clone()), - Shaders::Matmul32x32Prefetch(typ) => kernels::sgemm::matmul32x32_prefetch::load_shader(typ.clone()), - Shaders::Matmul64x1284x8(typ) => kernels::sgemm::matmul64x128_4x8::load_shader(typ.clone()), - Shaders::Matmul64x1284x8Prefetch(typ) => kernels::sgemm::matmul64x128_4x8_prefetch::load_shader(typ.clone()), + Shaders::Matmul128x128Prefetch(typ) => kernels::sgemm::matmul128x128_prefetch::load_shader(typ.clone()), + Shaders::Matmul1x256(typ) => kernels::sgemm::matmul1x256::load_shader(typ.clone()), + Shaders::Matmul1x256Prefetch(typ) => kernels::sgemm::matmul1x256_prefetch::load_shader(typ.clone()), + Shaders::Matmul64x64Prefetch(typ) => kernels::sgemm::matmul64x64_prefetch::load_shader(typ.clone()), + Shaders::Matmul64x64(typ) => kernels::sgemm::matmul64x64::load_shader(typ.clone()), Shaders::Matmul64x1288x8(typ) => kernels::sgemm::matmul64x128_8x8::load_shader(typ.clone()), + Shaders::Matmul128x128(typ) => kernels::sgemm::matmul128x128::load_shader(typ.clone()), Shaders::Matmul64x1288x8Prefetch(typ) => kernels::sgemm::matmul64x128_8x8_prefetch::load_shader(typ.clone()), - Shaders::Matmul64x64(typ) => kernels::sgemm::matmul64x64::load_shader(typ.clone()), - Shaders::Matmul64x648x8(typ) => kernels::sgemm::matmul64x64_8x8::load_shader(typ.clone()), + Shaders::Matmul64x1284x8(typ) => kernels::sgemm::matmul64x128_4x8::load_shader(typ.clone()), + Shaders::Matmul24x24Prefetch(typ) => kernels::sgemm::matmul24x24_prefetch::load_shader(typ.clone()), + Shaders::Matmul1x128(typ) => kernels::sgemm::matmul1x128::load_shader(typ.clone()), Shaders::Matmul64x648x8Prefetch(typ) => kernels::sgemm::matmul64x64_8x8_prefetch::load_shader(typ.clone()), - Shaders::Matmul64x64Prefetch(typ) => kernels::sgemm::matmul64x64_prefetch::load_shader(typ.clone()), - Shaders::Softmax(typ) => kernels::softmax::load_shader(typ.clone()), - Shaders::Unary(typ) => kernels::unary::load_shader(typ.clone()), - Shaders::Upsample(typ) => kernels::upsample::load_shader(typ.clone()), - Shaders::WhereCond(typ) => kernels::where_cond::load_shader(typ.clone()) + Shaders::Matmul64x1284x8Prefetch(typ) => kernels::sgemm::matmul64x128_4x8_prefetch::load_shader(typ.clone()), + Shaders::Matmul1x128Prefetch(typ) => kernels::sgemm::matmul1x128_prefetch::load_shader(typ.clone()), + Shaders::Matmul24x48(typ) => kernels::sgemm::matmul24x48::load_shader(typ.clone()), + Shaders::Matmul24x48Prefetch(typ) => kernels::sgemm::matmul24x48_prefetch::load_shader(typ.clone()), + Shaders::Gather(typ) => kernels::gather::load_shader(typ.clone()), + Shaders::Conv2d(typ) => kernels::conv2d::load_shader(typ.clone()), + Shaders::Cmp(typ) => kernels::cmp::load_shader(typ.clone()), + Shaders::WhereCond(typ) => kernels::where_cond::load_shader(typ.clone()), + Shaders::IndexSelect(typ) => kernels::index_select::load_shader(typ.clone()), + Shaders::Pool2d(typ) => kernels::pool2d::load_shader(typ.clone()), + Shaders::Binary(typ) => kernels::binary::load_shader(typ.clone()), + Shaders::Reduce(typ) => kernels::reduce::load_shader(typ.clone()), + Shaders::Conv1d(typ) => kernels::conv1d::load_shader(typ.clone()), + Shaders::RmsNorm(typ) => kernels::rms_norm::load_shader(typ.clone()) } } } @@ -305,57 +305,57 @@ impl Shaders { #[derive(Debug, Clone, PartialEq, Eq, Hash, std::marker::Copy)] pub enum Constants { None, - Constv8, + UseZ, + Constv1, + ConstDims1, + ConstIsStartoffsetZero3, Constv7, Preloada, - ConstIsStartoffsetZero1, + ConstDims3, + Constv3, + ConstIsContiguous3, + ConstIsContiguous2, + Constv6, + Constv8, Preloadb, - ConstIsStartoffsetZero3, - ConstIsContiguous1, - Constv2, - ConstDims2, + ConstIsStartoffsetZero1, ConstIsStartoffsetZero2, - Constv0, - UseZ, + ConstDims2, Isoutputpadded, - ConstIsContiguous3, - Constv3, - Constv6, + Constv0, Constv9, - Constv1, - Constv4, - ConstDims1, - ConstIsContiguous2, Constv5, - ConstDims3 + Constv2, + Constv4, + ConstIsContiguous1 } impl crate::EntryPoint for Constants{ fn get_entry_point(&self) -> &'static str{ match self{ - Constants::Constv8 => "CONSTV_8", + Constants::UseZ => "USE_Z", + Constants::Constv1 => "CONSTV_1", + Constants::ConstDims1 => "CONST_DIMS1", + Constants::ConstIsStartoffsetZero3 => "CONST_IS_STARTOFFSET_ZERO3", Constants::Constv7 => "CONSTV_7", Constants::Preloada => "PreLoadA", - Constants::ConstIsStartoffsetZero1 => "CONST_IS_STARTOFFSET_ZERO1", + Constants::ConstDims3 => "CONST_DIMS3", + Constants::Constv3 => "CONSTV_3", + Constants::ConstIsContiguous3 => "CONST_IS_CONTIGUOUS3", + Constants::ConstIsContiguous2 => "CONST_IS_CONTIGUOUS2", + Constants::Constv6 => "CONSTV_6", + Constants::Constv8 => "CONSTV_8", Constants::Preloadb => "PreLoadB", - Constants::ConstIsStartoffsetZero3 => "CONST_IS_STARTOFFSET_ZERO3", - Constants::ConstIsContiguous1 => "CONST_IS_CONTIGUOUS1", - Constants::Constv2 => "CONSTV_2", - Constants::ConstDims2 => "CONST_DIMS2", + Constants::ConstIsStartoffsetZero1 => "CONST_IS_STARTOFFSET_ZERO1", Constants::ConstIsStartoffsetZero2 => "CONST_IS_STARTOFFSET_ZERO2", - Constants::Constv0 => "CONSTV_0", - Constants::UseZ => "USE_Z", + Constants::ConstDims2 => "CONST_DIMS2", Constants::Isoutputpadded => "IsOutputPadded", - Constants::ConstIsContiguous3 => "CONST_IS_CONTIGUOUS3", - Constants::Constv3 => "CONSTV_3", - Constants::Constv6 => "CONSTV_6", + Constants::Constv0 => "CONSTV_0", Constants::Constv9 => "CONSTV_9", - Constants::Constv1 => "CONSTV_1", - Constants::Constv4 => "CONSTV_4", - Constants::ConstDims1 => "CONST_DIMS1", - Constants::ConstIsContiguous2 => "CONST_IS_CONTIGUOUS2", Constants::Constv5 => "CONSTV_5", - Constants::ConstDims3 => "CONST_DIMS3", + Constants::Constv2 => "CONSTV_2", + Constants::Constv4 => "CONSTV_4", + Constants::ConstIsContiguous1 => "CONST_IS_CONTIGUOUS1", Constants::None => panic!("not expected") } } @@ -367,116 +367,116 @@ impl Default for Constants { } } pub mod kernels { - pub mod binary { + pub mod convert { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{BinaryBufferFromBuffer,BinaryBufferInplace2ContiguousBoth,BinaryBufferFromBufferContiguousBoth,BinaryBufferInplace1ContiguousBoth} + pub enum Functions{ConvertU8ToF32,ConvertU32ToU8,ConvertToU32,ConvertU32ToI64,ConvertToF32,ConvertF32ToU8} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::BinaryBufferFromBuffer => "binary_buffer_from_buffer",Functions::BinaryBufferInplace2ContiguousBoth => "binary_buffer_inplace2_contiguous_both",Functions::BinaryBufferFromBufferContiguousBoth => "binary_buffer_from_buffer_contiguous_both",Functions::BinaryBufferInplace1ContiguousBoth => "binary_buffer_inplace1_contiguous_both" + Functions::ConvertU8ToF32 => "convert_u8_to_f32",Functions::ConvertU32ToU8 => "convert_u32_to_u8",Functions::ConvertToU32 => "convert_to_u32",Functions::ConvertU32ToI64 => "convert_u32_to_i64",Functions::ConvertToF32 => "convert_to_f32",Functions::ConvertF32ToU8 => "convert_f32_to_u8" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/binary.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/binary.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/binary.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/convert.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/convert.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/convert.pwgsl_generated_u8.wgsl"), } } } - pub mod cmp { + pub mod matmul { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{CmpBufferFromBuffer} + pub enum Functions{Matmul5,Matmul7,Matmul1,Matmul116,Matmul1End} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::CmpBufferFromBuffer => "cmp_buffer_from_buffer" + Functions::Matmul5 => "matmul5",Functions::Matmul7 => "matmul7",Functions::Matmul1 => "matmul1",Functions::Matmul116 => "matmul1_16",Functions::Matmul1End => "matmul1_end" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/cmp.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/cmp.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/cmp.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/matmul.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/matmul.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/matmul.pwgsl_generated_u8.wgsl"), } } } - pub mod conv1d { + pub mod copy { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Conv1d,Conv1dTranspose} + pub enum Functions{Copy2dTranspose,Copy,Copy2dTranspose2,Copy3d,Copy2d2,CopyStrided,Copy4,Copy2d,Copy3dPaddedNobatch,Copy3dPadded,Copy4dPadded} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Conv1d => "conv1d",Functions::Conv1dTranspose => "conv1d_transpose" + Functions::Copy2dTranspose => "copy2d_transpose",Functions::Copy => "copy",Functions::Copy2dTranspose2 => "copy2d_transpose2",Functions::Copy3d => "copy3d",Functions::Copy2d2 => "copy2d2",Functions::CopyStrided => "copy_strided",Functions::Copy4 => "copy_4",Functions::Copy2d => "copy2d",Functions::Copy3dPaddedNobatch => "copy3d_padded_nobatch",Functions::Copy3dPadded => "copy3d_padded",Functions::Copy4dPadded => "copy4d_padded" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/conv1d.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/conv1d.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/conv1d.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/copy.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/copy.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/copy.pwgsl_generated_u8.wgsl"), } } } - pub mod conv2d { + pub mod unary { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Conv2dTranspose,Conv2d,Conv2d2} + pub enum Functions{UnaryFromBuffer,UnaryFromBufferContiguous,UnaryInplaceContiguous} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Conv2dTranspose => "conv2d_transpose",Functions::Conv2d => "conv2d",Functions::Conv2d2 => "conv2d_2" + Functions::UnaryFromBuffer => "unary_from_buffer",Functions::UnaryFromBufferContiguous => "unary_from_buffer_contiguous",Functions::UnaryInplaceContiguous => "unary_inplace_contiguous" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/conv2d.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/conv2d.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/conv2d.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/unary.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/unary.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/unary.pwgsl_generated_u8.wgsl"), } } } - pub mod convert { + pub mod upsample { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{ConvertToF32,ConvertToU32,ConvertU8ToF32,ConvertF32ToU8,ConvertU32ToU8} + pub enum Functions{Upsample2d,Upsample1d} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::ConvertToF32 => "convert_to_f32",Functions::ConvertToU32 => "convert_to_u32",Functions::ConvertU8ToF32 => "convert_u8_to_f32",Functions::ConvertF32ToU8 => "convert_f32_to_u8",Functions::ConvertU32ToU8 => "convert_u32_to_u8" + Functions::Upsample2d => "upsample2d",Functions::Upsample1d => "upsample1d" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/convert.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/convert.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/convert.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/upsample.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/upsample.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/upsample.pwgsl_generated_u8.wgsl"), } } } - pub mod copy { + pub mod softmax { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Copy3dPadded,CopyStrided,Copy,Copy2dTranspose,Copy4,Copy3dPaddedNobatch,Copy2d,Copy2d2,Copy3d,Copy2dTranspose2} + pub enum Functions{Softmax} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Copy3dPadded => "copy3d_padded",Functions::CopyStrided => "copy_strided",Functions::Copy => "copy",Functions::Copy2dTranspose => "copy2d_transpose",Functions::Copy4 => "copy_4",Functions::Copy3dPaddedNobatch => "copy3d_padded_nobatch",Functions::Copy2d => "copy2d",Functions::Copy2d2 => "copy2d2",Functions::Copy3d => "copy3d",Functions::Copy2dTranspose2 => "copy2d_transpose2" + Functions::Softmax => "softmax" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/copy.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/copy.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/copy.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/softmax.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/softmax.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/softmax.pwgsl_generated_u8.wgsl"), } } } @@ -500,198 +500,217 @@ pub mod kernels { } } - pub mod index_select { + pub mod conv2d { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{IndexSelect} + pub enum Functions{Conv2d7,Conv2d5,Conv2dKernelSize1Nopadding,Conv2d,Conv2dLongchannels2,Conv2dLongchannels2Nopadding,Conv2dKernelSize1,Im2col,Conv2dLongchannel,Conv2dLongchannelNopadding,Conv2dNopadding,Conv2d2,Conv2dTranspose} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::IndexSelect => "index_select" + Functions::Conv2d7 => "conv2d7",Functions::Conv2d5 => "conv2d5",Functions::Conv2dKernelSize1Nopadding => "conv2d_kernel_size_1_nopadding",Functions::Conv2d => "conv2d",Functions::Conv2dLongchannels2 => "conv2d_longchannels2",Functions::Conv2dLongchannels2Nopadding => "conv2d_longchannels2_nopadding",Functions::Conv2dKernelSize1 => "conv2d_kernel_size_1",Functions::Im2col => "im2col",Functions::Conv2dLongchannel => "conv2d_longchannel",Functions::Conv2dLongchannelNopadding => "conv2d_longchannel_nopadding",Functions::Conv2dNopadding => "conv2d_nopadding",Functions::Conv2d2 => "conv2d_2",Functions::Conv2dTranspose => "conv2d_transpose" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/index_select.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/index_select.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/index_select.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/conv2d.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/conv2d.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/conv2d.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul { + pub mod cmp { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Matmul1End,Matmul116,Matmul1,Matmul5,Matmul7} + pub enum Functions{CmpBufferFromBuffer} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Matmul1End => "matmul1_end",Functions::Matmul116 => "matmul1_16",Functions::Matmul1 => "matmul1",Functions::Matmul5 => "matmul5",Functions::Matmul7 => "matmul7" + Functions::CmpBufferFromBuffer => "cmp_buffer_from_buffer" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/matmul.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/matmul.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/matmul.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/cmp.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/cmp.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/cmp.pwgsl_generated_u8.wgsl"), } } } - pub mod pool2d { + pub mod where_cond { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{MaxPool2d,AvgPool2d} + pub enum Functions{WhereCondIndexU32} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::MaxPool2d => "max_pool2d",Functions::AvgPool2d => "avg_pool2d" + Functions::WhereCondIndexU32 => "where_cond_index_u32" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/pool2d.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/pool2d.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/pool2d.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/where_cond.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/where_cond.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/where_cond.pwgsl_generated_u8.wgsl"), } } } - pub mod reduce { + pub mod index_select { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Reduce,ReduceIndex} + pub enum Functions{IndexSelect} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Reduce => "reduce",Functions::ReduceIndex => "reduce_index" + Functions::IndexSelect => "index_select" + } + } + } + pub fn load_shader(typ : crate::DType) -> &'static str { + match typ{ + crate::DType::F32 => include_str!("kernels//generated/index_select.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/index_select.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/index_select.pwgsl_generated_u8.wgsl"), + } + } + } + + pub mod pool2d { + #[derive(Debug, Clone, PartialEq, Eq, Hash)] + pub enum Functions{MaxPool2d,AvgPool2d} + impl crate::EntryPoint for Functions{ + fn get_entry_point(&self) -> &'static str{ + match self{ + Functions::MaxPool2d => "max_pool2d",Functions::AvgPool2d => "avg_pool2d" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/reduce.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/reduce.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/reduce.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/pool2d.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/pool2d.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/pool2d.pwgsl_generated_u8.wgsl"), } } } - pub mod rms_norm { + pub mod binary { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{RmsNorm} + pub enum Functions{BinaryBufferFromBufferContiguousBoth,BinaryBufferInplace1ContiguousBoth,BinaryBufferFromBuffer,BinaryBufferInplace2ContiguousBoth} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::RmsNorm => "rms_norm" + Functions::BinaryBufferFromBufferContiguousBoth => "binary_buffer_from_buffer_contiguous_both",Functions::BinaryBufferInplace1ContiguousBoth => "binary_buffer_inplace1_contiguous_both",Functions::BinaryBufferFromBuffer => "binary_buffer_from_buffer",Functions::BinaryBufferInplace2ContiguousBoth => "binary_buffer_inplace2_contiguous_both" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/rms_norm.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/rms_norm.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/rms_norm.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/binary.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/binary.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/binary.pwgsl_generated_u8.wgsl"), } } } - pub mod softmax { + pub mod reduce { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Softmax} + pub enum Functions{Reduce,ReduceIndex} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Softmax => "softmax" + Functions::Reduce => "reduce",Functions::ReduceIndex => "reduce_index" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/softmax.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/softmax.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/softmax.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/reduce.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/reduce.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/reduce.pwgsl_generated_u8.wgsl"), } } } - pub mod unary { + pub mod conv1d { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{UnaryInplaceContiguous,UnaryFromBufferContiguous,UnaryFromBuffer} + pub enum Functions{Conv1d,Conv1dTranspose} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::UnaryInplaceContiguous => "unary_inplace_contiguous",Functions::UnaryFromBufferContiguous => "unary_from_buffer_contiguous",Functions::UnaryFromBuffer => "unary_from_buffer" + Functions::Conv1d => "conv1d",Functions::Conv1dTranspose => "conv1d_transpose" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/unary.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/unary.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/unary.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/conv1d.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/conv1d.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/conv1d.pwgsl_generated_u8.wgsl"), } } } - pub mod upsample { + pub mod rms_norm { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Upsample2d,Upsample1d} + pub enum Functions{RmsNorm} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Upsample2d => "upsample2d",Functions::Upsample1d => "upsample1d" + Functions::RmsNorm => "rms_norm" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/upsample.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/upsample.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/upsample.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels//generated/rms_norm.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels//generated/rms_norm.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels//generated/rms_norm.pwgsl_generated_u8.wgsl"), } } } - pub mod where_cond { + pub mod sgemm { + pub mod matmul32x32_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{WhereCondIndexU32} + pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::WhereCondIndexU32 => "where_cond_index_u32" + Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels//generated/where_cond.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels//generated/where_cond.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels//generated/where_cond.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul32x32_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul32x32_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul32x32_prefetch.pwgsl_generated_u8.wgsl"), } } } - pub mod sgemm { - pub mod matmul128x128 { + pub mod matmul64x64_8x8 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Matmul,MatmulNoPadded} + pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" + Functions::MatmulNoPadded => "matmul_no_padded",Functions::Matmul => "matmul" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul128x128.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul128x128.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul128x128.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64_8x8.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul128x128_prefetch { + pub mod matmul16x64_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -703,14 +722,14 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul128x128_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul128x128_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul128x128_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul16x64_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul16x64_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul16x64_prefetch.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul16x64 { + pub mod matmul24x24 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ @@ -722,14 +741,14 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul16x64.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul16x64.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul16x64.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x24.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x24.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x24.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul16x64_prefetch { + pub mod matmul16x64 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -741,33 +760,33 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul16x64_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul16x64_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul16x64_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul16x64.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul16x64.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul16x64.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul1x128 { + pub mod matmul32x32 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{MatmulNoPadded,Matmul} + pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::MatmulNoPadded => "matmul_no_padded",Functions::Matmul => "matmul" + Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul1x128.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul1x128.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul1x128.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul32x32.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul32x32.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul32x32.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul1x128_prefetch { + pub mod matmul128x128_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -779,9 +798,9 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul1x128_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul1x128_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul1x128_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul128x128_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul128x128_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul128x128_prefetch.pwgsl_generated_u8.wgsl"), } } } @@ -807,11 +826,11 @@ pub mod kernels { pub mod matmul1x256_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{MatmulNoPadded,Matmul} + pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::MatmulNoPadded => "matmul_no_padded",Functions::Matmul => "matmul" + Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" } } } @@ -824,7 +843,7 @@ pub mod kernels { } } - pub mod matmul24x24 { + pub mod matmul64x64_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ @@ -836,33 +855,33 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x24.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x24.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x24.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64_prefetch.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul24x24_prefetch { + pub mod matmul64x64 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Matmul,MatmulNoPadded} + pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" + Functions::MatmulNoPadded => "matmul_no_padded",Functions::Matmul => "matmul" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x24_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x24_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x24_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul24x48 { + pub mod matmul64x128_8x8 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -874,14 +893,14 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x48.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x48.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x48.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x128_8x8.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul24x48_prefetch { + pub mod matmul128x128 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ @@ -893,33 +912,14 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x48_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x48_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x48_prefetch.pwgsl_generated_u8.wgsl"), - } - } - } - - pub mod matmul32x32 { - #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Matmul,MatmulNoPadded} - impl crate::EntryPoint for Functions{ - fn get_entry_point(&self) -> &'static str{ - match self{ - Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" - } - } - } - pub fn load_shader(typ : crate::DType) -> &'static str { - match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul32x32.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul32x32.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul32x32.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul128x128.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul128x128.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul128x128.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul32x32_prefetch { + pub mod matmul64x128_8x8_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -931,9 +931,9 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul32x32_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul32x32_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul32x32_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x128_8x8_prefetch.pwgsl_generated_u8.wgsl"), } } } @@ -957,7 +957,7 @@ pub mod kernels { } } - pub mod matmul64x128_4x8_prefetch { + pub mod matmul24x24_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -969,52 +969,52 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x128_4x8_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x128_4x8_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x128_4x8_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x24_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x24_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x24_prefetch.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul64x128_8x8 { + pub mod matmul1x128 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Matmul,MatmulNoPadded} + pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" + Functions::MatmulNoPadded => "matmul_no_padded",Functions::Matmul => "matmul" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x128_8x8.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul1x128.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul1x128.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul1x128.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul64x128_8x8_prefetch { + pub mod matmul64x64_8x8_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{Matmul,MatmulNoPadded} + pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" + Functions::MatmulNoPadded => "matmul_no_padded",Functions::Matmul => "matmul" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x128_8x8_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x128_8x8_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64_8x8_prefetch.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul64x64 { + pub mod matmul64x128_4x8_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -1026,14 +1026,14 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x128_4x8_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x128_4x8_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x128_4x8_prefetch.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul64x64_8x8 { + pub mod matmul1x128_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{MatmulNoPadded,Matmul} impl crate::EntryPoint for Functions{ @@ -1045,14 +1045,14 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64_8x8.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul1x128_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul1x128_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul1x128_prefetch.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul64x64_8x8_prefetch { + pub mod matmul24x48 { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ @@ -1064,28 +1064,28 @@ pub mod kernels { } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64_8x8_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64_8x8_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x48.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x48.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x48.pwgsl_generated_u8.wgsl"), } } } - pub mod matmul64x64_prefetch { + pub mod matmul24x48_prefetch { #[derive(Debug, Clone, PartialEq, Eq, Hash)] - pub enum Functions{MatmulNoPadded,Matmul} + pub enum Functions{Matmul,MatmulNoPadded} impl crate::EntryPoint for Functions{ fn get_entry_point(&self) -> &'static str{ match self{ - Functions::MatmulNoPadded => "matmul_no_padded",Functions::Matmul => "matmul" + Functions::Matmul => "matmul",Functions::MatmulNoPadded => "matmul_no_padded" } } } pub fn load_shader(typ : crate::DType) -> &'static str { match typ{ - crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul64x64_prefetch.pwgsl_generated_f32.wgsl"), - crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul64x64_prefetch.pwgsl_generated_u32.wgsl"), - crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul64x64_prefetch.pwgsl_generated_u8.wgsl"), + crate::DType::F32 => include_str!("kernels/sgemm//generated/matmul24x48_prefetch.pwgsl_generated_f32.wgsl"), + crate::DType::U32 => include_str!("kernels/sgemm//generated/matmul24x48_prefetch.pwgsl_generated_u32.wgsl"), + crate::DType::U8 => include_str!("kernels/sgemm//generated/matmul24x48_prefetch.pwgsl_generated_u8.wgsl"), } } } diff --git a/candle-wgpu-kernels/src/kernels/convert.pwgsl b/candle-wgpu-kernels/src/kernels/convert.pwgsl index 373dd2e382..4d0486ea58 100644 --- a/candle-wgpu-kernels/src/kernels/convert.pwgsl +++ b/candle-wgpu-kernels/src/kernels/convert.pwgsl @@ -91,6 +91,24 @@ fn convert_u32_to_u8(@builtin(global_invocation_id) global_id: vec3) { #endif +#define convert_start_offset op_meta[0] +#define convert_size op_meta[1] + +@compute +@workgroup_size(64, 1, 1) +fn convert_u32_to_i64(@builtin(global_invocation_id) global_id: vec3) { + let id_dest = global_id.x; + let id_source = global_id.x; + + if (id_source >= convert_size) { + return; + } + + let x: u32 = v_input1[convert_start_offset + id_source]; + + v_dest_i64[id_dest] = i64(x); +} + #ifdef f32 @compute diff --git a/candle-wgpu-kernels/src/kernels/util.pwgsl b/candle-wgpu-kernels/src/kernels/util.pwgsl index 16399f6435..05f23a2574 100644 --- a/candle-wgpu-kernels/src/kernels/util.pwgsl +++ b/candle-wgpu-kernels/src/kernels/util.pwgsl @@ -10,6 +10,10 @@ #define DTYPE u32 #endif +#ifdef i64 +#define DTYPE i64 +#endif + @group(0) @binding(0) var v_dest: array; @@ -22,6 +26,8 @@ var v_dest_u32: array; @group(0) @binding(0) var v_dest_f32: array; +@group(0) @binding(0) +var v_dest_i64: array; @group(0) @binding(1) @@ -73,6 +79,11 @@ const MAXVALUE : u32 = 4294967295; const MINVALUE : i32 = -2147483648; const MAXVALUE : i32 = 2147483647; +#elifdef i64 + +const MINVALUE : i64 = -9223372036854775808; +const MAXVALUE : i64 = 9223372036854775807; + #elifdef u8 const MINVALUE : u32 = 0; @@ -189,4 +200,4 @@ fn get_size3() -> u32{ // fn get_size2(dims : u32) -> u32{ // return dims * 2 + 2; -// } \ No newline at end of file +// }