diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 15b5e42b1f2e2..22df42f2a24ef 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -148,6 +148,7 @@ option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algor option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF) +option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF) # Options related to reducing the binary size produced by the build # XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON @@ -948,6 +949,9 @@ if (onnxruntime_USE_WEBGPU) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) + if (onnxruntime_USE_EXTERNAL_DAWN) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1) + endif() endif() if (onnxruntime_USE_CANN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index a69d2649ad832..d9e833a2d8cd4 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -656,11 +656,16 @@ if (onnxruntime_USE_WEBGPU) # Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it. set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + # We are currently always using the D3D12 backend. + set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() onnxruntime_fetchcontent_makeavailable(dawn) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native dawn::dawn_proc) + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native) + endif() + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc) endif() set(onnxruntime_LINK_DIRS) diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index eb25c55ab23e0..02c2a5aee481c 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -22,6 +22,9 @@ onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native dawn::dawn_proc) + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native) + endif() + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc) set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 67e5a9c0aa08b..561f65a33b89c 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -523,6 +523,9 @@ set (onnxruntime_global_thread_pools_test_SRC ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc) +set (onnxruntime_webgpu_external_dawn_test_SRC + ${TEST_SRC_DIR}/webgpu/external_dawn/main.cc) + # tests from lowest level library up. # the order of libraries should be maintained, with higher libraries being added first in the list @@ -1884,4 +1887,13 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD endif() endif() +if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN) + AddTest(TARGET onnxruntime_webgpu_external_dawn_test + SOURCES ${onnxruntime_webgpu_external_dawn_test_SRC} + LIBS dawn::dawn_native ${onnxruntime_test_providers_libs} + DEPENDS ${all_dependencies} + ) + onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers) +endif() + include(onnxruntime_fuzz_test.cmake) diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch index d696d386452e8..7a2a01d55be46 100644 --- a/cmake/patches/dawn/dawn.patch +++ b/cmake/patches/dawn/dawn.patch @@ -15,40 +15,55 @@ index 9c0bd6fa4e..bf8a57aeac 100644 ############################################################################### # Do the 'complete_lib' build. diff --git a/src/dawn/native/Surface_metal.mm b/src/dawn/native/Surface_metal.mm -index ce55acbd43..baa4835362 100644 +index ce55acbd43..2cfd363479 100644 --- a/src/dawn/native/Surface_metal.mm +++ b/src/dawn/native/Surface_metal.mm -@@ -36,7 +36,13 @@ +@@ -33,10 +33,18 @@ + + #import + ++#include "dawn/common/Platform.h" ++ namespace dawn::native { bool InheritsFromCAMetalLayer(void* obj) { - id object = static_cast(obj); + id object = -+#if TARGET_OS_IOS ++#if DAWN_PLATFORM_IS(IOS) + (__bridge id)obj; -+#else ++#else // DAWN_PLATFORM_IS(IOS) + static_cast(obj); -+#endif ++#endif // DAWN_PLATFORM_IS(IOS) + return [object isKindOfClass:[CAMetalLayer class]]; } diff --git a/src/dawn/native/metal/SharedFenceMTL.mm b/src/dawn/native/metal/SharedFenceMTL.mm -index bde8bfea07..f2f6459e91 100644 +index bde8bfea07..8906185d6f 100644 --- a/src/dawn/native/metal/SharedFenceMTL.mm +++ b/src/dawn/native/metal/SharedFenceMTL.mm -@@ -40,7 +40,13 @@ ResultOrError> SharedFence::Create( +@@ -25,6 +25,8 @@ + // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ++#include "dawn/common/Platform.h" ++ + #include "dawn/native/metal/SharedFenceMTL.h" + + #include "dawn/native/ChainUtils.h" +@@ -39,8 +41,13 @@ ResultOrError> SharedFence::Create( + const SharedFenceMTLSharedEventDescriptor* descriptor) { DAWN_INVALID_IF(descriptor->sharedEvent == nullptr, "MTLSharedEvent is missing."); if (@available(macOS 10.14, iOS 12.0, *)) { - return AcquireRef(new SharedFence( +- return AcquireRef(new SharedFence( - device, label, static_cast>(descriptor->sharedEvent))); -+ device, label, -+#if TARGET_OS_IOS -+ (__bridge id)(descriptor->sharedEvent) -+#else -+ static_cast>(descriptor->sharedEvent) -+#endif -+ )); ++ return AcquireRef(new SharedFence(device, label, ++#if DAWN_PLATFORM_IS(IOS) ++ (__bridge id)(descriptor->sharedEvent) ++#else // DAWN_PLATFORM_IS(IOS) ++ static_cast>(descriptor->sharedEvent) ++#endif // DAWN_PLATFORM_IS(IOS) ++ )); } else { return DAWN_INTERNAL_ERROR("MTLSharedEvent not supported."); } diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index 1ce6d66881c3e..5d83790dc2737 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api) # optional providers option(USE_DML "Build with DirectML support" OFF) +option(USE_WEBGPU "Build with WebGPU support" OFF) option(USE_CUDA "Build with CUDA support" OFF) option(USE_TENSORRT "Build with TensorRT support" OFF) option(USE_COREML "Build with CoreML support" OFF) @@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF) if(USE_DML) add_compile_definitions(USE_DML=1) endif() +if(USE_WEBGPU) + add_compile_definitions(USE_WEBGPU=1) +endif() if(USE_CUDA) add_compile_definitions(USE_CUDA=1) endif() diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index 46f8b83b0c5c2..004a3c890a7e4 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -3,12 +3,14 @@ import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common'; -import { Binding, binding } from './binding'; +import { Binding, binding, initOrt } from './binding'; class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) { + initOrt(); + this.#inferenceSession = new binding.InferenceSession(); if (typeof pathOrBuffer === 'string') { this.#inferenceSession.loadModel(pathOrBuffer, options); @@ -27,10 +29,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { readonly outputNames: string[]; startProfiling(): void { - // TODO: implement profiling + // startProfiling is a no-op. + // + // if sessionOptions.enableProfiling is true, profiling will be enabled when the model is loaded. } endProfiling(): void { - // TODO: implement profiling + this.#inferenceSession.endProfiling(); } async run( diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts index d6d592a1665b3..56203f5a5ca02 100644 --- a/js/node/lib/binding.ts +++ b/js/node/lib/binding.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { InferenceSession, OnnxValue } from 'onnxruntime-common'; +import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = { @@ -28,6 +28,8 @@ export declare namespace Binding { run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType; + endProfiling(): void; + dispose(): void; } @@ -48,4 +50,35 @@ export const binding = // eslint-disable-next-line @typescript-eslint/naming-convention InferenceSession: Binding.InferenceSessionConstructor; listSupportedBackends: () => Binding.SupportedBackend[]; + initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void; }; + +let ortInitialized = false; +export const initOrt = (): void => { + if (!ortInitialized) { + ortInitialized = true; + let logLevel = 2; + if (env.logLevel) { + switch (env.logLevel) { + case 'verbose': + logLevel = 0; + break; + case 'info': + logLevel = 1; + break; + case 'warning': + logLevel = 2; + break; + case 'error': + logLevel = 3; + break; + case 'fatal': + logLevel = 4; + break; + default: + throw new Error(`Unsupported log level: ${env.logLevel}`); + } + } + binding.initOrtOnce(logLevel, Tensor); + } +}; diff --git a/js/node/script/build.ts b/js/node/script/build.ts index 133d1a0d981a0..dcdcb93377b4c 100644 --- a/js/node/script/build.ts +++ b/js/node/script/build.ts @@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator']; const REBUILD = !!buildArgs.rebuild; // --use_dml const USE_DML = !!buildArgs.use_dml; +// --use_webgpu +const USE_WEBGPU = !!buildArgs.use_webgpu; // --use_cuda const USE_CUDA = !!buildArgs.use_cuda; // --use_tensorrt @@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') { if (USE_DML) { args.push('--CDUSE_DML=ON'); } +if (USE_WEBGPU) { + args.push('--CDUSE_WEBGPU=ON'); +} if (USE_CUDA) { args.push('--CDUSE_CUDA=ON'); } diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 057066507621b..23d859351f426 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -11,7 +11,12 @@ #include "tensor_helper.h" #include -Napi::FunctionReference InferenceSessionWrap::constructor; +Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor; +Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor; + +Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() { + return InferenceSessionWrap::ortTensorConstructor; +} Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { #if defined(USE_DML) && defined(_WIN32) @@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { Ort::Global::api_ == nullptr, env, "Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version " "ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)."); - auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"}; - env.SetInstanceData(ortEnv); + // initialize binding Napi::HandleScope scope(env); Napi::Function func = DefineClass( env, "InferenceSession", - {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run), + {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), + InstanceMethod("run", &InferenceSessionWrap::Run), InstanceMethod("dispose", &InferenceSessionWrap::Dispose), + InstanceMethod("endProfiling", &InferenceSessionWrap::EndProfiling), InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr), InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)}); - constructor = Napi::Persistent(func); - constructor.SuppressDestruct(); + wrappedSessionConstructor = Napi::Persistent(func); + wrappedSessionConstructor.SuppressDestruct(); exports.Set("InferenceSession", func); Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends); exports.Set("listSupportedBackends", listSupportedBackends); + Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce); + exports.Set("initOrtOnce", initOrtOnce); + return exports; } +Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + Napi::HandleScope scope(env); + + int log_level = info[0].As().Int32Value(); + + Ort::Env* ortEnv = env.GetInstanceData(); + if (ortEnv == nullptr) { + ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"}; + env.SetInstanceData(ortEnv); + } + + Napi::Function tensorConstructor = info[1].As(); + ortTensorConstructor = Napi::Persistent(tensorConstructor); + ortTensorConstructor.SuppressDestruct(); + + return env.Undefined(); +} + InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {} @@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) { ? typeInfo.GetTensorTypeAndShapeInfo().GetElementType() : ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); } + + // cache preferred output locations + ParsePreferredOutputLocations(info[argsLength - 1].As(), outputNames_, preferredOutputLocations_); + if (preferredOutputLocations_.size() > 0) { + ioBinding_ = std::make_unique(*session_); + } } catch (Napi::Error const& e) { throw e; } catch (std::exception const& e) { @@ -167,7 +201,8 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { std::vector reuseOutput; size_t inputIndex = 0; size_t outputIndex = 0; - OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release(); + Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault}; try { for (auto& name : inputNames_) { @@ -175,7 +210,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { inputIndex++; inputNames_cstr.push_back(name.c_str()); auto value = feed.Get(name); - inputValues.push_back(NapiValueToOrtValue(env, value, memory_info)); + inputValues.push_back(NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo)); } } for (auto& name : outputNames_) { @@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { outputNames_cstr.push_back(name.c_str()); auto value = fetch.Get(name); reuseOutput.push_back(!value.IsNull()); - outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info)); + outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo)); } } @@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { runOptions = Ort::RunOptions{}; ParseRunOptions(info[2].As(), runOptions); } + if (preferredOutputLocations_.size() == 0) { + session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, + inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0], + inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0], + outputIndex == 0 ? nullptr : &outputValues[0], outputIndex); - session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, - inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0], - inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0], - outputIndex == 0 ? nullptr : &outputValues[0], outputIndex); + Napi::Object result = Napi::Object::New(env); - Napi::Object result = Napi::Object::New(env); + for (size_t i = 0; i < outputIndex; i++) { + result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputValues[i]))); + } + return scope.Escape(result); + } else { + // IO binding + ORT_NAPI_THROW_ERROR_IF(preferredOutputLocations_.size() != outputNames_.size(), env, + "Preferred output locations must have the same size as output names."); - for (size_t i = 0; i < outputIndex; i++) { - result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i])); - } + for (size_t i = 0; i < inputIndex; i++) { + ioBinding_->BindInput(inputNames_cstr[i], inputValues[i]); + } + for (size_t i = 0; i < outputIndex; i++) { + // TODO: support preallocated output tensor (outputValues[i]) + + if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) { + ioBinding_->BindOutput(outputNames_cstr[i], gpuBufferMemoryInfo); + } else { + ioBinding_->BindOutput(outputNames_cstr[i], cpuMemoryInfo); + } + } + + session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, *ioBinding_); + + auto outputs = ioBinding_->GetOutputValues(); + ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch."); - return scope.Escape(result); + Napi::Object result = Napi::Object::New(env); + for (size_t i = 0; i < outputIndex; i++) { + result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputs[i]))); + } + return scope.Escape(result); + } } catch (Napi::Error const& e) { throw e; } catch (std::exception const& e) { @@ -218,6 +281,8 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) { ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + this->ioBinding_.reset(nullptr); + this->defaultRunOptions_.reset(nullptr); this->session_.reset(nullptr); @@ -225,6 +290,20 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) { return env.Undefined(); } +Napi::Value InferenceSessionWrap::EndProfiling(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + + Napi::EscapableHandleScope scope(env); + + Ort::AllocatorWithDefaultOptions allocator; + + auto filename = session_->EndProfilingAllocated(allocator); + Napi::String filenameValue = Napi::String::From(env, filename.get()); + return scope.Escape(filenameValue); +} + Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); Napi::EscapableHandleScope scope(env); @@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo #ifdef USE_DML result.Set(result.Length(), createObject("dml", true)); #endif +#ifdef USE_WEBGPU + result.Set(result.Length(), createObject("webgpu", true)); +#endif #ifdef USE_CUDA result.Set(result.Length(), createObject("cuda", false)); #endif diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index effdd83e3aa02..0b3dd1178c807 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -12,9 +12,22 @@ class InferenceSessionWrap : public Napi::ObjectWrap { public: static Napi::Object Init(Napi::Env env, Napi::Object exports); + static Napi::FunctionReference& GetTensorConstructor(); + InferenceSessionWrap(const Napi::CallbackInfo& info); private: + /** + * [sync] initialize ONNX Runtime once. + * + * This function must be called before any other functions. + * + * @param arg0 a number specifying the log level. + * + * @returns undefined + */ + static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info); + /** * [sync] list supported backend list * @returns array with objects { "name": "cpu", requirementsInstalled: true } @@ -63,10 +76,19 @@ class InferenceSessionWrap : public Napi::ObjectWrap { */ Napi::Value Dispose(const Napi::CallbackInfo& info); + /** + * [sync] end the profiling. + * @param nothing + * @returns nothing + * @throw nothing + */ + Napi::Value EndProfiling(const Napi::CallbackInfo& info); + // private members // persistent constructor - static Napi::FunctionReference constructor; + static Napi::FunctionReference wrappedSessionConstructor; + static Napi::FunctionReference ortTensorConstructor; // session objects bool initialized_; @@ -81,4 +103,8 @@ class InferenceSessionWrap : public Napi::ObjectWrap { std::vector outputNames_; std::vector outputTypes_; std::vector outputTensorElementDataTypes_; + + // preferred output locations + std::vector preferredOutputLocations_; + std::unique_ptr ioBinding_; }; diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 0ed1ba08e6bf7..2e9614dc8e25e 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -9,12 +9,16 @@ #include "common.h" #include "session_options_helper.h" +#include "tensor_helper.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" #endif #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" #endif +#ifdef USE_WEBGPU +#include "core/providers/webgpu/webgpu_provider_factory.h" +#endif #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" #endif @@ -36,7 +40,12 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess Napi::Value epValue = epList[i]; std::string name; int deviceId = 0; +#ifdef USE_COREML int coreMlFlags = 0; +#endif +#ifdef USE_WEBGPU + std::unordered_map webgpu_options; +#endif if (epValue.IsString()) { name = epValue.As().Utf8Value(); } else if (!epValue.IsObject() || epValue.IsNull() || !epValue.As().Has("name") || @@ -49,9 +58,23 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess if (obj.Has("deviceId")) { deviceId = obj.Get("deviceId").As(); } +#ifdef USE_COREML if (obj.Has("coreMlFlags")) { coreMlFlags = obj.Get("coreMlFlags").As(); } +#endif +#ifdef USE_WEBGPU + for (const auto& nameIter : obj.GetPropertyNames()) { + Napi::Value nameVar = nameIter.second; + std::string name = nameVar.As().Utf8Value(); + if (name != "name") { + Napi::Value valueVar = obj.Get(nameVar); + ORT_NAPI_THROW_TYPEERROR_IF(!valueVar.IsString(), epList.Env(), "Invalid argument: sessionOptions.executionProviders must be a string or an object with property 'name'."); + std::string value = valueVar.As().Utf8Value(); + webgpu_options[name] = value; + } + } +#endif } // CPU execution provider @@ -77,6 +100,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } else if (name == "dml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId)); #endif +#ifdef USE_WEBGPU + } else if (name == "webgpu") { + sessionOptions.AppendExecutionProvider("WebGPU", webgpu_options); +#endif #ifdef USE_COREML } else if (name == "coreml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags)); @@ -95,6 +122,22 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } } +void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) { + for (const auto& kvp : obj) { + auto key = kvp.first.As().Utf8Value(); + Napi::Value value = kvp.second; + if (value.IsObject()) { + IterateExtraOptions(prefix + key + ".", value.As(), sessionOptions); + } else { + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(), + "Invalid argument: sessionOptions.extra value must be a string in Node.js binding."); + std::string entry = prefix + key; + auto val = value.As().Utf8Value(); + sessionOptions.AddConfigEntry(entry.c_str(), val.c_str()); + } + } +} + void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) { // Execution provider if (options.Has("executionProviders")) { @@ -162,6 +205,28 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio } } + // optimizedModelFilePath + if (options.Has("optimizedModelFilePath")) { + auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath"); + ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(), + "Invalid argument: sessionOptions.optimizedModelFilePath must be a string."); +#ifdef _WIN32 + auto str = optimizedModelFilePathValue.As().Utf16Value(); + std::basic_string optimizedModelFilePath = std::wstring{str.begin(), str.end()}; +#else + std::basic_string optimizedModelFilePath = optimizedModelFilePathValue.As().Utf8Value(); +#endif + sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str()); + } + + // extra + if (options.Has("extra")) { + auto extraValue = options.Get("extra"); + ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(), + "Invalid argument: sessionOptions.extra must be an object."); + IterateExtraOptions("", extraValue.As(), sessionOptions); + } + // execution mode if (options.Has("executionMode")) { auto executionModeValue = options.Get("executionMode"); @@ -195,4 +260,118 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio sessionOptions.SetLogSeverityLevel(static_cast(logLevelNumber)); } + + // Profiling + if (options.Has("enableProfiling")) { + auto enableProfilingValue = options.Get("enableProfiling"); + ORT_NAPI_THROW_TYPEERROR_IF(!enableProfilingValue.IsBoolean(), options.Env(), + "Invalid argument: sessionOptions.enableProfiling must be a boolean value."); + + if (enableProfilingValue.As().Value()) { + ORT_NAPI_THROW_TYPEERROR_IF(!options.Has("profileFilePrefix"), options.Env(), + "Invalid argument: sessionOptions.profileFilePrefix is required" + " when sessionOptions.enableProfiling is set to true."); + auto profileFilePrefixValue = options.Get("profileFilePrefix"); + ORT_NAPI_THROW_TYPEERROR_IF(!profileFilePrefixValue.IsString(), options.Env(), + "Invalid argument: sessionOptions.profileFilePrefix must be a string." + " when sessionOptions.enableProfiling is set to true."); +#ifdef _WIN32 + auto str = profileFilePrefixValue.As().Utf16Value(); + std::basic_string profileFilePrefix = std::wstring{str.begin(), str.end()}; +#else + std::basic_string profileFilePrefix = profileFilePrefixValue.As().Utf8Value(); +#endif + sessionOptions.EnableProfiling(profileFilePrefix.c_str()); + } else { + sessionOptions.DisableProfiling(); + } + } + + // external data + if (options.Has("externalData")) { + auto externalDataValue = options.Get("externalData"); + ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(), + "Invalid argument: sessionOptions.externalData must be an array."); + auto externalData = externalDataValue.As(); + std::vector> paths; + std::vector buffs; + std::vector sizes; + + for (const auto& kvp : externalData) { + Napi::Value value = kvp.second; + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(), + "Invalid argument: sessionOptions.externalData value must be an object in Node.js binding."); + Napi::Object obj = value.As(); + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(), + "Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding."); +#ifdef _WIN32 + auto path = obj.Get("path").As().Utf16Value(); + paths.push_back(std::wstring{path.begin(), path.end()}); +#else + auto path = obj.Get("path").As().Utf8Value(); + paths.push_back(path); +#endif + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") || + !obj.Get("data").IsBuffer() || + !(obj.Get("data").IsTypedArray() && obj.Get("data").As().TypedArrayType() == napi_uint8_array), + options.Env(), + "Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding."); + + auto data = obj.Get("data"); + if (data.IsBuffer()) { + buffs.push_back(data.As>().Data()); + sizes.push_back(data.As>().Length()); + } else { + auto typedArray = data.As(); + buffs.push_back(reinterpret_cast(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset()); + sizes.push_back(typedArray.ByteLength()); + } + } + sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes); + } +} + +void ParsePreferredOutputLocations(const Napi::Object options, const std::vector& outputNames, std::vector& preferredOutputLocations) { + if (options.Has("preferredOutputLocation")) { + auto polValue = options.Get("preferredOutputLocation"); + if (polValue.IsNull() || polValue.IsUndefined()) { + return; + } + if (polValue.IsString()) { + DataLocation location = ParseDataLocation(polValue.As().Utf8Value()); + ORT_NAPI_THROW_TYPEERROR_IF(location == DATA_LOCATION_NONE, options.Env(), + "Invalid argument: preferredOutputLocation must be an array or a valid string."); + + if (location == DATA_LOCATION_GPU_BUFFER || location == DATA_LOCATION_ML_TENSOR) { + preferredOutputLocations.resize(outputNames.size(), location); + } + } else if (polValue.IsObject()) { + preferredOutputLocations.resize(outputNames.size(), DATA_LOCATION_CPU); + + auto pol = polValue.As(); + for (const auto& nameIter : pol.GetPropertyNames()) { + Napi::Value nameVar = nameIter.second; + std::string name = nameVar.As().Utf8Value(); + // find the name in outputNames + auto it = std::find(outputNames.begin(), outputNames.end(), name); + ORT_NAPI_THROW_TYPEERROR_IF(it == outputNames.end(), options.Env(), + "Invalid argument: \"", name, "\" is not a valid output name."); + + Napi::Value value = pol.Get(nameVar); + DataLocation location = DATA_LOCATION_NONE; + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString() || (location = ParseDataLocation(value.As().Utf8Value())) == DATA_LOCATION_NONE, + options.Env(), + "Invalid argument: preferredOutputLocation[\"", name, "\"] must be a valid string."); + + size_t index = it - outputNames.begin(); + preferredOutputLocations[index] = location; + } + + if (std::all_of(preferredOutputLocations.begin(), preferredOutputLocations.end(), [](int loc) { return loc == DATA_LOCATION_CPU; })) { + preferredOutputLocations.clear(); + } + } else { + ORT_NAPI_THROW_TYPEERROR(options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string."); + } + } } diff --git a/js/node/src/session_options_helper.h b/js/node/src/session_options_helper.h index c0a9ae0d683e9..c6338f28e03c6 100644 --- a/js/node/src/session_options_helper.h +++ b/js/node/src/session_options_helper.h @@ -11,3 +11,6 @@ struct SessionOptions; // parse a Javascript session options object and fill the native SessionOptions object. void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions); + +// parse a Javascript session options object and prepare the preferred output locations. +void ParsePreferredOutputLocations(const Napi::Object options, const std::vector& outputNames, std::vector& preferredOutputLocations); \ No newline at end of file diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 54f1c5a09906e..637b977c7fee6 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -8,6 +8,7 @@ #include "common.h" #include "tensor_helper.h" +#include "inference_session_wrap.h" // make sure consistent with origin definition static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == 0, "definition not consistent with OnnxRuntime"); @@ -100,7 +101,7 @@ const std::unordered_map DATA_TYPE_NAME_ {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"string", ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, {"bool", ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, {"float16", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16}, {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}}; // currently only support tensor -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info) { +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info) { ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), env, "Tensor must be an object."); // check 'dims' @@ -110,6 +111,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* auto dimsArray = dimsValue.As(); auto len = dimsArray.Length(); + size_t elementSize = 1; std::vector dims; if (len > 0) { dims.reserve(len); @@ -122,17 +124,26 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* "Tensor.dims[", i, "] is invalid: ", dimDouble); int64_t dim = static_cast(dimDouble); dims.push_back(dim); + elementSize *= dim; } } + // check 'location' + auto tensorLocationValue = tensorObject.Get("location"); + ORT_NAPI_THROW_TYPEERROR_IF(!tensorLocationValue.IsString(), env, "Tensor.location must be a string."); + DataLocation tensorLocation = ParseDataLocation(tensorLocationValue.As().Utf8Value()); + ORT_NAPI_THROW_RANGEERROR_IF(tensorLocation == DATA_LOCATION_NONE, env, "Tensor.location is not supported."); + // check 'data' and 'type' - auto tensorDataValue = tensorObject.Get("data"); auto tensorTypeValue = tensorObject.Get("type"); ORT_NAPI_THROW_TYPEERROR_IF(!tensorTypeValue.IsString(), env, "Tensor.type must be a string."); auto tensorTypeString = tensorTypeValue.As().Utf8Value(); if (tensorTypeString == "string") { + auto tensorDataValue = tensorObject.Get("data"); + + ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_CPU, env, "Tensor.location must be 'cpu' for string tensors."); ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsArray(), env, "Tensor.data must be an array for string tensors."); auto tensorDataArray = tensorDataValue.As(); @@ -162,29 +173,42 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* auto v = DATA_TYPE_NAME_TO_ID_MAP.find(tensorTypeString); ORT_NAPI_THROW_TYPEERROR_IF(v == DATA_TYPE_NAME_TO_ID_MAP.end(), env, "Tensor.type is not supported: ", tensorTypeString); - ONNXTensorElementDataType elemType = v->second; - ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env, - "Tensor.data must be a typed array for numeric tensor."); + if (tensorLocation == DATA_LOCATION_CPU) { + auto tensorDataValue = tensorObject.Get("data"); + ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env, + "Tensor.data must be a typed array for numeric tensor."); + + auto tensorDataTypedArray = tensorDataValue.As(); + auto typedArrayType = tensorDataValue.As().TypedArrayType(); + ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env, + "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", + tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); - auto tensorDataTypedArray = tensorDataValue.As(); - auto typedArrayType = tensorDataValue.As().TypedArrayType(); - ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env, - "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", - tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); + char* buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); + size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); + size_t bufferByteLength = tensorDataTypedArray.ByteLength(); + return Ort::Value::CreateTensor(cpu_memory_info, buffer + bufferByteOffset, bufferByteLength, + dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + } else { + ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_GPU_BUFFER, env, "Tensor.location must be 'gpu-buffer' for IO binding."); - char* buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); - size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); - size_t bufferByteLength = tensorDataTypedArray.ByteLength(); - return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength, - dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + auto gpuBufferValue = tensorObject.Get("gpuBuffer"); + // nodejs: tensor.gpuBuffer is no longer a GPUBuffer in nodejs. we assume it is an external object (bind the OrtValue pointer). + ORT_NAPI_THROW_TYPEERROR_IF(!gpuBufferValue.IsExternal(), env, "Tensor.gpuBuffer must be an external object."); + Ort::Value dataValue(gpuBufferValue.As>().Data()); + void* gpuBuffer = dataValue.GetTensorMutableRawData(); + dataValue.release(); + + size_t dataByteLength = DATA_TYPE_ELEMENT_SIZE_MAP[elemType] * elementSize; + return Ort::Value::CreateTensor(webgpu_memory_info, gpuBuffer, dataByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + } } } -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value) { Napi::EscapableHandleScope scope(env); - auto returnValue = Napi::Object::New(env); auto typeInfo = value.GetTypeInfo(); auto onnxType = typeInfo.GetONNXType(); @@ -197,24 +221,26 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { // type auto typeCstr = DATA_TYPE_ID_TO_NAME_MAP[elemType]; ORT_NAPI_THROW_ERROR_IF(typeCstr == nullptr, env, "Tensor type (", elemType, ") is not supported."); - - returnValue.Set("type", Napi::String::New(env, typeCstr)); + auto type = Napi::String::New(env, typeCstr); // dims const size_t dimsCount = tensorTypeAndShapeInfo.GetDimensionsCount(); - std::vector dims; + std::vector dimsVector; if (dimsCount > 0) { - dims = tensorTypeAndShapeInfo.GetShape(); + dimsVector = tensorTypeAndShapeInfo.GetShape(); } - auto dimsArray = Napi::Array::New(env, dimsCount); + auto dims = Napi::Array::New(env, dimsCount); for (uint32_t i = 0; i < dimsCount; i++) { - dimsArray[i] = dims[i]; + dims[i] = dimsVector[i]; } - returnValue.Set("dims", dimsArray); + + // location + auto memoryInfo = value.GetTensorMemoryInfo(); + bool isGpuBuffer = memoryInfo.GetDeviceType() == OrtMemoryInfoDeviceType_GPU && + memoryInfo.GetAllocatorName() == "WebGPU_Buffer"; // size auto size = tensorTypeAndShapeInfo.GetElementCount(); - returnValue.Set("size", Napi::Number::From(env, size)); // data if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { @@ -234,20 +260,48 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { i == size - 1 ? tempBufferLength - tempOffsets[i] : tempOffsets[i + 1] - tempOffsets[i]); } } - returnValue.Set("data", Napi::Value(env, stringArray)); + + // new Tensor(stringArray /* string[] */, dims /* number[] */) + return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New({stringArray, dims})); } else { // number data - // TODO: optimize memory - auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); - if (size > 0) { - memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + if (isGpuBuffer) { + // Tensor.fromGpuBuffer(buffer, options) + Napi::Function tensorFromGpuBuffer = InferenceSessionWrap::GetTensorConstructor().Value().Get("fromGpuBuffer").As(); + OrtValue* underlyingOrtValue = value.release(); + + auto options = Napi::Object::New(env); + options.Set("dataType", type); + options.Set("dims", dims); + options.Set("dispose", Napi::Function::New( + env, [](const Napi::CallbackInfo& info) { + Ort::GetApi().ReleaseValue(reinterpret_cast(info.Data())); + return info.Env().Undefined(); + }, + "dispose", underlyingOrtValue)); + options.Set("download", Napi::Function::New( + env, [](const Napi::CallbackInfo& info) { + NAPI_THROW("not implemented"); + }, + "download", underlyingOrtValue)); + + return scope.Escape(tensorFromGpuBuffer.Call({Napi::External::New(env, underlyingOrtValue), options})); + } else { + // TODO: optimize memory + auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + if (size > 0) { + memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + } + napi_value typedArrayData; + napi_status status = + napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData); + NAPI_THROW_IF_FAILED(env, status, Napi::Value); + + // new Tensor(type, typedArrayData, dims) + return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New( + {type, + Napi::Value(env, typedArrayData), + dims})); } - napi_value typedArrayData; - napi_status status = - napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData); - NAPI_THROW_IF_FAILED(env, status, Napi::Value); - returnValue.Set("data", Napi::Value(env, typedArrayData)); } - - return scope.Escape(returnValue); } diff --git a/js/node/src/tensor_helper.h b/js/node/src/tensor_helper.h index 56b399ccc24ee..4a51e5240602a 100644 --- a/js/node/src/tensor_helper.h +++ b/js/node/src/tensor_helper.h @@ -9,7 +9,32 @@ #include "onnxruntime_cxx_api.h" // convert a Javascript OnnxValue object to an OrtValue object -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info); +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info); // convert an OrtValue object to a Javascript OnnxValue object -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value); +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value); + +enum DataLocation { + DATA_LOCATION_NONE = 0, + DATA_LOCATION_CPU = 1, + DATA_LOCATION_CPU_PINNED = 2, + DATA_LOCATION_TEXTURE = 3, + DATA_LOCATION_GPU_BUFFER = 4, + DATA_LOCATION_ML_TENSOR = 5 +}; + +inline DataLocation ParseDataLocation(const std::string& location) { + if (location == "cpu") { + return DATA_LOCATION_CPU; + } else if (location == "cpu-pinned") { + return DATA_LOCATION_CPU_PINNED; + } else if (location == "texture") { + return DATA_LOCATION_TEXTURE; + } else if (location == "gpu-buffer") { + return DATA_LOCATION_GPU_BUFFER; + } else if (location == "ml-tensor") { + return DATA_LOCATION_ML_TENSOR; + } else { + return DATA_LOCATION_NONE; + } +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc new file mode 100644 index 0000000000000..d1e5f53d7f637 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "contrib_ops/webgpu/bert/fast_gelu.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + FastGelu); + +Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform); + + shader.AdditionalImplementation() << TanhImpl; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " var a = " << x.GetByOffset("global_idx") << ";\n"; + if (Inputs().size() > 1) { + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); + if (bias_components_ == 1) { + shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n" + " a += x_value_t(" + << bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n"; + } else { + shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } + } + shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr); + + return Status::OK(); +} + +Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + auto* output = context.Output(0, input->Shape()); + + uint32_t data_size = SafeInt(output->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const auto vec_size = (data_size + 3) / 4; + uint32_t bias_size = 0; + int bias_components = 1; + + if (bias != nullptr) { + bias_size = SafeInt(bias->Shape().Size()); + if (bias_size % 4 == 0) { + bias_components = 4; + bias_size = bias_size / 4; + } + } + + FastGeluProgram program{bias_components}; + program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({vec_size}); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h new file mode 100644 index 0000000000000..fa40d52bf301f --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class FastGeluProgram final : public Program { + public: + FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int bias_components_; +}; + +class FastGelu final : public WebGpuKernel { + public: + FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc new file mode 100644 index 0000000000000..8997e8698d96d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc @@ -0,0 +1,36 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 1, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc new file mode 100644 index 0000000000000..d836c1ddf8675 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -0,0 +1,493 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MultiHeadAttention); + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderUsage::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" + << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" + << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; + if (has_bias_) { + shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; + } + shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; + if (has_bias_) { + shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; + } else { + shader.MainFunctionBody() << ";\n"; + } + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + assert(input_tensor->Shape().GetDims().size() == 3); + assert(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)}}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderUsage::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" + "let headIdx = workgroup_id.z;\n" + "let m = workgroup_id.y * TILE_SIZE;\n" + "let n = workgroup_id.x * TILE_SIZE;\n" + "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; + + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" + << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; + } else { + shader.MainFunctionBody() << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; + } + + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; + } + + shader.MainFunctionBody() << "var value = f32_val_t(0);\n" + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.past_sequence_length) {\n" + " tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " } else {\n" + " tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" + " }\n"; + } else { + shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; + } + + if (has_present_key_) { + shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" + << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" + << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " + attention_bias[outputIdx]"; + } + shader.MainFunctionBody() << ";\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { + const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + const int tile_size = 12; + const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = parameters.head_size / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size)) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AdditionalImplementation() << "var thread_max: array;\n" + << "var thread_sum: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = f32_val_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" + << " }\n" + << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " var f32input = f32_val_t(x[offset + i]);\n" + << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) { + const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1); + int work_group_size = 64; + const int d_comp = d / components; + if (d_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components}; + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .SetDispatchGroupSize(n) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(1.f / static_cast(d))}, + {static_cast(d_comp)}, + {static_cast(elementsPerThread)}}); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderUsage::UseUniform); + } + + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n"; + + shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; + + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" + << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + } else { + shader.MainFunctionBody() << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + shader.MainFunctionBody() << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.past_sequence_length) {\n" + << " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " } else {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" + << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" + << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << " let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" + << " output[outputIdx] = value;\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + AttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length) { + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; + const bool has_present_value = output_count > 1 && past_value != nullptr; + const int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size)}, + {static_cast(parameters.num_heads)}, + {static_cast(parameters.v_hidden_size)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + ; + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; + + const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length)); + + return Status::OK(); +} + +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); +} + +Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* bias = context.Input(3); + const Tensor* key_padding_mask = context.Input(4); + const Tensor* attention_bias = context.Input(5); + const Tensor* past_key = context.Input(6); + const Tensor* past_value = context.Input(7); + + if (query->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu"); + } + if (key != nullptr && key->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed KV not implemented for webgpu"); + } + if (key_padding_mask) { + ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); + } + + AttentionParameters parameters; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, + num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(parameters.sequence_length); + output_shape[2] = static_cast(parameters.v_hidden_size); + Tensor* output = context.Output(0, output_shape); + + // If optional outputs aren't needed, present_key and present_value will be null + std::vector present_dims{ + parameters.batch_size, + parameters.num_heads, + parameters.total_sequence_length, + parameters.head_size, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context.Output(1, present_shape); + Tensor* present_value = context.Output(2, present_shape); + + TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, parameters.head_size}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q)); + + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); + } + + TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.head_size}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.head_size, key, bias, parameters.hidden_size, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.v_head_size}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); + + // Compute the attention score and apply the score to V + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h new file mode 100644 index 0000000000000..36803e3027b4c --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias, int tile_size, int components) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32}, + {"d_comp", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + + private: + int work_group_size_; + int components_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; +}; + +class MultiHeadAttention final : public WebGpuKernel { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + protected: + int num_heads_; + float mask_filter_value_; + float scale_; + bool is_unidirectional_{false}; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..85ab94706b149 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform); + const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + // TODO: remove output_indices. + const auto& output_indices = shader.AddIndices("output_indices", false); + const auto interleaved_str = interleaved_ ? "true" : "false"; + shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" + " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n" + " if (global_idx >= size) { return; }\n" + " if (bsnh[3] < half_rotary_emb_dim) {\n" + << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n" + << " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n" + << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n" + << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n" + << " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("i", "re") << "\n" + << " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("j", "im") << "\n" + << " } else { \n" + " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + << " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n" + << " }"; + + return Status::OK(); +} + +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { + scale_ = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads_ = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved_ = (info.GetAttrOrDefault("interleaved", 0) == 1); + is_packed_batching_ = (info.GetAttrOrDefault("is_packed_batching", 0) == 1); +} + +Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto input_shape = input->Shape(); + const auto* position_ids = context.Input(1); + const auto* cos_cache = context.Input(2); + const auto* sin_cache = context.Input(3); + auto* output = context.Output(0, input_shape); + + const auto batch_size = gsl::narrow_cast(input->Shape()[0]); + const auto batch_stride = gsl::narrow_cast(input_shape.SizeFromDimension(1)); + const auto sequence_length = gsl::narrow_cast(input_shape[input_shape.NumDimensions() - 2]); + const auto hidden_size = batch_stride / sequence_length; + const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); + const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const TensorShape global_shape({batch_size, + sequence_length, + hidden_size / head_size, + head_size - half_rotary_embedding_dim}); + + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = gsl::narrow_cast(global_shape[j]); + global_strides[j] = gsl::narrow_cast(global_shape.SizeFromDimension(j + 1)); + } + + const auto output_size = gsl::narrow_cast(global_shape.Size()); + RotaryEmbeddingProgram program{interleaved_}; + const auto input_output_strides = + input_shape.NumDimensions() == 3 + ? std::vector({batch_stride, hidden_size, head_size, 1}) + : (input_shape.NumDimensions() == 4 + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + : std::vector({})); + + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::Rank}, + {position_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{scale_}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..0d73b89fb62df --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class RotaryEmbeddingProgram final : public Program { + public: + RotaryEmbeddingProgram(bool interleaved) : Program{"RotaryEmbedding"}, interleaved_{interleaved} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32}, + {"global_shape", ProgramUniformVariableDataType::Uint32}, + {"global_stride", ProgramUniformVariableDataType::Uint32}, + {"input_output_stride", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; +}; + +class RotaryEmbedding final : public WebGpuKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + float scale_; + int num_heads_; + int rotary_embedding_dim_; + bool interleaved_; + bool is_packed_batching_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc new file mode 100644 index 0000000000000..254dd26b8a142 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +static uint32_t GetMaxComponents(int size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("skip", ShaderUsage::UseUniform); + shader.AddInput("gamma", ShaderUsage::UseUniform); + if (hasBeta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + if (hasBias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_input_skip_bias_sum_) { + shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseUniform); + } + + int components = x.NumComponents(); + + std::string bias = (hasBias_) ? " + bias[offset1d + i] " : ""; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; + std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; + std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? "input_skip_bias_sum[offset + i] = value;\n" : ""; + + shader.AdditionalImplementation() + << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n" + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n" + << "var stride = hidden_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * hidden_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = hidden_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let skip_value = skip[offset + i];\n" + << " let input_value = x[offset + i];\n" + << " let value = input_value + skip_value" << bias << ";\n" + << " output[offset + i] = value;\n" + << input_skip_bias_sum + << " let f32_value = f32_val_t(value);\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" + << "};\n"; + + return Status::OK(); +} + +template +Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* x = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* gamma = context.Input(2); + // optional + const Tensor* beta = context.Input(3); + const Tensor* bias = context.Input(4); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + auto* input_skip_bias_sum = context.Output(3, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + const uint32_t hidden_size = SafeInt(x_shape[x_shape.NumDimensions() - 1]); + const int components = GetMaxComponents(hidden_size); + const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr; + + SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, is_fp16, simplified}; + program + .CacheHint(simplified, has_input_skip_bias_sum) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize(SafeInt(ceil(1.0 * data_size / hidden_size))) + .AddUniformVariables({ + {static_cast(components)}, + }) + .AddUniformVariables({ + {static_cast(hidden_size)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (beta != nullptr) { + program.AddInput({beta, ProgramTensorMetadataDependency::Type, components}); + } + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + if (has_input_skip_bias_sum) { + program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + SkipLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SkipSimplifiedLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h new file mode 100644 index 0000000000000..03de1a4b568b9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class SkipLayerNormProgram final : public Program { + public: + SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool has_input_skip_bias_sum, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} { + epsilon_ = epsilon; + hasBeta_ = hasBeta; + hasBias_ = hasBias; + epsilon_ = epsilon; + hidden_size_ = hidden_size; + has_input_skip_bias_sum_ = has_input_skip_bias_sum; + simplified_ = simplified; + is_fp16_ = is_fp16; + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"components", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool hasBeta_; + bool hasBias_; + float epsilon_; + uint32_t hidden_size_; + bool has_input_skip_bias_sum_; + bool is_fp16_; + bool simplified_; +}; + +template +class SkipLayerNorm final : public WebGpuKernel { + public: + SkipLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + float epsilon_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..ad3b3ff662b79 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -0,0 +1,394 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { +// Put it to a common place? +uint32_t GetMaxComponents(uint32_t size) { + // we cannot use vec3 type since it has alignment of 16 bytes + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + + return 1; +} + +std::string QuantizedDataType(int components) { + switch (components) { + case 1: + return "array"; + case 2: + return "mat4x2"; + case 4: + return "mat2x4"; + default: + return "array"; + } +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + MatMulNBits); + +Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + if (use_block32_) { + const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); + const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. + const uint32_t a_length_per_tile = tile_size / a.NumComponents(); + const uint32_t block_size = 32; + const uint32_t blocks_per_tile = tile_size / block_size; + shader.AdditionalImplementation() << "var sub_a: array;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">;\n"; + std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY()); + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" + // Loop over shared dimension. + << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" + << " let a_col_start = tile * " << a_length_per_tile << ";\n" + << " // load one tile A data into shared memory.\n" + << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" + << " let a_col = a_col_start + a_offset;\n" + " if (a_col < uniforms.input_a_shape[2]) {\n" + << " sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n" + << " } else {\n" + " sub_a[a_offset] = input_a_value_t(0);\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + // Each thread processes one block. + " let b_row = col + local_id.y;\n" + << " let block = tile * " << blocks_per_tile << " + local_id.x;\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n" + " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; + } else { + // The default zero point is 8 for unsigned 4-bit quantization. + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; + } + shader.MainFunctionBody() << " var scale = output_element_t(0);\n" + " var b_data = input_b_value_t(0);\n" + << " if (block < n_blocks_per_col) {\n" + << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" + << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" + << " }\n" + << " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n" + << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " let a_data0 = vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);\n" + " let a_data1 = vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);\n"; + break; + case 2: + shader.MainFunctionBody() << " let a_data0 = vec4(sub_a[word_offset], sub_a[word_offset + 1]);\n" + " let a_data1 = vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]);\n"; + break; + case 4: + shader.MainFunctionBody() << " let a_data0 = sub_a[word_offset];\n" + " let a_data1 = sub_a[word_offset + 1];\n"; + break; + default: + break; + } + shader.MainFunctionBody() << " let b_value = b_data"; + if (components_b_ > 1) { + shader.MainFunctionBody() << "[i]"; + } + shader.MainFunctionBody() << ";\n" + " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" + " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" + " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + " let b_dequantized_values = (b_quantized_values - mat2x4("; + for (int i = 0; i < 8; i++) { + shader.MainFunctionBody() << "zero_point"; + if (i < 7) { + shader.MainFunctionBody() << ", "; + } + } + shader.MainFunctionBody() << ")) * scale;\n" + " inter_results[local_id.y][local_id.x] += dot(a_data0, b_dequantized_values[0]) + dot(a_data1, b_dequantized_values[1]);\n" + << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " workgroupBarrier();\n" + " }\n" + << " if (local_idx < " << WorkgroupSizeY() << ") {\n" + << " var output_value = output_value_t(0);\n" + << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" + << " output_value += inter_results[local_idx][b];\n" + " }\n" + " if (col + local_idx < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n" + << " }\n" + " }\n"; + } else { + const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); + const int output_element_number = y.NumComponents() * SafeInt(output_number_); + + const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; + std::string offset = "workgroup_idx * " + std::to_string(output_number_); + shader.AdditionalImplementation() << "var workgroup_shared : array;\n"; + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + " let blob_size = uniforms.input_b_shape[2];\n" + " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" + << " var word_offset = block * uniforms.block_size / " << a.NumComponents() << ";\n"; + + // prepare scale and zero point + shader.MainFunctionBody() << " var col_index = col * " << y.NumComponents() << ";\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " var zero_point_byte_count: u32;\n" + " var zero_point_word_index: u32;\n" + " var zero_point_byte_offset: u32;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " var zero_point_bits_offset: u32;\n" + " var zero_point_word: u32;\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" + " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n" + << " col_index += 1;\n"; + } + } else { + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " col_index += 1;\n"; + } + } + + shader.MainFunctionBody() << " for (var word: u32 = 0; word < blob_size; word += 1) {\n"; + + // prepare b data + shader.MainFunctionBody() << " col_index = col * " << y.NumComponents() << ";\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" + << " col_index += 1;\n"; + } + shader.MainFunctionBody() << " var b_value : u32;\n" + " let b_mask : u32 = 0x0F0F0F0Fu;\n" + " var b_value_lower : vec4;\n" + " var b_value_upper : vec4;\n" + << " var b_quantized_values : " << quantized_data_type << ";\n" + << " var b_dequantized_values : " << quantized_data_type << ";\n"; + + shader.MainFunctionBody() << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; + + // process one word + shader.MainFunctionBody() << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" + << " var a_data: " << quantized_data_type << ";\n" + << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" + << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" + << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" + << " input_offset++;\n" + " } else {\n" + " a_data[j] = input_a_value_t(0);\n" + " }\n" + " }\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " b_value = b" << c << "_data"; + if (components_b_ > 1) { + shader.MainFunctionBody() << "[i]"; + } + shader.MainFunctionBody() << ";\n" + " b_value_lower = unpack4xU8(b_value & b_mask);\n" + " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" + << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + << " b_dequantized_values = "; + if (a.NumComponents() == 1) { + if (has_zero_points_) { + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; + } else { + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " + << "(b_quantized_values[1] - zero_point) * scale" << c << "," + << "(b_quantized_values[2] - zero_point) * scale" << c << "," + << "(b_quantized_values[3] - zero_point) * scale" << c << "," + << "(b_quantized_values[4] - zero_point) * scale" << c << "," + << "(b_quantized_values[5] - zero_point) * scale" << c << "," + << "(b_quantized_values[6] - zero_point) * scale" << c << "," + << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; + } + } else { + shader.MainFunctionBody() << "(b_quantized_values - " << quantized_data_type << "("; + for (int i = 0; i < 8; i++) { + if (has_zero_points_) { + shader.MainFunctionBody() << "zero_point" << c; + } else { + shader.MainFunctionBody() << "zero_point"; + } + if (i < 7) { + shader.MainFunctionBody() << ", "; + } + } + shader.MainFunctionBody() << ")) * scale" << c << ";\n"; + } + + shader.MainFunctionBody() << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; + if (y.NumComponents() > 1) { + shader.MainFunctionBody() << "[" << c % y.NumComponents() << "]"; + } + shader.MainFunctionBody() << " += "; + if (a.NumComponents() == 1) { + shader.MainFunctionBody() << "a_data[0] * b_dequantized_values[0] + " + "a_data[1] * b_dequantized_values[1] + " + "a_data[2] * b_dequantized_values[2] + " + "a_data[3] * b_dequantized_values[3] + " + "a_data[4] * b_dequantized_values[4] + " + "a_data[5] * b_dequantized_values[5] + " + "a_data[6] * b_dequantized_values[6] + " + "a_data[7] * b_dequantized_values[7];\n"; + } else if (a.NumComponents() == 2) { + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]) + " + "dot(a_data[2], b_dequantized_values[2]) + " + "dot(a_data[3], b_dequantized_values[3]);\n"; + } else if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]);\n"; + } + } + + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + << " if (local_id.x < " << output_number_ << ") {\n" + << " var output_value = output_value_t(0);\n" + " var workgroup_shared_offset = local_id.x;\n" + << " let blocks_num = min(" << shared_memory_size << ", n_blocks_per_col);\n" + << " for (var b = 0u; b < blocks_num; b++) {\n" + " output_value += workgroup_shared[workgroup_shared_offset];\n" + << " workgroup_shared_offset += " << output_number_ << ";\n" + << " }\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value") << "\n" + << " }\n"; + } + + return Status::OK(); +} + +Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* b = context.Input(1); + const Tensor* scales = context.Input(2); + const Tensor* zero_points = context.Input(3); + const Tensor* g_idx = context.Input(4); + const Tensor* bias = context.Input(5); + + ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet."); + ORT_ENFORCE(bias == nullptr, "bias as input is not supported yet."); + + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + auto* y = context.Output(0, helper.OutputShape()); + const uint32_t data_size = SafeInt(y->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const uint32_t batch_count = SafeInt(helper.OutputOffsets().size()); + const uint32_t M = SafeInt(helper.M()); + const uint32_t N = SafeInt(helper.N()); + const uint32_t K = SafeInt(helper.K()); + const uint32_t block_size = SafeInt(block_size_); + const uint32_t nbits = 4; + + const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; + const uint32_t blob_size = (block_size / 8) * nbits; + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + uint32_t components = GetMaxComponents(N); + const bool is_intel = !std::strcmp(context.AdapterInfo().vendor, "intel") && !std::strcmp(context.AdapterInfo().architecture, "gen-12lp"); + const bool use_block32 = is_intel && block_size == 32; + const bool has_zero_points = zero_points != nullptr; + // TODO: Support output_number > 1. Some cases are failed when output_number > 1. + // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; + const uint32_t output_number = 1; + MatMulNBitsProgram program{output_number, SafeInt(components_b), has_zero_points, use_block32}; + + if (use_block32) { + components = 1; + const uint32_t workgroup_size = 128; + const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 + : 1; + const uint32_t workgroup_x = workgroup_size / workgroup_y; + program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); + program.SetDispatchGroupSize(data_size / components / workgroup_y); + } else { + program.SetDispatchGroupSize(data_size / components / output_number); + } + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, SafeInt(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, SafeInt(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, SafeInt(components)}) + .AddUniformVariable({block_size}) + .CacheHint(std::to_string(output_number)); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..c0d6b3e6379cd --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MatMulNBitsProgram final : public Program { + public: + MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points, bool use_block32) : Program{"MatMulNBits"}, + output_number_{output_number}, + components_b_{components_b}, + has_zero_points_{has_zero_points}, + use_block32_{use_block32} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t output_number_; + int components_b_; + bool has_zero_points_; + bool use_block32_; +}; + +class MatMulNBits final : public WebGpuKernel { + public: + MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + N_ = info.GetAttr("N"); + block_size_ = info.GetAttr("block_size"); + int64_t bits = info.GetAttr("bits"); + ORT_ENFORCE(bits == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 8ed1372cd0e62..4006006a76ba8 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -9,6 +9,24 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); +// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -18,7 +36,22 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - }; + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md new file mode 100644 index 0000000000000..fe0d99b1d602b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/README.md @@ -0,0 +1,39 @@ +# WebGPU Execution Provider + +This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU EP is working in progress. + +## Build WebGPU EP + +Just append `--use_webgpu` to the `build.bat`/`build.sh` command line. + +For linux, a few dependencies need to be installed: +```sh +apt-get install libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev libx11-dev libx11-xcb-dev +``` + +## Troubleshooting + +TODO: add solutions to common problems. + +## Development Guide + +See [How to write WebGPU EP kernel](./docs/How_to_Write_WebGPU_EP_Kernel.md) for more information. + +## Conventions + +See [Conventions](./docs/Conventions.md) for more information. + +## Best Practices + +See [Best Practices](./docs/Best_Practices.md) for more information. + +## TODO items + +The following items are not yet implemented: + +- [ ] Validation Switch (allows to change the behavior of whether perform specific validation checks) +- [ ] pushErrorScope/popErrorScope +- [ ] Graph Capture +- [ ] Profiling supported by WebGPU Query Buffer +- [ ] WebGPU resources tracking (mainly for buffers) +- [ ] Global hanlders( unhandled exceptions and device lost ) diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc new file mode 100644 index 0000000000000..8e27acdc285d4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/framework/session_state.h" +#include "core/providers/webgpu/allocator.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +void* GpuBufferAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + + auto buffer = context_.BufferManager().Create(size); + + stats_.num_allocs++; + return buffer; +} + +void GpuBufferAllocator::Free(void* p) { + if (p != nullptr) { + context_.BufferManager().Release(static_cast(p)); + stats_.num_allocs--; + } +} + +void GpuBufferAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h new file mode 100644 index 0000000000000..51ca65a8b4822 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class GpuBufferAllocator : public IAllocator { + public: + GpuBufferAllocator(const WebGpuContext& context) + : IAllocator( + OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), + 0, OrtMemTypeDefault)), + context_{context} { + } + + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc new file mode 100644 index 0000000000000..8751338d24178 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +size_t NormalizeBufferSize(size_t size) { + return (size + 15) / 16 * 16; +} + +class DisabledCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + // always return empty buffer + return nullptr; + } + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + void ReleaseBuffer(WGPUBuffer buffer) override { + wgpuBufferRelease(buffer); + } + + void OnRefresh() override { + // no-op + } +}; + +class LazyReleaseCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + pending_buffers_.clear(); + } + + std::vector pending_buffers_; +}; + +class SimpleCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buffers_.find(buffer_size); + if (it != buffers_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + buffers_[wgpuBufferGetSize(buffer)].push_back(buffer); + } + pending_buffers_.clear(); + } + + std::map> buffers_; + std::vector pending_buffers_; +}; + +// TODO: maybe use different bucket size for storage and uniform buffers? +constexpr std::initializer_list> BUCKET_DEFAULT_LIMIT_TABLE = { + {64, 250}, + {128, 200}, + {256, 200}, + {512, 200}, + {2048, 230}, + {4096, 200}, + {8192, 50}, + {16384, 50}, + {32768, 50}, + {65536, 50}, + {131072, 50}, + {262144, 50}, + {524288, 50}, + {1048576, 50}, + {2097152, 30}, + {4194304, 20}, + {8388608, 10}, + {12582912, 10}, + {16777216, 10}, + {26214400, 15}, + {33554432, 22}, + {44236800, 2}, + {58982400, 6}, + // we don't want to cache the bucket sizes below but not caching them + // results in some major performance hits for models like sd-turbo. + {67108864, 6}, + {134217728, 6}, + {167772160, 6}, +}; + +class BucketCacheManager : public IBufferCacheManager { + public: + BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + BucketCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + // TODO: consider graph capture. currently not supported + + for (auto& buffer : pending_buffers_) { + auto buffer_size = wgpuBufferGetSize(buffer); + + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.push_back(buffer); + } else { + wgpuBufferRelease(buffer); + } + } + + pending_buffers_.clear(); + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + buckets_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + buckets_.emplace(pair.first, std::vector()); + } + std::sort(buckets_keys_.begin(), buckets_keys_.end()); + +#ifndef NDEBUG // if debug build + for (size_t i = 0; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] % 16 == 0, "Bucket sizes must be multiples of 16."); + } + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_; + std::vector pending_buffers_; + std::vector buckets_keys_; +}; + +std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { + switch (cache_mode) { + case BufferCacheMode::Disabled: + return std::make_unique(); + case BufferCacheMode::LazyRelease: + return std::make_unique(); + case BufferCacheMode::Simple: + return std::make_unique(); + case BufferCacheMode::Bucket: + return std::make_unique(); + default: + ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode"); + } +} + +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { + switch (mode) { + case BufferCacheMode::Disabled: + os << "Disabled"; + break; + case BufferCacheMode::LazyRelease: + os << "LazyRelease"; + break; + case BufferCacheMode::Simple: + os << "Simple"; + break; + case BufferCacheMode::Bucket: + os << "Bucket"; + break; + default: + os << "Unknown(" << static_cast(mode) << ")"; + } + return os; +} + +BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) + : context_{context}, + storage_cache_{CreateBufferCacheManager(storage_buffer_cache_mode)}, + uniform_cache_{CreateBufferCacheManager(uniform_buffer_cache_mode)}, + query_resolve_cache_{CreateBufferCacheManager(query_resolve_buffer_cache_mode)}, + default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} { +} + +void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite; + desc.mappedAtCreation = true; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto mapped_data = staging_buffer.GetMappedRange(); + memcpy(mapped_data, src, size); + staging_buffer.Unmap(); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); + pending_staging_buffers_.push_back(staging_buffer); +} + +void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { + ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); + + auto buffer_size = NormalizeBufferSize(size); + ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + "Source and destination buffers must have enough space for the copy operation. src_size=", + wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); +} + +WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { + auto& cache = GetCacheManager(static_cast(usage)); + auto buffer_size = cache.CalculateBufferSize(size); + + auto buffer = cache.TryAcquireCachedBuffer(buffer_size); + if (buffer) { + return buffer; + } + + // cache miss, create a new buffer + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = usage; + // desc.label = std::to_string(xx++).c_str(); + buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); + + ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); + + cache.RegisterBuffer(buffer, size); + return buffer; +} + +void BufferManager::Release(WGPUBuffer buffer) { + GetCacheManager(buffer).ReleaseBuffer(buffer); +} + +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); + context_.Flush(); + + // TODO: revise wait in whole project + + ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, wgpu::CallbackMode::WaitAnyOnly, [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + + auto mapped_data = staging_buffer.GetConstMappedRange(); + memcpy(dst, mapped_data, size); +} + +void BufferManager::RefreshPendingBuffers() { + pending_staging_buffers_.clear(); + storage_cache_->OnRefresh(); + uniform_cache_->OnRefresh(); + query_resolve_cache_->OnRefresh(); + default_cache_->OnRefresh(); +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const { + if (usage & WGPUBufferUsage_Storage) { + return *storage_cache_; + } else if (usage & WGPUBufferUsage_Uniform) { + return *uniform_cache_; + } else if (usage & WGPUBufferUsage_QueryResolve) { + return *query_resolve_cache_; + } else { + return *default_cache_; + } +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const { + return GetCacheManager(wgpuBufferGetUsage(buffer)); +} + +std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) { + return std::make_unique(context, storage_buffer_cache_mode, uniform_buffer_cache_mode, query_resolve_buffer_cache_mode); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h new file mode 100644 index 0000000000000..00febfbc29f1b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +enum class BufferCacheMode { + Disabled, + LazyRelease, + Simple, + Bucket +}; +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); + +// +// IBufferCacheManager is an interface for buffer cache management. +// +// By implementing this interface, we can have different buffer cache management strategies. +// Currently, we have 3 strategies: +// - Disabled: no cache. always allocate a new buffer and release it immediately after use. +// - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh. +// - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache. +// - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket. +// +class IBufferCacheManager { + public: + virtual ~IBufferCacheManager() = default; + + // calculate actual buffer size to allocate based on the requested size. + virtual size_t CalculateBufferSize(size_t request_size) = 0; + + // return a buffer if available in cache. otherwise empty. + virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) = 0; + + // register a newly created buffer + virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size) = 0; + + // release a buffer + virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; + + // when a stream refresh is requested + virtual void OnRefresh() = 0; +}; + +// +// BufferManager manages operations on buffers. +// +class BufferManager { + public: + BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + void Upload(void* src, WGPUBuffer dst, size_t size); + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size); + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + void Release(WGPUBuffer buffer); + void Download(WGPUBuffer src, void* dst, size_t size); + void RefreshPendingBuffers(); + + private: + IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const; + IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; + + WebGpuContext& context_; + std::unique_ptr storage_cache_; + std::unique_ptr uniform_cache_; + std::unique_ptr query_resolve_cache_; + std::unique_ptr default_cache_; + + std::vector pending_staging_buffers_; +}; + +class BufferManagerFactory { + public: + static std::unique_ptr Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + private: + BufferManagerFactory() {} +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc new file mode 100644 index 0000000000000..ce4f3e49611e2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { +ComputeContext::ComputeContext(OpKernelContext& kernel_context) + : webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())}, + kernel_context_{kernel_context} { +} + +void ComputeContext::PushErrorScope() { + if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) { + webgpu_context_.Device().PushErrorScope(wgpu::ErrorFilter::Validation); + } +} + +Status ComputeContext::PopErrorScope() { + Status status{}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_ERROR(webgpu_context_.Wait( + webgpu_context_.Device().PopErrorScope( + wgpu::CallbackMode::WaitAnyOnly, [](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) { + ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped."); + if (error_type == wgpu::ErrorType::NoError) { + return; + } + *status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message); + }, + &status))); + } + return status; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h new file mode 100644 index 0000000000000..b7ea8a58e232b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include + +#include "core/framework/execution_provider.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class Tensor; + +namespace webgpu { + +class WebGpuContext; + +class ComputeContext { + public: + ComputeContext(OpKernelContext& kernel_context); + + virtual ~ComputeContext() = default; + + // + // Get various information from the context. + // + + inline const wgpu::AdapterInfo& AdapterInfo() const { + return webgpu_context_.AdapterInfo(); + } + inline const wgpu::Limits& DeviceLimits() const { + return webgpu_context_.DeviceLimits(); + } + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; + } + + // + // Get the logger. + // + inline const logging::Logger& Logger() const { + return kernel_context_.Logger(); + } + + // + // Get input tensor. + // + template + inline const T* Input(int index) const { + return kernel_context_.Input(index); + } + + // + // Get input count. + // + inline int InputCount() const { + return kernel_context_.InputCount(); + } + + // + // Set output tensor. + // + template + inline Tensor* Output(int index, TensorShapeType&& shape) { + return kernel_context_.Output(index, std::forward(shape)); + } + + // + // Get output count. + // + inline int OutputCount() const { + return kernel_context_.OutputCount(); + } + + // + // Create CPU tensor. + // + template + Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceCPUAllocator(&allocator)); + return {data_type, std::forward(shape), allocator}; + } + + // + // Create GPU tensor. + // + template + Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); + return {data_type, std::forward(shape), allocator}; + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + // + // Push error scope. + // + // This is useful only when "skip_validation" is not set. + // + void PushErrorScope(); + + // + // Pop error scope. + // + // This is useful only when "skip_validation" is not set. + // + Status PopErrorScope(); + + protected: + WebGpuContext& webgpu_context_; + OpKernelContext& kernel_context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc new file mode 100644 index 0000000000000..615ae11175782 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + if (bytes > 0) { + void const* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + context_.BufferManager().MemCpy(static_cast(const_cast(src_data)), + static_cast(dst_data), bytes); + } else { + // copy from CPU to GPU + context_.BufferManager().Upload(const_cast(src_data), static_cast(dst_data), bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); + } + } + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h new file mode 100644 index 0000000000000..f9949576aa60b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class DataTransfer : public IDataTransfer { + public: + DataTransfer(const WebGpuContext& context) : context_{context} {}; + ~DataTransfer() {}; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + + private: + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/docs/Best_Practices.md b/onnxruntime/core/providers/webgpu/docs/Best_Practices.md new file mode 100644 index 0000000000000..d519292b226d0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/docs/Best_Practices.md @@ -0,0 +1,37 @@ +### Always use std::ostringstream to generate shader code if possible + +This helps to the performance of code generation. + +For example: + +```cpp +ss << "var " << name << " = " << value << ";\n"; +``` + +is better than + +```cpp +ss << ("var " + name + " = " + value + ";\n"); +``` + +### Avoid creating template class for kernel using data type as template parameter. + +This basically means that we should define class like this: + +```cpp +class Abs : public WebGpuKernel { + ... +}; +``` + +instead of + +```cpp + +template // T is tensor element type +class Abs : public WebGpuKernel { + ... +}; +``` + +This is because we don't really read and use `Tensor::Data()`. Tensor stores a handle to a WebGPU buffer but not a pointer to the data. Using template for data type only increases the binary size with no real benefit. diff --git a/onnxruntime/core/providers/webgpu/docs/Conventions.md b/onnxruntime/core/providers/webgpu/docs/Conventions.md new file mode 100644 index 0000000000000..fecccc76a4db7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/docs/Conventions.md @@ -0,0 +1,40 @@ +### Use "webgpu" other than "wgpu" in this folder + +This is referring to the naming convention of variables, classes and namespace. + +ORT C API is using "wgpu". + +Let's keep it "webgpu" for this folder for now. I have a very good reason to do so: + +- search for "webgpu" in the code base shows the WebGPU EP related code and search for "wgpu" shows the WebGPU API related code. This helps me easier to find the code I want to look at. + +And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") + +### Use `OStringStream` defined in string_utils.h and macros defined in string_macros.h + +Type `onnxruntime::webgpu::OStringStream` is a type alias of Abseil's OStringStream. It's a lightweight implementation +of `std::ostream`. It's recommended to use `OStringStream` instead of `std::ostringstream` in the code base. + +The macros defined in `string_macros.h` are used to make coding easier: + +```cpp +std::string MyFunction() { + SS(code /* name of the string stream */, 2048 /* initial capacity */); + + code << "var my_var = "; + + // function call style string append. equivalent to: + // + // code << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; + // + SS_APPEND(code, "vec4(", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); + + return SS_GET(code); // return the string +} +``` + +### Use the subfolder for kernel implementation + +Operator implementation source code need to be put under a subfolder like "math"/"nn"/"tensor". + +See folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples. diff --git a/onnxruntime/core/providers/webgpu/docs/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/docs/How_to_Write_WebGPU_EP_Kernel.md new file mode 100644 index 0000000000000..3e501cd957e03 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/docs/How_to_Write_WebGPU_EP_Kernel.md @@ -0,0 +1,212 @@ +# How to Write WebGPU EP Kernel + +This document describes how to write a WebGPU EP kernel for ONNX Runtime. + +The following document will assume the operator name is `Example`, and you will see class `ExampleProgram` and `ExampleOpKernel` in the examples. Replace `Example` with the actual operator name you are implementing. + +Follow the following steps to create a WebGPU kernel: + +## 1. Decide _filename_ and _cateogory_, and create a new file at: + +`onnxruntime/core/providers/webgpu/{category}/{filename}.cc` + +- filename is usually a snake_case_name of the operator name, or a descriptive name if it includes multiple operators (eg. binary_elementwise_ops.cc) +- category is the subfolder representing the operator category (eg. math/nn/controlflow) + + see folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples + +## 2. Declare a new Program class + +### 2.1. The Program class should inherit from Program: + +```c++ +class ExampleProgram : public Program { +// ... +} +``` + +### 2.2. The Program class can define the following information: + +There are 3 types of definitions described as below. All of them are optional. If not specified, it is treated as empty. Those definitions are defined as static const members to ensure they don't depend on any runtime information. + +#### **constants** + +constants are declaration of values that are never changes in the shader code. They are inserted into the WGSL source code like this: + +```wgsl +const A : u32 = 64; +``` + +Use macro `WEBGPU_PROGRAM_DEFINE_CONSTANTS` to define constants in your Program class, or use `WEBGPU_PROGRAM_EXTEND_CONSTANTS` to extend the constants defined in the base class. + +#### **overridable constants** + +overridable constants are similar to constants, but they can be overridden before the compute pipeline is created. Overridable constants may or may not have a default value. They are inserted into the WGSL source code like this: + +```wgsl +override B : u32 = 64; +override C : f32; +``` + +Use macro `WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS` to define overridable constants in your Program class, or use `WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS` to extend the overridable constants defined in the base class. + +#### **uniform definitions** + +uniform definitions are declaration of uniform varables. Their names and type must be defined and cannot be changed. Their values(including length) can be set at runtime. + +Use macro `WEBGPU_PROGRAM_DEFINE_UNIFORMS_VARIABLES` to define uniform definitions in your Program class, or use `WEBGPU_PROGRAM_EXTEND_UNIFORMS_VARIABLES` to extend the uniform definitions defined in the base class. + +### 2.3. The Program class should override the `GenerateShaderCode` method: + +```c++ +Status GenerateShaderCode(ShaderHelper& sh) const override; +``` + +In the function implementation, `sh` is an instance of `ShaderHelper` which provides a set of helper functions to generate shader code. + +Example: + +```c++ +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddVariable(ProgramVariableScope::Input, + "x", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + 1); + const auto& output = shader.AddVariable(ProgramVariableScope::Output, + "y", + ToProgramVariableDataType(Outputs()[0]->GetElementType(), 4), + 1); + shader.AppendImplementation(additional_impl_); + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "let a = ", input.GetByOffset("global_idx"), ";\n", + output.SetByOffset("global_idx", expression_)); + + return Status::OK(); +} +``` + +`ShaderHelper::AddVariable` creates an instace of `ShaderVariable`. The class `ShaderVariable` is similar to `IndicesHelper` in onnxruntime-web. It provides a set of helper functions as value/indices/offset getter/setter. + +`ShaderHelper::AppendImplementation` inserts additional implementation code into the shader code. It will be put before the main function. + +`ShaderHelper::MainFunctionBody` generates the main function body. It accepts arbitrary number of arguments and concatenates them into the main function body. + +### 2.3. Lifecycle of the Program class + +For each calls into the `ExampleOpKernel::ComputeInternal()` method, a new instance of the `ExampleProgram` class should be created as local variable (The detail will be explained in `ExampleOpKernel` as below). The Program instance is destroyed when reaching the end of scope. + +A few functions can be called on the Program instance: + +- call `ProgramBase::Inputs` and `ProgramBase::Outputs` to set input/output tensor info. +- call `ProgramBase::CacheHint` to set the cache hint. +- call `ProgramBase::UniformsVariables`(optional) and `ProgramBase::OverridableConstants`(optional) to set runtime info of uniforms and overridable constants. They need to match the corresponding definitions described above. +- call `ProgramBase::DispatchGroupSize` and `ProgramBase::WorkgroupSize`(optional) to set the dispatch group size and workgroup size. + +## 3. Declare a new OpKernel class + +### 3.1. The OpKernel class should inherit from WebGpuKernel: + +```c++ +class ExampleOpKernel : public WebGpuKernel { +// ... +} +``` + +### 3.2. The OpKernel class should override the `ComputeInternal` method: + +```c++ +Status ComputeInternal(ComputeContext& context) const override; +``` + +Usually, in the implementation, we do 3 things: + +- Create a local variable of the Program class. +- Set a few runtime info of the Program instance. +- Call `context.RunProgram(program)` to run the program and return the status. + +Complicated operators may do more things. Check header files and existing implementations for more details. + +## 4. Register the operator + +Register the operator just like any EP does. Check existing implementations for more details. + +Please note that registration is composed of 2 parts: + +- Use macros like `ONNX_OPERATOR_KERNEL_EX` or `ONNX_OPERATOR_VERSIONED_KERNEL_EX` (or wrap a new macro as what we usually do) to register the operator in kernel source code file. +- Add the operator to onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc + +## 5. Write tests + +This section is WIP. + +## 6. Build and test + +### Build + +use `build.bat --use_webgpu --skip_tests` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. + +### Prepare test data + +Assume `C:\code\onnxruntime` is the root of your onnxruntime repo in all documents below. + +if folder `C:\code\onnxruntime\js\test\data` does not exist, run the following in your onnxruntime repo root: + +``` +cd js +npm ci +npm run prepare-node-tests +``` + +### Run Suite test (temporary: this may change recently) + +to do suite test, find the "test_webgpu.bat" in your build folder (It's usually in `build\Windows\Debug\Debug`). run it for tests: + +``` +# run all tests +test_webgpu.bat + +# run a test list from args +test_webgpu.bat -m=test_abs;test_cos +``` + +To add more tests to the suite list, edit the file at `C:\code\onnxruntime\onnxruntime\test\providers\webgpu\test_webgpu.js`. After editing, run build again otherwise this file will not be copied to the build folder. + +> How does it work? +> +> The `test_webgpu.bat` calls `test_webgpu.js` with nodejs. +> +> The `test_webgpu.js` use the test list (either the suite list or from cmd args) to prepare a temporary folder and creates symbolic links to the test data folder (under `C:\code\onnxruntime\js\test\data`). Then it runs `onnx_test_runner` on the temporary folder. + +### Run single test / debug + +to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: + +``` +onnx_test_runner.exe -v -e webgpu -a 0.001 -t 0.001 -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +``` + +The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. See `onnxruntime\core\providers\webgpu\webgpu_provider_options.h` for available WebGPU EP options. + +The `-a` and `-t` flags are used to specify the absolute and relative tolerance for the test. +- currently the value is set to `0.001` for both absolute and relative tolerance for the WebGPU EP. +- `onnx_test_runner` will try to load file `\testdata\onnx_backend_test_series_overrides.jsonc>` if available to set the default tolerance values. It is recommended to set the tolerance values in the command line to ensure consistent behavior. + > This is why the following command may have different results: + > + > ``` + > C:\code\onnxruntime> build\Windows\Debug\Debug\onnx_test_runner.exe -e webgpu C:\code\onnxruntime\js\test\data\node\opset9\test_asin_example + > ``` + > + > ``` + > C:\code\onnxruntime\build\Windows\Debug\Debug> onnx_test_runner.exe -e webgpu C:\code\onnxruntime\js\test\data\node\opset9\test_asin_example + > ``` + +Some features are useful but if you are troubleshooting and want to rule out the cause, you can: + +- set `storageBufferCacheMode` to `disabled` to disable the storage buffer cache. +- set `-M` and `-A` to disable memory pattern and memory arena. +- set `-j 1` to disable parallel execution (if you have multiple models to test). + +Example: +``` +onnx_test_runner.exe -v -A -M -j 1 -e webgpu -a 0.001 -t 0.001 -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +``` diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc new file mode 100644 index 0000000000000..f704dc4e2cf82 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/generator/range.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +template +Status Range::ComputeInternal(ComputeContext& context) const { + T start = context.Input(0)->Data()[0]; + T limit = context.Input(1)->Data()[0]; + T delta = context.Input(2)->Data()[0]; + + int64_t n = static_cast(ceil((1.0 * (limit - start)) / delta)); + if (n <= 0) { + n = 0; + } + auto* output_tensor = context.Output(0, TensorShape{n}); + if (n == 0) { + return Status::OK(); + } + + uint32_t output_size = SafeInt(n); + RangeProgram program{}; + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + output_size, + *reinterpret_cast(&start), + *reinterpret_cast(&delta), + }); + + return context.RunProgram(program); +} + +Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let value = bitcast(uniforms.start) + output_value_t(global_idx) * bitcast(uniforms.delta);\n" + << output.SetByOffset("global_idx", "value"); + + return Status(); +} + +#define WEBGPU_RANGE_KERNEL(TYPE) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Range, \ + kOnnxDomain, \ + 11, \ + TYPE, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 0) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Range); + +WEBGPU_RANGE_KERNEL(float) +WEBGPU_RANGE_KERNEL(int32_t) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.h b/onnxruntime/core/providers/webgpu/generator/range.h new file mode 100644 index 0000000000000..2f5812bb460ad --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +template +class Range : public WebGpuKernel { + public: + explicit Range(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +class RangeProgram : public Program { + public: + RangeProgram() : Program{"Range"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"start", ProgramUniformVariableDataType::Uint32}, + {"delta", ProgramUniformVariableDataType::Uint32}); +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc new file mode 100644 index 0000000000000..6077ef0499069 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -0,0 +1,310 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/webgpu/math/binary_elementwise_ops.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + // check whether can use element-wise mode. + // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. + // In element-wise mode, no indices calculation is needed. + if (is_lhs_scalar_ || is_rhs_scalar_ || !is_broadcast_) { + // get A data + if (is_lhs_scalar_) { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("global_idx") << ";\n"; + } + + // get B data + if (is_rhs_scalar_) { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("global_idx") << ";\n"; + } + } else { + const auto& c_indices = shader.AddIndices("bcast_indices"); + // check whether can use vectorize mode. + // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode + // can be enabled. + // In vectorize mode, the source data of A and B will be loaded only once to calculate 4 output values. + // Use indices helpers to calculate the offset of A and B. + if (vectorize_) { + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + + shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + // get A data + if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("offset_a") << ");\n"; + } + + // get B data + if (b.NumComponents() == 4) { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("offset_b / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("offset_b") << ");\n"; + } + } else { + // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. + shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a0 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b0 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n" + << "let offset_a1 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b1 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n" + << "let offset_a2 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b2 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n" + << "let offset_a3 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b3 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + + // get A data + shader.MainFunctionBody() << "let a = vec4(" + << a.GetByOffset("offset_a0") << ", " + << a.GetByOffset("offset_a1") << ", " + << a.GetByOffset("offset_a2") << ", " + << a.GetByOffset("offset_a3") << ");\n"; + // get B data + shader.MainFunctionBody() << "let b = vec4(" + << b.GetByOffset("offset_b0") << ", " + << b.GetByOffset("offset_b1") << ", " + << b.GetByOffset("offset_b2") << ", " + << b.GetByOffset("offset_b3") << ");\n"; + } + } + + shader.MainFunctionBody() << c.SetByOffset("global_idx", expression_); + return Status::OK(); +} + +Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { + auto lhs_tensor = context.Input(0); + auto rhs_tensor = context.Input(1); + const auto& lhs_shape = lhs_tensor->Shape(); + const auto& rhs_shape = rhs_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); + auto output_tensor = context.Output(0, output_shape); + int64_t size = output_shape.Size(); + if (size == 0) { + return Status::OK(); + } + + bool is_broadcast = lhs_shape != rhs_shape; + bool is_lhs_scalar = lhs_shape.IsScalar(); + bool is_rhs_scalar = rhs_shape.IsScalar(); + + bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast; + bool a_last_dim_divisible_by_4 = false; + bool b_last_dim_divisible_by_4 = false; + bool shared_dimension_divisible_by_4 = false; + size_t num_shared_dimension = 0; + if (!vectorize) { + // check whether vectorize can be enabled + a_last_dim_divisible_by_4 = lhs_shape.NumDimensions() > 0 && lhs_shape[lhs_shape.NumDimensions() - 1] % 4 == 0; + b_last_dim_divisible_by_4 = rhs_shape.NumDimensions() > 0 && rhs_shape[rhs_shape.NumDimensions() - 1] % 4 == 0; + if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) { + vectorize = true; + } else { + size_t shared_dimension = 1; + for (size_t i = 1; i < output_shape.NumDimensions(); i++) { + size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; + size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; + if (dimA == dimB) { + shared_dimension *= dimA; + num_shared_dimension++; + } else { + break; + } + } + if (shared_dimension % 4 == 0) { + shared_dimension_divisible_by_4 = true; + vectorize = true; + } + } + } + + uint32_t vec_size = SafeInt((size + 3) / 4); + BinaryElementwiseProgram program{kernel_name_, + expression_, + is_broadcast, + is_lhs_scalar, + is_rhs_scalar, + vectorize}; + program + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}); + + if (is_lhs_scalar || is_rhs_scalar || !is_broadcast) { + // Mode Element-wise + // cache hint: "E{is_a_scalar}{is_b_scalar}" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Type, {is_lhs_scalar ? 1 : vec_size}, 4}, + {rhs_tensor, ProgramTensorMetadataDependency::Type, {is_rhs_scalar ? 1 : vec_size}, 4}}) + .CacheHint("E" + std::to_string(is_lhs_scalar) + std::to_string(is_rhs_scalar)); + } else if (vectorize) { + // reshape the dims to merge the shared dimension if available + bool need_reshape = shared_dimension_divisible_by_4 && num_shared_dimension > 1; + TensorShape reshaped_lhs_shape = need_reshape ? lhs_shape.Slice(0, lhs_shape.NumDimensions() - num_shared_dimension + 1) + : lhs_shape; + TensorShape reshaped_rhs_shape = need_reshape ? rhs_shape.Slice(0, rhs_shape.NumDimensions() - num_shared_dimension + 1) + : rhs_shape; + TensorShape reshaped_output_shape = need_reshape ? output_shape.Slice(0, output_shape.NumDimensions() - num_shared_dimension + 1) + : output_shape; + if (need_reshape) { + reshaped_lhs_shape[reshaped_lhs_shape.NumDimensions() - 1] = lhs_shape.SizeFromDimension(lhs_shape.NumDimensions() - num_shared_dimension); + reshaped_rhs_shape[reshaped_rhs_shape.NumDimensions() - 1] = rhs_shape.SizeFromDimension(rhs_shape.NumDimensions() - num_shared_dimension); + reshaped_output_shape[reshaped_output_shape.NumDimensions() - 1] = output_shape.SizeFromDimension(output_shape.NumDimensions() - num_shared_dimension); + } + + if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4) { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type, {(lhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type}); + } + if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4) { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type, {(rhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type}); + } + // Mode Vectorize broadcast + // cache hint: "V{a_rank};{b_rank};{output_rank}" + program + .AddIndices(reshaped_output_shape) + .AddIndices(reshaped_lhs_shape) + .AddIndices(reshaped_rhs_shape) + .CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(), + reshaped_rhs_shape.NumDimensions(), + reshaped_output_shape.NumDimensions()}, + ";")); + } else { + // Mode Broadcast + // cache hint: "B" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddIndices(output_tensor->Shape()) + .CacheHint("B"); + } + + return context.RunProgram(program); +} + +#define WEBGPU_BINARY_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public BinaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : BinaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_BINARY_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_KERNEL_2(OP_TYPE, VERSION, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +WEBGPU_BINARY_IMPL(Add, "a + b") +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 7, 12, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 13, 13, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Add, 14, Add, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Div, "a / b") +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 7, 12, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 13, 13, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Div, 14, Div, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Mul, "a * b") +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 7, 12, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 13, 13, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Mul, 14, Mul, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Sub, "a - b") +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4(a), vec4(b)))") +WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL_2(Pow, 15, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Equal, "vec4(a == b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 7, 10, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 11, 12, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 13, 18, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Equal, 19, Equal, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Greater, "vec4(a > b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 7, 8, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 9, 12, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Greater, 13, Greater, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Less, "vec4(a < b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 7, 8, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 9, 12, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Less, 13, Less, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(GreaterOrEqual, "vec4(a >= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(GreaterOrEqual, 16, GreaterOrEqual, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(LessOrEqual, "vec4(a <= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(LessOrEqual, 16, LessOrEqual, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h new file mode 100644 index 0000000000000..84cbcdf3244d8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class BinaryElementwiseProgram final : public Program { + public: + BinaryElementwiseProgram(const std::string& kernel_name, + const std::string& expression, + const bool is_broadcast, + const bool is_lhs_scalar, + const bool is_rhs_scalar, + const bool vectorize) : Program{kernel_name}, + expression_{expression}, + is_broadcast_{is_broadcast}, + is_lhs_scalar_{is_lhs_scalar}, + is_rhs_scalar_{is_rhs_scalar}, + vectorize_{vectorize} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + std::string expression_; + bool is_broadcast_; + bool is_lhs_scalar_; + bool is_rhs_scalar_; + bool vectorize_; +}; + +class BinaryElementwise : public WebGpuKernel { + public: + BinaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression} {} + + protected: + Status ComputeInternal(ComputeContext& context) const final; + + private: + std::string kernel_name_; + std::string expression_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc new file mode 100644 index 0000000000000..f6d6b18a3d365 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << additional_impl_; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression_); + + return Status::OK(); +} + +Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + SafeInt vec_size = (size + 3) / 4; + UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }); + if (!cache_hint.empty()) { + program.CacheHint(cache_hint); + } + ORT_RETURN_IF_ERROR(ConfigureProgram(context, program)); + return context.RunProgram(program); +} + +#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public UnaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : UnaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION_FROM, VERSION_TO, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + OP_TYPE_AND_CLASS_NAME); + +// +// math +// + +WEBGPU_ELEMENTWISE_IMPL(Abs, "abs(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Neg, "-a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Neg, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Neg, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Floor, "floor(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Floor, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Ceil, "ceil(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Ceil, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Reciprocal, "1.0/a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Reciprocal, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sqrt, "sqrt(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sqrt, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Log, "log(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Log, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) + +class HardSigmoid final : public UnaryElementwise { + public: + HardSigmoid(const OpKernelInfo& info) + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} { + // attr[0] is alpha, attr[1] is beta + info.GetAttrOrDefault("alpha", attr, 0.2f); + info.GetAttrOrDefault("beta", attr + 1, 0.5f); + } + + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.AddUniformVariables({gsl::make_span(attr, 2)}); + return Status::OK(); + } + + protected: + float attr[2]; +}; + +WEBGPU_ELEMENTWISE_KERNEL(HardSigmoid, 6, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sin, "sin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cos, "cos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tan, "tan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Tan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asin, "asin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acos, "acos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Atan, "atan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Atan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sinh, "sinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asinh, "asinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acosh, "acosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) + +#if __APPLE__ +// Metal returns 0 for values >= 1.0. +// Need custom impl to return +inf for 1.0 (by dividing 1 by 0), and NaN for > 1.0 (by dividing 0 by 0) +WEBGPU_ELEMENTWISE_IMPL(Atanh, + "select(" + " select(x_value_t(1.0), x_value_t(0.0), a > x_value_t(1.0)) / x_value_t(0.0)," + " atanh(a)," + " a < x_value_t(1.0))", + "", + ShaderUsage::UseValueTypeAlias) +#else +WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") +#endif +WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Not, "!a") +WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(Not, 1) + +// No longer support Clip < opset 11 (where min and max are attributes) +// +// Use template class for "Clip" because the implementation is significantly different between float16 and float32 +template +class Clip final : public UnaryElementwise { + public: + Clip(const OpKernelInfo& info) + : UnaryElementwise{info, + "Clip", + std::is_same_v ? ClipF16Impl : ClipImpl, + "", ShaderUsage::UseElementTypeAlias} {} + + Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { + const auto* clip_min_tensor = context.Input(1); + const auto* clip_max_tensor = context.Input(2); + + const T attr[] = {clip_min_tensor ? clip_min_tensor->Data()[0] + : std::numeric_limits::lowest(), + clip_max_tensor ? clip_max_tensor->Data()[0] + : std::numeric_limits::max()}; + if constexpr (std::is_same_v) { + // F16: stores span as a single float + float encoded_value = *reinterpret_cast(attr); + program.AddUniformVariable({encoded_value}); + } else { + static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32"); + // stores span as-is + program.AddUniformVariable({gsl::make_span(attr, 2)}); + } + return Status::OK(); + } + + // uniforms.attr is a f32 value. It is encoded as a float for 2 f16 values. + // bitcast>(uniforms.attr)[0] is clip_min, bitcast>(uniforms.attr)[1] is clip_max + constexpr static const char ClipF16Impl[] = "clamp(a, vec4(bitcast>(uniforms.attr)[0]), vec4(bitcast>(uniforms.attr)[1]))"; + + // the size of element of uniforms.attr should be the same as x_element_t. use bitcast to convert between them + // uniforms.attr[0] is clip_min, uniforms.attr[1] is clip_max + constexpr static const char ClipImpl[] = "clamp(a, vec4(bitcast(uniforms.attr[0])), vec4(bitcast(uniforms.attr[1])))"; +}; +#define WEBGPU_CLIP_KERNEL(TYPE) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(Clip, kOnnxDomain, 13, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip); +WEBGPU_CLIP_KERNEL(float) +WEBGPU_CLIP_KERNEL(MLFloat16) + +// +// activation +// + +class LinearUnit : public UnaryElementwise { + public: + LinearUnit(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl, + float default_alpha) + : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} { + info.GetAttrOrDefault("alpha", &alpha_, default_alpha); + } + + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.AddUniformVariables({alpha_}); + return Status::OK(); + } + + protected: + float alpha_; +}; + +#define WEBGPU_LU_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public LinearUnit { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) +WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) + +class Gelu : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) + : UnaryElementwise{info, + "Gelu", + info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, + ShaderUsage::UseValueTypeAlias} { + cache_hint = info.GetAttrOrDefault("approximate", "none"); + } +}; + +WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.attr) * a, a, a >= vec4(0))", "", 0.01f) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.attr))", "", 1.0f) +WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h new file mode 100644 index 0000000000000..70fa81d21f95d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class UnaryElementwiseProgram final : public Program { + public: + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage) + : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size + {"attr", ProgramUniformVariableDataType::Float32}); // float type attribute(s) + // TODO: add u32/i32 attribute(s) if needed + + private: + std::string_view expression_; + std::string_view additional_impl_; + ShaderUsage additional_usage_; +}; + +// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch +// the std::string to std::string_view. This will avoid the cost of copying the string. + +class UnaryElementwise : public WebGpuKernel { + public: + UnaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl = "", + ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl}, + additional_usage_{usage} {} + + protected: + std::string cache_hint; + + Status ComputeInternal(ComputeContext& context) const final; + virtual Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const { + program.AddUniformVariables({{}}); // empty for attribute(s) + return Status::OK(); + } + + private: + std::string kernel_name_; + std::string expression_; + std::string additional_impl_; + ShaderUsage additional_usage_; +}; + +constexpr const char ErfImpl[] = R"( +const r0 = 0.3275911; +const r1 = 0.254829592; +const r2 = -0.284496736; +const r3 = 1.421413741; +const r4 = -1.453152027; +const r5 = 1.061405429; + +fn erf_v(v: x_value_t) -> x_value_t { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +} +)"; + +constexpr const char HardSigmoidImpl[] = R"( +fn hard_sigmoid_v(v: vec4) -> vec4 { + let alpha = x_element_t(uniforms.attr[0]); + let beta_v = vec4(uniforms.attr[1]); + return max(vec4(0.0), + min(vec4(1.0), alpha * v + beta_v)); +} +)"; + +// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) +// https://github.com/gpuweb/gpuweb/issues/4458 +constexpr const char TanhImpl[] = R"( +fn tanh_v(a: x_value_t) -> x_value_t { + let expr = exp(-2 * abs(a)); + return sign(a) * (1 - expr) / (1 + expr); +} +)"; + +constexpr const char EluImpl[] = R"( +fn elu(a: x_element_t) -> x_element_t { + let alpha = x_element_t(uniforms.attr); + return select((exp(a) - 1.0) * alpha, a, a >= 0.0); +} + +fn elu_v(v: vec4) -> vec4 { + return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); +} +)"; + +// default GELU expression, depending on ErfImpl +constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; + +// fast GELU expression, depending on TanhImpl +constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))"; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc new file mode 100644 index 0000000000000..c2f7023526c77 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -0,0 +1,155 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" + +namespace onnxruntime { +namespace webgpu { + +static int GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { + int64_t rank = static_cast(tensor_rank); + if (axis < -rank && axis >= rank) { + ORT_THROW("invalid axis: ", axis); + } + return SafeInt(axis < 0 ? axis + rank : axis); +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scale", ShaderUsage::UseUniform); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + int components = x.NumComponents(); + std::string bias = (has_bias_) ? " + bias[j]" : ""; + std::string simpl1 = (simplified_) ? "" : " - mean * mean"; + std::string simpl2 = (simplified_) ? "" : " - mean"; + + shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") + << "let offset = global_idx * uniforms.norm_size_vectorized;\n" + << "var mean_vector = f32_val_t(0);\n" + << "var mean_square_vector = f32_val_t(0);\n" + << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" + << " let value = f32_val_t(x[h + offset]);\n" + << " mean_vector += value;\n" + << " mean_square_vector += value * value;\n" + << "}\n" + << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n" + << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" + << " let f32input = f32_val_t(x[j + offset]);\n" + << " let f32scale = f32_val_t(scale[j]);\n" + << " output[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" + << "}\n"; + + return Status::OK(); +} + +template +Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* scale = context.Input(1); + const auto* bias = context.Input(2); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + + const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); + const uint32_t norm_count = SafeInt(x_shape.SizeToDimension(axis)); + const int64_t norm_size = x_shape.SizeFromDimension(axis); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = SafeInt((norm_size + components - 1) / components); + + const auto scale_size = scale->Shape().Size(); + const auto bias_size = (bias) ? bias->Shape().Size() : 0; + if (scale_size != norm_size || (bias && bias_size != norm_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale and bias (if provided) must match this. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } + + LayerNormProgram program{bias != nullptr, is_fp16, simplified}; + + program + .CacheHint(simplified) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(norm_count)}, + }) + .AddUniformVariables({ + {static_cast(norm_size)}, + }) + .AddUniformVariables({ + {static_cast(norm_size_vectorized)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h new file mode 100644 index 0000000000000..17a9edbf4dd01 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class LayerNormProgram final : public Program { + public: + LayerNormProgram(bool has_bias, + bool is_fp16, + bool simplified) : Program{"LayerNorm"}, + has_bias_{has_bias}, + is_fp16_{is_fp16}, + simplified_{simplified} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"norm_count", ProgramUniformVariableDataType::Uint32}, + {"norm_size", ProgramUniformVariableDataType::Uint32}, + {"norm_size_vectorized", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool has_bias_; + bool is_fp16_; + bool simplified_; +}; + +template +class LayerNorm final : public WebGpuKernel { + public: + LayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("axis", &axis_, -1); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + info.GetAttrOrDefault("stash_type", &stash_type_, 1); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + int64_t axis_; + float epsilon_; + int64_t stash_type_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc new file mode 100644 index 0000000000000..d1d4c242c4697 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramUniformVariableValue::ProgramUniformVariableValue() + : length{0}, data_type{} {} // representing an empty uniform variable + +ProgramUniformVariableValue::ProgramUniformVariableValue(float value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, &value, sizeof(float)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(uint32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, &value, sizeof(uint32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(int32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, &value, sizeof(int32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(MLFloat16 value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, &value, sizeof(MLFloat16)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, values.data(), sizeof(float), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, values.data(), sizeof(uint32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, values.data(), sizeof(int32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, values.data(), sizeof(MLFloat16), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, + const void* ptr, + size_t element_byte_size, + size_t length /* = 1 */) + : length{length}, data_type{data_type} { + ORT_ENFORCE(length > 0, "number of element of uniform variable must be greater than 0"); + + data.resize(length * element_byte_size); + memcpy(data.data(), ptr, length * element_byte_size); +} + +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) { + os << ProgramUniformVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) { + os << ProgramConstantDataTypeName[std::underlying_type::type(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) { + bool first = true; + if ((dep & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { + os << "Type"; + first = false; + } + if ((dep & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + if (!first) os << "|"; + os << "Rank"; + first = false; + } + if ((dep & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + if (!first) os << "|"; + os << "Shape"; + first = false; + } + if (first) { + os << "None"; + } + + return os; +} + +#ifndef NDEBUG +constexpr std::string_view ProgramVariableDataTypeName[] = { + "f32", // Float32 + "f32x2", // Float32x2 + "f32x4", // Float32x4 + "f16", // Float16 + "f16x2", // Float16x2 + "f16x4", // Float16x4 + "i32", // Int32 + "i32x2", // Int32x2 + "i32x4", // Int32x4 + "u32", // Uint32 + "u32x2", // Uint32x2 + "u32x4", // Uint32x4 + "i64", // Int64 + "u64", // Uint64 + "boolx4", // Boolx4 + "u8x4", // Uint8x4 + "u8x8", // Uint8x8 + "u8x16", // Uint8x16 +}; +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { + os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} +#endif + +int NumberOfComponents(ProgramVariableDataType type) { + switch (type) { + case ProgramVariableDataType::Float32: + case ProgramVariableDataType::Int32: + case ProgramVariableDataType::Uint32: + case ProgramVariableDataType::Int64: + case ProgramVariableDataType::Uint64: + case ProgramVariableDataType::Float16: + return 1; + case ProgramVariableDataType::Float32x2: + case ProgramVariableDataType::Int32x2: + case ProgramVariableDataType::Uint32x2: + case ProgramVariableDataType::Float16x2: + return 2; + case ProgramVariableDataType::Float32x4: + case ProgramVariableDataType::Int32x4: + case ProgramVariableDataType::Uint32x4: + case ProgramVariableDataType::Float16x4: + case ProgramVariableDataType::Boolx4: + case ProgramVariableDataType::Uint8x4: + return 4; + case ProgramVariableDataType::Uint8x8: + return 8; + case ProgramVariableDataType::Uint8x16: + return 16; + default: + return -1; + } +} + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { + if (component == 1) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return ProgramVariableDataType::Int64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return ProgramVariableDataType::Uint64; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 2) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32x2; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 4) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return ProgramVariableDataType::Boolx4; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 8) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x8; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 16) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x16; + default: + return ProgramVariableDataType::InvalidType; + } + } else { + return ProgramVariableDataType::InvalidType; + } +} + +namespace { +TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { + ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, + "Cannot reduce shape ", shape.ToString(), " by component=", component); + TensorShape reduced_shape = shape; + reduced_shape[reduced_shape.NumDimensions() - 1] /= component; + return reduced_shape; +} +} // namespace + +ProgramInput::ProgramInput(const Tensor* tensor) : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramOutput::ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) + : name_{name}, + metadata_{metadata}, + dispatch_group_size_x_{0}, + dispatch_group_size_y_{0}, + dispatch_group_size_z_{0}, + workgroup_size_x_{0}, + workgroup_size_y_{0}, + workgroup_size_z_{0} { +} + +ProgramBase& ProgramBase::AddInput(ProgramInput&& input) { + inputs_.emplace_back(input); + return *this; +} + +ProgramBase& ProgramBase::AddInputs(std::initializer_list inputs) { + inputs_.insert(inputs_.end(), inputs.begin(), inputs.end()); + return *this; +} + +ProgramBase& ProgramBase::AddOutput(ProgramOutput&& output) { + outputs_.emplace_back(output); + return *this; +} + +ProgramBase& ProgramBase::AddOutputs(std::initializer_list outputs) { + outputs_.insert(outputs_.end(), outputs.begin(), outputs.end()); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { + return SetDispatchGroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y) { + return SetDispatchGroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { + dispatch_group_size_x_ = x; + dispatch_group_size_y_ = y; + dispatch_group_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x) { + return SetWorkgroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y) { + return SetWorkgroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { + workgroup_size_x_ = x; + workgroup_size_y_ = y; + workgroup_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::AddUniformVariable(ProgramUniformVariableValue&& variable) { + variables_.emplace_back(variable); + return *this; +} + +ProgramBase& ProgramBase::AddUniformVariables(std::initializer_list variables) { + variables_.insert(variables_.end(), variables.begin(), variables.end()); + return *this; +} + +ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list overridable_constants) { + overridable_constants_.insert(overridable_constants_.end(), overridable_constants.begin(), overridable_constants.end()); + return *this; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h new file mode 100644 index 0000000000000..1562ec158b40a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.h @@ -0,0 +1,605 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include + +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webgpu { +class ShaderHelper; +class ComputeContext; +class WebGpuContext; + +// data type of uniform variable +enum class ProgramUniformVariableDataType { + Float32, + Float16, + Uint32, + Int32, +}; +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType); + +constexpr size_t ProgramUniformVariableDataTypeSize[] = {sizeof(float), sizeof(uint16_t), sizeof(uint32_t), sizeof(int32_t)}; + +constexpr std::string_view ProgramUniformVariableDataTypeName[] = {"f32", "f16", "u32", "i32"}; + +// represents a runtime value of a uniform variable +struct ProgramUniformVariableValue { + ProgramUniformVariableValue(); // representing an empty uniform variable + ProgramUniformVariableValue(float value); + ProgramUniformVariableValue(uint32_t value); + ProgramUniformVariableValue(int32_t value); + ProgramUniformVariableValue(MLFloat16 value); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + + size_t length; + ProgramUniformVariableDataType data_type; + std::vector data; + + private: + ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, const void* ptr, size_t element_byte_size, size_t length = 1); +}; + +// represents a uniform variable definition +struct ProgramUniformVariableDefinition { + constexpr ProgramUniformVariableDefinition(std::string_view name, ProgramUniformVariableDataType data_type) + : name{name}, data_type{data_type} {} + + std::string_view name; + ProgramUniformVariableDataType data_type; +}; + +// data type of constant +enum class ProgramConstantDataType { + Float32, + Float16, + Uint32, + Int32, + Bool +}; +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType); + +constexpr std::string_view ProgramConstantDataTypeName[] = {"f32", "f16", "u32", "i32", "bool"}; + +// represents a constant in a program +struct ProgramConstant { + constexpr ProgramConstant(std::string_view name, float value) : name{name}, type{ProgramConstantDataType::Float32}, f32{value} {} + constexpr ProgramConstant(std::string_view name, uint32_t value) : name{name}, type{ProgramConstantDataType::Uint32}, u32{value} {} + constexpr ProgramConstant(std::string_view name, int32_t value) : name{name}, type{ProgramConstantDataType::Int32}, i32{value} {} + constexpr ProgramConstant(std::string_view name, MLFloat16 value) : name{name}, type{ProgramConstantDataType::Float16}, f16{value} {} + constexpr ProgramConstant(std::string_view name, bool value) : name{name}, type{ProgramConstantDataType::Bool}, boolean{value} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; +}; + +// represents a runtime value of an overridable constant +struct ProgramOverridableConstantValue { + constexpr ProgramOverridableConstantValue() : type{}, u32{}, has_value{false} {} // representing not overriding + constexpr ProgramOverridableConstantValue(float value) : type{ProgramConstantDataType::Float32}, f32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(uint32_t value) : type{ProgramConstantDataType::Uint32}, u32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(int32_t value) : type{ProgramConstantDataType::Int32}, i32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(MLFloat16 value) : type{ProgramConstantDataType::Float16}, f16{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(bool value) : type{ProgramConstantDataType::Bool}, boolean{value}, has_value{true} {} + + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_value; +}; + +// represents an overridable constant definition. may or may not have a default value. +struct ProgramOverridableConstantDefinition { + constexpr ProgramOverridableConstantDefinition(std::string_view name, ProgramConstantDataType type) + : name{name}, type{type}, u32{}, has_default_value{false} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, float value) + : name{name}, type{ProgramConstantDataType::Float32}, f32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, uint32_t value) + : name{name}, type{ProgramConstantDataType::Uint32}, u32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, int32_t value) + : name{name}, type{ProgramConstantDataType::Int32}, i32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, MLFloat16 value) + : name{name}, type{ProgramConstantDataType::Float16}, f16{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, bool value) + : name{name}, type{ProgramConstantDataType::Bool}, boolean{value}, has_default_value{true} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_default_value; +}; + +// represents whether the program shader depends on the type, rank, or shape of an input/output tensor +enum class ProgramTensorMetadataDependency : int { + None = 0, + Type = 1, + Rank = 2, + Shape = 4, + TypeAndRank = Type | Rank, + TypeAndShape = Type | Shape, +}; +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency); + +inline ProgramTensorMetadataDependency operator|(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a | (int&)b); +} +inline ProgramTensorMetadataDependency operator&(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a & (int&)b); +} +inline ProgramTensorMetadataDependency& operator|=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a |= (int&)b); +} +inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b); +} + +constexpr SafeInt WORKGROUP_SIZE = 64; + +// data type of variable +// +// this is not a full list of all possible data types in shader programs. +// it only includes what are used in WebGPU EP. +enum class ProgramVariableDataType { + InvalidType = -1, + Float32, + Float32x2, + Float32x4, + Float16, + Float16x2, + Float16x4, + Int32, + Int32x2, + Int32x4, + Uint32, + Uint32x2, + Uint32x4, + Int64, + Uint64, + Boolx4, + Uint8x4, + Uint8x8, + Uint8x16 +}; +#ifndef NDEBUG +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); +#endif + +int NumberOfComponents(ProgramVariableDataType type); + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); + +struct ProgramInput { + ProgramInput(const Tensor* tensor); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); + + const Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +struct ProgramOutput { + ProgramOutput(Tensor* tensor); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); + + Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +enum class ValidationMode { + Disabled = 0, + WGPUOnly, + Basic, + Full +}; + +namespace details { +class ProgramWrapper; +} + +struct ProgramMetadata { + gsl::span constants; + gsl::span overridable_constants; + gsl::span uniform_variables; +}; + +class ProgramBase { + public: + // + // chain-style methods for setting properties + // + + // set the cache hint for the program + template + ProgramBase& CacheHint(T&&... hints) { + cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(hints)...), "|"); + return *this; + } + + // add a program input + ProgramBase& AddInput(ProgramInput&& input); + // add multiple program inputs + ProgramBase& AddInputs(std::initializer_list inputs); + // add a program output + ProgramBase& AddOutput(ProgramOutput&& output); + // add multiple program outputs + ProgramBase& AddOutputs(std::initializer_list outputs); + // add a program variable for indices + ProgramBase& AddIndices(const TensorShape& shape); + // add a program variable for indices + ProgramBase& AddIndices(TensorShape&& shape); + + // set the size of dispatch groups. Y and Z are 1 if not specified. + ProgramBase& SetDispatchGroupSize(uint32_t x); + // set the size of dispatch groups. Z is 1 if not specified. + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y); + // set the size of dispatch groups. + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); + + // set the size of a workgroup grid. Y and Z are 1 if not specified. + ProgramBase& SetWorkgroupSize(uint32_t x); + // set the size of a workgroup grid. Z is 1 if not specified. + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y); + // set the size of a workgroup grid. + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z); + + // add a uniform variable. + // + // the specified uniform variable should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& AddUniformVariable(ProgramUniformVariableValue&& variable); + // add multiple uniform variables. + // + // the specified uniform variables should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& AddUniformVariables(std::initializer_list variables); + + // set the overridable constants + // + // the specified overridable constants should match the overridable constant definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS. + ProgramBase& SetOverridableConstants(std::initializer_list overridable_constants); + + // + // shader code generation + // + + virtual Status GenerateShaderCode(ShaderHelper& shader) const = 0; + + // + // Properties Getters + // + + inline const std::string& Name() const { return name_; } + inline const ProgramMetadata& Metadata() const { return metadata_; } + inline const std::string& CacheHint() const { return cache_hint_; } + inline const std::vector& Inputs() const { return inputs_; } + inline const std::vector& Outputs() const { return outputs_; } + inline const std::vector& Indices() const { return indices_; } + inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } + inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } + inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } + inline uint32_t WorkgroupSizeX() const { return workgroup_size_x_; } + inline uint32_t WorkgroupSizeY() const { return workgroup_size_y_; } + inline uint32_t WorkgroupSizeZ() const { return workgroup_size_z_; } + inline const std::vector& UniformVariables() const { return variables_; } + inline const std::vector& OverridableConstants() const { return overridable_constants_; } + + protected: + virtual ~ProgramBase() = default; + + private: + // Make the constructor private to prevent direct instantiation or inheritance from this class + // Use the Program template class as base class to create a new program class + explicit ProgramBase(std::string_view name, ProgramMetadata&& metadata); + + std::string name_; + ProgramMetadata metadata_; + + std::string cache_hint_; + std::vector inputs_; + std::vector outputs_; + std::vector indices_; + + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + uint32_t workgroup_size_x_; + uint32_t workgroup_size_y_; + uint32_t workgroup_size_z_; + + std::vector variables_; + std::vector overridable_constants_; + + friend class details::ProgramWrapper; +}; + +namespace details { +// class ProgramWrapper is for accessing private constructor of ProgramBase. +// only ProgramWrapper can access the constructor of ProgramBase because ProgramWrapper is the only friend class of +// ProgramBase. This design is used to prevent direct instantiation or inheritance from ProgramBase. +class ProgramWrapper : public ProgramBase { + protected: + template + ProgramWrapper(Args&&... args) : ProgramBase{std::forward(args)...} {} +}; + +#if defined(ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK) +#error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined" +#endif + +#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ + private: \ + template \ + static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ + template \ + static auto test_has_##identifier(...)->std::false_type; \ + \ + template ::value && /* - is a const std::array */ \ + std::is_const_v && /* - has "const" modifier */ \ + !std::is_member_pointer_v>> /* - is static */ \ + static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ + template \ + static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ + \ + public: \ + static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ + static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type(0))::value + +// the following template class checks whether the type is a const std::array +template +struct is_const_std_array : std::false_type {}; +template +struct is_const_std_array> : std::true_type {}; + +// the following template class checks whether certain static members exist in the derived class (SFINAE) +template +class DerivedProgramClassTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(constants, ProgramConstant); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(overridable_constants, ProgramOverridableConstantDefinition); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(uniform_variables, ProgramUniformVariableDefinition); +}; + +// compile-time tests for the type check +// +// TODO: move this to test folder +namespace test { + +template +class TestTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +}; + +struct TestClass_Empty {}; +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_0 { + int b; +}; +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_1 { + int a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_2 { + const int a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_0 { + const int a[2]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_1 { + static constexpr int a[] = {0}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_2 { + static int a[]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_3 { + static const int a[]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_0 { + std::array a = {1}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_1 { + static constexpr std::array a = {1, 2}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_2 { + static const std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_3 { + static constexpr const std::array a = {1, 2, 3, 4}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_4 { + static std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +} // namespace test + +#undef ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK + +} // namespace details + +template +class Program : public details::ProgramWrapper { + public: + template + Program(Args&&... args) : details::ProgramWrapper{std::forward(args)..., GetMetadata()} {} + + static ProgramMetadata GetMetadata() { + ProgramMetadata metadata; + if constexpr (details::DerivedProgramClassTypeCheck::has_constants) { + constexpr const ProgramConstant* ptr = T::constants.data(); + constexpr size_t len = T::constants.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_constants_with_correct_type, + "Derived class of \"Program\" has member \"constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_CONSTANTS() to declare constants."); + + metadata.constants = {ptr, len}; + } else { + metadata.constants = {}; + } + + if constexpr (details::DerivedProgramClassTypeCheck::has_overridable_constants) { + constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data(); + constexpr size_t len = T::overridable_constants.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type, + "Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS() to declare overridable constants."); + + metadata.overridable_constants = {ptr, len}; + } else { + metadata.overridable_constants = {}; + } + + if constexpr (details::DerivedProgramClassTypeCheck::has_uniform_variables) { + constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data(); + constexpr size_t len = T::uniform_variables.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type, + "Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() or WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES() to declare uniform variables."); + + metadata.uniform_variables = {ptr, len}; + } else { + metadata.uniform_variables = {}; + } + + return metadata; + } +}; + +namespace details { +// helper function to convert a C-style array to std::array +// +// This is basically the same as std::to_array in C++20. +// +template +constexpr auto _to_std_array_impl(T (&arr)[N], std::index_sequence) -> std::array, N> { + return {{arr[Idx]...}}; +} + +template +constexpr auto _to_std_array(T (&arr)[N]) -> std::array, N> { + return _to_std_array_impl(arr, std::make_index_sequence{}); +} + +// helper function to concatenate a std::array and a C-style array to a std::array +// +template +constexpr std::array, L + R> _concat2_impl(const std::array& lhs, + T (&rhs)[R], + std::index_sequence, + std::index_sequence) { + return {{lhs[IdxL]..., rhs[IdxR]...}}; +} + +template +constexpr std::array, L + R> _concat2(const std::array& lhs, T (&rhs)[R]) { + return _concat2_impl(lhs, rhs, std::make_index_sequence{}, std::make_index_sequence{}); +} + +} // namespace details +#define WEBGPU_PROGRAM_DEFINE_(identifier, T, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::details::_to_std_array(identifier##_own) + +#define WEBGPU_PROGRAM_EXTEND_(identifier, T, BASE, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::details::_concat2(BASE::identifier, identifier##_own) + +#define WEBGPU_PROGRAM_DEFINE_CONSTANTS(...) \ + WEBGPU_PROGRAM_DEFINE_(constants, onnxruntime::webgpu::ProgramConstant, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(constants, onnxruntime::webgpu::ProgramConstant, BASE, __VA_ARGS__) + +#define WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS(...) \ + WEBGPU_PROGRAM_DEFINE_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, BASE, __VA_ARGS__) + +#define WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(...) \ + WEBGPU_PROGRAM_DEFINE_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, BASE, __VA_ARGS__) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc new file mode 100644 index 0000000000000..a5c21563dbfcd --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/program_cache_key.h" + +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +// macro "D" - append to the ostream only in debug build +#ifndef NDEBUG // if debug build +#define D(str) << str +#else +#define D(str) +#endif + +namespace { +// append the info of an input or output to the cachekey +void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, + bool& first) { + if (first) { + first = false; + } else { + ss << '|'; + } + + if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { +#ifndef NDEBUG // if debug build + ss << var_type; +#else + ss << static_cast(var_type); +#endif + ss << ';'; + } + + if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + ss D("Dims=") << tensor.Shape().ToString(); + } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + ss D("Rank=") << tensor.Shape().NumDimensions(); + } +} +} // namespace + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { + SS(ss, kStringInitialSizeCacheKey); + + // final key format: + // =[]:::: + // + // = ||... + // = ,, + // = + // = ||... + // = + // = ||... + // = ; + ss << program.Name(); + + // append custom cache hint if any + if (auto& hint = program.CacheHint(); !hint.empty()) { + ss << '[' D("CacheHint=") << hint << ']'; + } + + // append workgroup size if overridden + if (auto x = program.WorkgroupSizeX(), y = program.WorkgroupSizeY(), z = program.WorkgroupSizeZ(); + x != 0 || y != 0 || z != 0) { + ss << ":" D("WorkgroupSize="); + // only append non-zero values. zero values are considered as use default + if (x > 0) { + ss << x; + } + ss << ","; + if (y > 0) { + ss << y; + } + ss << ","; + if (z > 0) { + ss << z; + } + } + + ss << ":" D("DispatchDim=") << (is_1d_dispatch ? "1" : "3"); + ss << ":" D("UniformSizes="); + bool first = true; + for (const auto& uniform : program.UniformVariables()) { + if (first) { + first = false; + } else { + ss << "|"; + } + if (uniform.length > 0) { + ss << uniform.length; + } + } + + ss << ":" D("Inputs="); + first = true; + for (const auto& input : program.Inputs()) { + AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first); + } + + ss << ":" D("Outputs="); + first = true; + for (const auto& output : program.Outputs()) { + AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); + } + + return SS_GET(ss); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.h b/onnxruntime/core/providers/webgpu/program_cache_key.h new file mode 100644 index 0000000000000..22ba19ebd0f25 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc new file mode 100644 index 0000000000000..297d211ff1262 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/common.h" +#include "core/common/safeint.h" + +#include "core/common/common.h" +#include "core/common/logging/logging.h" + +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks) + : name{program.Name()}, + compute_pipeline{compute_pipeline}, + shape_uniform_ranks{shape_uniform_ranks} {} + +Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { + ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); + + auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; + if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { + auto size = static_cast(x) * static_cast(y) * static_cast(z); + SafeInt dispatch_avg = std::ceil(std::sqrt(size)); + if (dispatch_avg > limit_per_dimension) { + dispatch_avg = std::ceil(std::cbrt(size)); + ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); + x = y = z = dispatch_avg; + } else { + x = y = dispatch_avg; + z = 1; + } + } + return Status::OK(); +} + +Status ProgramManager::Build(const ProgramBase& program, + const ProgramMetadata& program_metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniform_ranks) const { + ShaderHelper shader_helper{program, + program_metadata, + device_, + limits_, + normalized_dispatch_x, + normalized_dispatch_y, + normalized_dispatch_z}; + ORT_RETURN_IF_ERROR(shader_helper.Init()); + + ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices()); + + // code is a large std::string that contains the final shader code + std::string code; + ORT_RETURN_IF_ERROR(shader_helper.GenerateSourceCode(code, shape_uniform_ranks)); + + LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] Start ===\n\n" + << code + << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] End ===\n"; + + wgpu::ShaderModuleWGSLDescriptor wgsl_descriptor{}; + wgsl_descriptor.code = code.c_str(); + + wgpu::ShaderModuleDescriptor descriptor{}; + descriptor.nextInChain = &wgsl_descriptor; + + auto shader_module = device_.CreateShaderModule(&descriptor); + + // TODO: a new cache hierarchy for constants. + // + // Explaination: + // Currently, we use Uniforms for dynamic data. This helps to reduce the number of program artifacts. + // + // "dynamic data" here means the data the determined at runtime, such as the shape of the input tensor. + // + // However, some programs may not necessarily depend on dynamic data. For example, "Clip" may depend on the value of "min" and "max". + // We are using uniforms for the value of "min" and "max" in the current implementation, but usually "min" and "max" are determined + // earlier because they are either from Attributes or from the initializers of the model. + // + // Questions: + // - can we use one instance of ShaderModule to create multiple ComputePipeline? + // - is there any benefit to do so compared to the current implementation? + // + + // process overridable constants if available + size_t constant_count = program.OverridableConstants().size(); + + // making a copy of the constant names is required because they are stored as std::string_view in the program + // metadata. A value of std::string_view is not guaranteed to be a C-stlye string (null-terminated) and hence + // cannot be used directly in the WebGPU API (which expects a const char*). + std::vector constant_names; + constant_names.reserve(constant_count); + std::vector constant_entries; + constant_entries.reserve(constant_count); + for (size_t i = 0; i < constant_count; ++i) { + const auto& constant_override = program.OverridableConstants()[i]; + const auto& constant_def = program_metadata.overridable_constants[i]; + + if (constant_override.has_value) { + double value = 0; + switch (constant_override.type) { + case ProgramConstantDataType::Bool: + value = constant_override.boolean ? 1 : 0; + break; + case ProgramConstantDataType::Float16: + // convert f16(MLFloat16) -> f32(float) -> f64(double) + // because the value of a constant must be a double in WebGPU API, it is expensive to use f16 overridable constants. + value = constant_override.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + value = constant_override.f32; + break; + case ProgramConstantDataType::Int32: + value = constant_override.i32; + break; + case ProgramConstantDataType::Uint32: + value = constant_override.u32; + break; + } + + const auto& name_string = constant_names.emplace_back(constant_def.name); + wgpu::ConstantEntry entry{}; + entry.key = name_string.c_str(); + entry.value = value; + constant_entries.push_back(std::move(entry)); + } + } + + wgpu::ProgrammableStageDescriptor compute_stage{}; + compute_stage.module = shader_module; + compute_stage.entryPoint = "main"; + if (!constant_entries.empty()) { + compute_stage.constants = constant_entries.data(); + compute_stage.constantCount = constant_entries.size(); + } + + wgpu::ComputePipelineDescriptor pipeline_descriptor{}; + pipeline_descriptor.compute = compute_stage; +#ifndef NDEBUG // if debug build + pipeline_descriptor.label = program.Name().c_str(); +#endif + + compute_pipeline = device_.CreateComputePipeline(&pipeline_descriptor); + + return Status(); +} + +const ProgramArtifact* ProgramManager::Get(const std::string& key) const { + auto result = programs_.find(key); + if (result != programs_.end()) { + return &result->second; + } + + return nullptr; +} + +const ProgramArtifact* ProgramManager::Set(const std::string& key, ProgramArtifact&& program) { + return &(programs_.emplace(key, std::move(program)).first->second); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h new file mode 100644 index 0000000000000..eded1cfa17970 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { + +class ProgramArtifact { + public: + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks); + + const std::string name; + const wgpu::ComputePipeline compute_pipeline; + const std::vector shape_uniform_ranks; + + ProgramArtifact(ProgramArtifact&&) = default; + ProgramArtifact& operator=(ProgramArtifact&&) = delete; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact); +}; + +class ProgramManager { + public: + ProgramManager(const wgpu::Device& device, const wgpu::Limits& limits) : device_(device), limits_(limits) {} + + Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const; + + Status Build(const ProgramBase& program, + const ProgramMetadata& metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniform_ranks) const; + const ProgramArtifact* Get(const std::string& key) const; + const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); + + private: + std::unordered_map programs_; + const wgpu::Device& device_; + const wgpu::Limits& limits_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc new file mode 100644 index 0000000000000..5685494556248 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -0,0 +1,530 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/string_utils.h" +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +ShaderHelper::ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z) + : device_{device}, + limits_{limits}, + dispatch_group_size_x_{dispatch_group_size_x}, + dispatch_group_size_y_{dispatch_group_size_y}, + dispatch_group_size_z_{dispatch_group_size_z}, + program_{program}, + program_metadata_{program_metadata}, + additional_implementation_ss_{&additional_implementation_}, + body_ss_{&body_} {} + +Status ShaderHelper::Init() { + // dispatch group size is normalized so no need to validate it here + + // validate workgroup size + auto workgroup_size_x = program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ(); + + ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && + workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && + workgroup_size_z <= limits_.maxComputeWorkgroupSizeZ, + "Workgroup size exceeds the maximum allowed size [", + limits_.maxComputeWorkgroupSizeX, ", ", + limits_.maxComputeWorkgroupSizeY, ", ", + limits_.maxComputeWorkgroupSizeZ, "]"); + + ORT_RETURN_IF_NOT(workgroup_size_x * workgroup_size_y * workgroup_size_z <= limits_.maxComputeInvocationsPerWorkgroup, + "Workgroup size exceeds the maximum allowed invocations ", limits_.maxComputeInvocationsPerWorkgroup); + + // init body string stream + bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; + body_.reserve(4096); + additional_implementation_.reserve(1024); + + // append header for main function so it is ready for user to append main function body + body_ss_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" + "fn main(@builtin(global_invocation_id) global_id : vec3,\n" + " @builtin(workgroup_id) workgroup_id : vec3,\n" + " @builtin(local_invocation_index) local_idx : u32,\n" + " @builtin(local_invocation_id) local_id : vec3"; + if (!is_1d_dispatch) { + body_ss_ << ",\n" + " @builtin(num_workgroups) num_workgroups : vec3"; + } + body_ss_ << ") {\n"; + if (is_1d_dispatch) { + body_ss_ << " let global_idx = global_id.x;\n" + " let workgroup_idx = workgroup_id.x;\n"; + } else { + body_ss_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + } + + return Status::OK(); +} + +const ShaderVariableHelper& ShaderHelper::AddInput(const std::string& name, ShaderUsage usage) { + const size_t input_index = input_vars_.size(); + ORT_ENFORCE(input_index < program_.Inputs().size(), + "Too many inputs in the program (", program_.Inputs().size(), ")"); + + const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape + : program_.Inputs()[input_index].tensor->Shape(); + return AddVariableImpl(true, name, usage, dims); +} + +const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, ShaderUsage usage) { + const size_t output_index = output_vars_.size(); + ORT_ENFORCE(output_index < program_.Outputs().size(), + "Too many outputs in the program (", program_.Outputs().size(), ")"); + + const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape + : program_.Outputs()[output_index].tensor->Shape(); + return AddVariableImpl(false, name, usage, dims); +} + +const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, bool use_uniform) { + const size_t indices_index = indices_vars_.size(); + return *indices_vars_.emplace_back( + std::make_unique(name, + ProgramVariableDataType::InvalidType, + use_uniform ? ShaderUsage::UseUniform : ShaderUsage::None, + program_.Indices()[indices_index])); +} + +#ifndef NDEBUG // if debug build +namespace { +// Validate if the tensor element type matches the program variable data type +Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || + var_type == ProgramVariableDataType::Float32x2 || + var_type == ProgramVariableDataType::Float32x4, + "Unexpected program variable type ", int(var_type), " for float32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 || + var_type == ProgramVariableDataType::Float16x2 || + var_type == ProgramVariableDataType::Float16x4, + "Unexpected program variable type ", int(var_type), " for float16 tensor"); + + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || + var_type == ProgramVariableDataType::Int32x2 || + var_type == ProgramVariableDataType::Int32x4, + "Unexpected program variable type ", int(var_type), " for int32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 || + var_type == ProgramVariableDataType::Uint32x2 || + var_type == ProgramVariableDataType::Uint32x4, + "Unexpected program variable type ", int(var_type), " for uint32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int64, + "Unexpected program variable type ", int(var_type), " for int64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint64, + "Unexpected program variable type ", int(var_type), " for uint64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Boolx4, + "Unexpected program variable type ", int(var_type), " for bool tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint8x4 || + var_type == ProgramVariableDataType::Uint8x8 || + var_type == ProgramVariableDataType::Uint8x16, + "Unexpected program variable type ", int(var_type), " for uint8 tensor"); + break; + default: + ORT_RETURN_IF(true, "Unsupported data type: ", element_type); + // todo: add int4/uint4 + } + return Status::OK(); +} + +// Validate if the number of components and override shape match the original shape +Status ValidateVariableShape(const TensorShape& origin_shape, + bool use_override_shape, + const TensorShape& override_shape, + int num_components) { + if (use_override_shape) { + // if override shape specified, assert override_size == ceil( origin_size / 4 ) + ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(), + "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); + } + + return Status::OK(); +} + +// Validate if the dependency and variable usage match +Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderUsage usage, bool is_input) { + bool dependency_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool dependency_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + bool dependency_type = (dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type; + + // if dependency is already set for shape, it is no need to set for rank. + ORT_RETURN_IF(dependency_rank && dependency_shape, + "Dependency cannot set for both \"Rank\" and \"Shape\"."); + + // if dependency is set for shape, it's already part of the shader cache. no need to use uniform. + ORT_RETURN_IF(dependency_shape && (usage & ShaderUsage::UseUniform), + "Dependency is set for \"Shape\", using uniform for shape is not allowed."); + + // for input variable, check is more strict. + // this is because usually output shape is determined by the existing information, which is already part of the shader cache. + if (is_input) { + // if dependency is not set for type, should not use type alias for element and value. + // storage type is always used. so setting not depending on type is at user's own risk. + ORT_RETURN_IF(!dependency_type && (usage & (ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias)), + "Input dependency is not set for \"Type\", but type alias for element type or value type is used."); + + // if dependency is not set for rank and shape, the shader should not use shape and stride. + ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderUsage::UseShapeAndStride), + "Input dependency is set for neither \"Rank\" nor \"Shape\", but variable shape and stride is used."); + } + + return Status::OK(); +} +} // namespace + +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), + input.use_override_shape, + input.use_override_shape ? input.override_shape : input.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); + + return Status::OK(); +} +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), + output.use_override_shape, + output.use_override_shape ? output.override_shape : output.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); + + return Status::OK(); +} + +#endif // NDEBUG + +const ShaderVariableHelper& ShaderHelper::AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims) { + ORT_ENFORCE(input_vars_.size() + output_vars_.size() < limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); + + ProgramVariableDataType type = ProgramVariableDataType::InvalidType; + auto& vars = is_input ? input_vars_ : output_vars_; + + if (is_input) { + const auto& input = program_.Inputs()[vars.size()]; + type = input.var_type; + } else { + const auto& output = program_.Outputs()[vars.size()]; + type = output.var_type; + } + + const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); + return *var; +} + +Status ShaderHelper::ValidateShapeForInputs() const { + // Validate input as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(input_vars_.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", input_vars_.size(), ", Program: ", program_.Inputs().size()); + for (size_t i = 0; i < input_vars_.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate input shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars_[i])); +#endif + + // check input dependencies with actual usages. + auto usage = input_vars_[i]->usage_; + auto dependency = program_.Inputs()[i].dependency; + bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + ORT_RETURN_IF_NOT((use_rank || input_vars_[i]->rank_ < 2) && !use_shape, + "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); + // If you want neither hard-coded shape nor shape uniform, use a flattened shape (rank=1). + // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + } + } + } + return Status::OK(); +} + +Status ShaderHelper::ValidateShapeForOutputs() const { + // Validate output as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(output_vars_.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", output_vars_.size(), ", Program: ", program_.Outputs().size()); + + for (size_t i = 0; i < output_vars_.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate output shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars_[i])); +#endif + + // check output dependencies with actual usages. + auto usage = output_vars_[i]->usage_; + auto dependency = program_.Outputs()[i].dependency; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not + // necessarily a part of the cache key. + ORT_RETURN_IF_NOT(!use_shape, + "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + } + } + } + return Status::OK(); +} + +Status ShaderHelper::ValidateIndices() const { + ORT_RETURN_IF_NOT(indices_vars_.size() == program_.Indices().size(), + "Mismatched indices variable count. Shader: ", indices_vars_.size(), ", Program: ", program_.Indices().size()); + + return Status::OK(); +} + +Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { + SS(ss, kStringInitialSizeShaderSourceCode); + + // + // Section feature enabling + // + if (std::any_of(program_.Inputs().begin(), + program_.Inputs().end(), + [](const ProgramInput& input) { + return input.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + }) || + std::any_of(program_.Outputs().begin(), + program_.Outputs().end(), + [](const ProgramOutput& output) { + return output.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + })) { + ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); + ss << "enable f16;\n"; + if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { + ss << "enable subgroups_f16;\n"; + } + } + if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { + ss << "enable subgroups;\n"; + } + + // + // Section constants + // + ss << "const workgroup_size_x: u32 = " << (program_.WorkgroupSizeX() == 0 ? uint32_t(WORKGROUP_SIZE) : program_.WorkgroupSizeX()) + << ";\nconst workgroup_size_y: u32 = " << (program_.WorkgroupSizeY() == 0 ? uint32_t(1) : program_.WorkgroupSizeY()) + << ";\nconst workgroup_size_z: u32 = " << (program_.WorkgroupSizeZ() == 0 ? uint32_t(1) : program_.WorkgroupSizeZ()) + << ";\n"; + + for (const auto& constant : program_metadata_.constants) { + ss << "const " << constant.name << ": " << constant.type << " = "; + WriteConstantValue(ss, constant); + ss << ";\n"; + } + + size_t override_constant_count = program_metadata_.overridable_constants.size(); + for (size_t i = 0; i < override_constant_count; ++i) { + // size and type are previously checked to match + const auto& constant_def = program_metadata_.overridable_constants[i]; + const auto& constant_override = program_.OverridableConstants()[i]; + + ss << "override " << constant_def.name << ": " << constant_def.type << " = "; + if (constant_override.has_value) { + WriteConstantValue(ss, constant_override); + } else { + WriteConstantValue(ss, constant_def); + } + ss << ";\n"; + } + + // + // Input/output variables + // + size_t variable_count = 0; + for (const auto& input : input_vars_) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; + } + for (const auto& output : output_vars_) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; + } + + // + // uniform variables + // + + // store shape uniform ranks in shape_uniform_ranks + bool use_any_shape_uniform = false; + ORT_ENFORCE(shape_uniform_ranks.size() == 0); + shape_uniform_ranks.reserve(input_vars_.size() + output_vars_.size() + indices_vars_.size()); + + for (const auto& input : input_vars_) { + bool use_uniform = (input->usage_ & ShaderUsage::UseUniform) && + (input->usage_ & ShaderUsage::UseShapeAndStride) && + input->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? input->rank_ : 0); + } + for (const auto& output : output_vars_) { + bool use_uniform = (output->usage_ & ShaderUsage::UseUniform) && + (output->usage_ & ShaderUsage::UseShapeAndStride) && + output->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? output->rank_ : 0); + } + for (const auto& indices : indices_vars_) { + bool use_uniform = (indices->usage_ & ShaderUsage::UseUniform) && + (indices->usage_ & ShaderUsage::UseShapeAndStride) && + indices->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? indices->rank_ : 0); + } + + if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), + program_.UniformVariables().cend(), + [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { + bool first = true; + ss << "struct Uniforms {"; + + // lambda append_uniform is used to append one uniform variable to the uniform struct + auto append_uniform = [&ss, &first](std::string_view name, ProgramUniformVariableDataType data_type, size_t length) { + if (length == 0) { + return; + } + + if (first) { + first = false; + } else { + ss << ","; + } + + auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; + ss << "\n " << alignment << name << ": "; + if (length > 4) { + if (data_type == ProgramUniformVariableDataType::Float16) { + size_t array_size = (length + 7) / 8; + ss << "array, " << array_size << ">"; + } else { + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; + } + } else if (length > 1) { + ss << "vec" << length << "<" << data_type << ">"; + } else { + ss << data_type; + } + }; + + for (const auto& input : input_vars_) { + const size_t rank = input->rank_; + if (rank > 0 && (input->usage_ & ShaderUsage::UseUniform) && (input->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = input->name_ + "_shape"; + std::string stride = input->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (const auto& output : output_vars_) { + const size_t rank = output->rank_; + if (rank > 0 && (output->usage_ & ShaderUsage::UseUniform) && (output->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = output->name_ + "_shape"; + std::string stride = output->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (const auto& indices : indices_vars_) { + const size_t rank = indices->rank_; + if (rank > 0 && (indices->usage_ & ShaderUsage::UseUniform) && (indices->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = indices->name_ + "_shape"; + std::string stride = indices->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (size_t i = 0; i < program_.UniformVariables().size(); i++) { + const auto& uniform_def = program_metadata_.uniform_variables[i]; + const auto& uniform_value = program_.UniformVariables()[i]; + append_uniform(uniform_def.name, uniform_def.data_type, uniform_value.length); + } + + ss << "\n};\n" + "@group(0) @binding(" + << variable_count << ") var uniforms: Uniforms;\n"; + } + + // + // Indices helper + // + ss << "\n"; + for (const auto& var : input_vars_) { + var->Impl(ss); + } + for (const auto& var : output_vars_) { + var->Impl(ss); + } + for (const auto& var : indices_vars_) { + var->Impl(ss); + } + ss << "\n"; + + // + // Additional Implementation + // + ss << additional_implementation_; + + // + // Main Function Body + // + ss << body_; + ss << "\n" + "}\n"; + + code = SS_GET(ss); + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h new file mode 100644 index 0000000000000..5e60c1293acea --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/common/safeint.h" +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/string_utils.h" + +namespace onnxruntime { +namespace webgpu { + +class ShaderHelper final { + // The content of a shader code is composed of the following parts: + // + // ** + // ** section: feature sets definition + // ** + // // this sections enable features like "enable f16;". need to be defined at the beginning of the shader. + // + // ** + // ** section: constants and overridable constants + // ** + // // this section defines constants and overridable constants. + // - constants are defined as "const a:f32 = 1.0;". It's hard coded in the shader. + // - overridable constants are defined as "override a:f32 = 1.0;" (may override or not) + // or "override b:u32;" (must override) + // the value can be overriden by pipeline creation config. + // + // ** + // ** section: inputs and outputs + // ** + // // this section defines input and output variables. + // user can call shader_helper.AddVariable() to add input and output variables. + // + // ** + // ** section: uniforms + // ** + // // this section defines uniform type and variables. + // + // ** + // ** section: indices helper generated utility functions + // ** + // // this section defines utility functions to calculate indices. + // + // ** + // ** section: additional implementation + // ** + // // this section contains additional implementation provided by the user. + // user can call shader_helper.AppendImplementation() to append additional implementation. + // + // ** + // ** section: main function + // ** + // // this section contains the main function of the shader. + // user can call shader_helper.MainFunctionBody() to set the main function body. + // + + public: + ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z); + + Status Init(); + + // Add an input variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. + const ShaderVariableHelper& AddInput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an output variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. + const ShaderVariableHelper& AddOutput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an indices variable to the shader. + const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform = true); + + // Get the string stream for additional implementation code to the shader. + inline OStringStream& AdditionalImplementation() { + return additional_implementation_ss_; + } + + // Get the string stream for the main function body of the shader. + inline OStringStream& MainFunctionBody() { + return body_ss_; + } + + std::string GuardAgainstOutOfBoundsWorkgroupSizes(std::string_view size) const { + return MakeStringWithClassicLocale(" if (global_idx >= ", size, ") { return; }\n"); + } + + private: + template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} + void WriteConstantValue(std::ostream& ss, const ConstantType& constant) const { + switch (constant.type) { + case ProgramConstantDataType::Float16: + ss << constant.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + ss << constant.f32; + break; + case ProgramConstantDataType::Int32: + ss << constant.i32; + break; + case ProgramConstantDataType::Uint32: + ss << constant.u32; + break; + case ProgramConstantDataType::Bool: + ss << (constant.boolean ? "true" : "false"); + break; + default: + ORT_THROW("Invalid constant type", constant.type); + } + } + + const ShaderVariableHelper& AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims); + +#ifndef NDEBUG // if debug build + Status ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const; + Status ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const; +#endif + + Status ValidateShapeForInputs() const; + Status ValidateShapeForOutputs() const; + Status ValidateIndices() const; + + // Generate source code. + // + // This function: + // - performs validation if neccessary, + // - appends the ranks for variables to the shape_uniform_ranks. + // (The rank value is zero if no uniform is needed for the variable.) + // - generates the final source code. + // + // \param code The generated full WGSL source code. + // \param shape_uniform_ranks The ranks for variables that need a uniform for the shape. + // + Status GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const; + friend class ProgramManager; + + const wgpu::Device& device_; + const wgpu::Limits& limits_; + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + const ProgramBase& program_; + const ProgramMetadata& program_metadata_; + + std::vector> input_vars_; + std::vector> output_vars_; + std::vector> indices_vars_; + std::string additional_implementation_; + OStringStream additional_implementation_ss_; + std::string body_; + OStringStream body_ss_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc new file mode 100644 index 0000000000000..e60a06800851d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/common/safeint.h" +#include "core/providers/webgpu/shader_variable.h" + +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +constexpr static const std::string_view STORAGE_TYPE[] = { + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "vec2", // Int64 + "vec2", // Uint64 + "u32", // Boolx4 + "u32", // Uint8x4 + "vec2", // Uint8x8 + "vec4", // Uint8x16 +}; + +constexpr static const std::string_view VALUE_TYPE[] = { + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "i32", // Int64 (trancated to i32) + "u32", // Uint64 (trancated to u32) + "vec4", // Boolx4 + "u32", // Uint8x4 (u32 as 4 elements of uint8) + "vec2", // Uint8x8 (vec2 as 2x4 elements of uint8) + "vec4", // Uint8x16 (vec4 as 4x4 elements of uint8) +}; + +constexpr static const std::string_view ELEMENT_TYPE[] = { + "f32", // Float32 + "f32", // Float32x2 + "f32", // Float32x4 + "f16", // Float16 + "f16", // Float16x2 + "f16", // Float16x4 + "i32", // Int32 + "i32", // Int32x2 + "i32", // Int32x4 + "u32", // Uint32 + "u32", // Uint32x2 + "u32", // Uint32x4 + "i32", // Int64 + "u32", // Uint64 + "bool", // Boolx4 + "u32", // Uint8x4 + "u32", // Uint8x8 + "u32", // Uint8x16 +}; + +inline std::string GetIndicesType(int rank) { + return rank < 2 ? "u32" + : (rank < 4 ? MakeStringWithClassicLocale("vec", rank, "") + : MakeStringWithClassicLocale("array")); +} + +} // namespace + +ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : name_(name), + type_(type), + num_components_{NumberOfComponents(type)}, + rank_{SafeInt(dims.NumDimensions())}, + dims_{dims}, + usage_(usage), + indices_type_{GetIndicesType(rank_)}, + value_type_alias_{name_ + "_value_t"}, + element_type_alias_{name_ + "_element_t"}, + indices_type_alias_{name_ + "_indices_t"} {} + +ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : ShaderIndicesHelper{name, type, usage, dims} { + ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); + ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); +} + +void ShaderIndicesHelper::Impl(std::ostream& ss) const { + // Start generating code + + const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; + const std::string stride = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; + + // Types + if (usage_ & ShaderUsage::UseValueTypeAlias) { + SS_APPEND(ss, "alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); + } + if (usage_ & ShaderUsage::UseIndicesTypeAlias) { + SS_APPEND(ss, "alias ", indices_type_alias_, " = ", indices_type_, ";\n"); + } + if (usage_ & ShaderUsage::UseElementTypeAlias) { + SS_APPEND(ss, "alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); + } + + // Need shape and strides when (not use uniform) and (use shape and stride is enabled) + if (!(usage_ & ShaderUsage::UseUniform) && (usage_ & ShaderUsage::UseShapeAndStride) && rank_ > 0) { + SS_APPEND(ss, "const ", shape, " = ", IndicesType(), "("); + + bool first = true; + for (auto dim : dims_.GetDims()) { + if (!first) { + ss << ","; + } + + ss << dim; + first = false; + } + ss << ");\n"; + + if (rank_ > 1) { + SS_APPEND(ss, "const ", stride, " = ", GetIndicesType(rank_ - 1), "("); + first = true; + for (int i = 1; i < rank_; i++) { + if (!first) { + ss << ","; + } + ss << dims_.SizeFromDimension(i); + first = false; + } + ss << ");\n"; + } + } + + // Implementation of "fn o2i_{name}" + if (usage_ & ShaderUsage::UseOffsetToIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); + SS_APPEND(ss, " var indices: ", IndicesType(), ";\n"); + SS_APPEND(ss, " var current = offset;\n"); + for (int i = 0; i < rank_ - 1; i++) { + auto current_stride = GetElementAt(stride, i, rank_ - 1); + SS_APPEND(ss, " let dim", i, " = current / ", current_stride, ";\n"); + SS_APPEND(ss, " let rest", i, " = current % ", current_stride, ";\n"); + SS_APPEND(ss, " indices[", i, "] = dim", i, ";\n"); + SS_APPEND(ss, " current = rest", i, ";\n"); + } + SS_APPEND(ss, " indices[", rank_ - 1, "] = current;\n"); + SS_APPEND(ss, " return indices;\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn i2o_{name}" + if (usage_ & ShaderUsage::UseIndicesToOffset) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); + SS_APPEND(ss, " return "); + for (int i = 0; i < rank_ - 1; i++) { + SS_APPEND(ss, "indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); + } + SS_APPEND(ss, "indices[", rank_ - 1, "];\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn {res_name}_bi2o_{name}" + if (usage_ & ShaderUsage::UseBroadcastedIndicesToOffset) { + if (rank_ > 0) { + for (const auto& broadcasted_result_ptr : broadcasted_to_) { + const auto& broadcasted_result = *broadcasted_result_ptr; + SS_APPEND(ss, "fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); + if (rank_ == 1) { + SS_APPEND(ss, " return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); + } else { + SS_APPEND(ss, " return "); + for (int i = 0; i < rank_ - 1; i++) { + auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); + std::string current_stride = rank_ == 2 ? stride : GetElementAt(stride, i, rank_ - 1); + SS_APPEND(ss, current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); + } + SS_APPEND(ss, broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + } + SS_APPEND(ss, "}\n"); + } + } + } +} + +void ShaderVariableHelper::Impl(std::ostream& ss) const { + ShaderIndicesHelper::Impl(ss); + + // Implementation of "fn set_{name}" + if (usage_ & ShaderUsage::UseSet) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn set_", name_, "(d0: u32"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i, ": u32"); + } + SS_APPEND(ss, ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " set_", name_, "_by_indices(d0"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i); + } + SS_APPEND(ss, ", value);\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn set_{name}_by_indices" + if (usage_ & ShaderUsage::UseSetByIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn get_{name}" + if (usage_ & ShaderUsage::UseGet) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn get_", name_, "(d0: u32"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i, ": u32"); + } + SS_APPEND(ss, ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return get_", name_, "_by_indices(d0"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i); + } + SS_APPEND(ss, ");\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn get_{name}_by_indices" + if (usage_ & ShaderUsage::UseGetByIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS_APPEND(ss, "}\n"); + } + } +} + +std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const { + SS(ss, kStringInitialSizeGetByOffsetImpl); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << ElementType() << "(" << name_ << "[" << offset << "].x)"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: + ss << "vec4(bool(" + << name_ << "[" << offset << "] & 0xFFu), bool(" + << name_ << "[" << offset << "] & 0xFF00u), bool(" + << name_ << "[" << offset << "] & 0xFF0000u), bool(" + << name_ << "[" << offset << "] & 0xFF000000u))"; + break; + default: + ss << name_ << "[" << offset << "]"; + } + + return SS_GET(ss); +} + +std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::string_view value) const { + SS(ss, kStringInitialSizeSetByOffsetImpl); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), select(0u, 0xFFFFFFFFu, " << value << " < 0));"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), 0u);"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: + ss << name_ << "[" << offset << "]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(" << value << "));"; + break; + default: + ss << name_ << "[" << offset << "]=" << value << ";"; + } + + return SS_GET(ss); +} + +std::string_view ShaderVariableHelper::StorageType() const { + return STORAGE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariableHelper::ValueType() const { + return (usage_ & ShaderUsage::UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariableHelper::ElementType() const { + return (usage_ & ShaderUsage::UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; +} + +std::string_view ShaderIndicesHelper::IndicesType() const { + return (usage_ & ShaderUsage::UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; +} +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h new file mode 100644 index 0000000000000..4d4655925c980 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +template || std::is_same_v>> +std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { + // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. + if (var.rfind("uniforms.", 0) == 0) { + if (rank > 4) { + if constexpr (std::is_integral_v) { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + } else { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + } + } else { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + } else { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); + } + } + } + } + + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : std::string{var}; +} + +struct ShaderUsage { + enum : uint32_t { + None = 0, // no usage. this means no additional implementation code will be generated. + UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) + UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) + UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) + UseShapeAndStride = 16, // use shape and stride for the variable + UseOffsetToIndices = 32, // use implementation of fn o2i_{name} + UseIndicesToOffset = 64, // use implementation of fn i2o_{name} + UseBroadcastedIndicesToOffset = 128, // use implementation of fn {broadcasted_result_name}_bi2o_{name} + UseSet = 256, // use implementation of fn set_{name} + UseSetByIndices = 512, // use implementation of fn set_{name}_by_indices + UseGet = 1024, // use implementation of fn get_{name} + UseGetByIndices = 2048, // use implementation of fn get_{name}_by_indices + UseUniform = 32768, // use uniform for shape and stride + } usage; + + ShaderUsage(decltype(usage) usage) : usage{usage} {} + ShaderUsage(uint32_t usage) : usage{usage} {} + + explicit operator bool() { + return usage != None; + } +}; + +// A helper class to make it easier to generate shader code related to indices calculation. +class ShaderIndicesHelper { + public: + ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderIndicesHelper(ShaderIndicesHelper&&) = default; + ShaderIndicesHelper& operator=(ShaderIndicesHelper&&) = default; + + // get the number of components of the variable. + inline int NumComponents() const { return num_components_; } + + // get the rank of the indices. + inline int Rank() const; + + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. + // \param offset: a WGSL expression (u32) representing the offset. + inline std::string OffsetToIndices(std::string_view offset_expr) const; + + // create a WGSL expression (u32) for getting offset from indices. + // \param indices: a WGSL expression ({varname}_indices_t) representing the indices. + inline std::string IndicesToOffset(std::string_view indices_expr) const; + + // create a WGSL expression (u32) for getting original offset from broadcasted indices. + // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. + // \param broadcasted_result: the broadcasted result variable. + inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const; + + // create a WGSL expression ({varname}_indices_t) as an indices literal + // \param init: a list of indices values. + template + inline std::string Indices(TIndices&&... indices_args) const; + + // create a WGSL statement for setting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to set. + // \param value: the value (u32) to set. + template + inline std::string IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const; + + // create a WGSL expression (u32) for getting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to get. + template + inline std::string IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const; + + protected: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper); + + void Impl(std::ostream& ss) const; + + std::string_view IndicesType() const; + + std::string name_; + ProgramVariableDataType type_; // for variable + int num_components_; // for variable + int rank_; + TensorShape dims_; + + mutable ShaderUsage usage_; + mutable std::set broadcasted_to_; + + // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. + std::string indices_type_; + + // the alias for the types + std::string value_type_alias_; + std::string element_type_alias_; + std::string indices_type_alias_; + + friend class ShaderHelper; +}; + +// A helper class to make it easier to generate shader code related to a variable setting/getting and its indices calculation. +class ShaderVariableHelper : public ShaderIndicesHelper { + public: + ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderVariableHelper(ShaderVariableHelper&&) = default; + ShaderVariableHelper& operator=(ShaderVariableHelper&&) = default; + + // create a WGSL statement for setting data at the given indices. + // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). + template + inline std::string Set(TIndicesAndValue&&... args) const; + + // create a WGSL statement for setting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param value: the value ({varname}_value_t) to set. + inline std::string SetByIndices(std::string_view indices_var, std::string_view value) const; + + // create a WGSL statement for setting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + // \param value: the value ({varname}_value_t) to set. + template + inline std::string SetByOffset(TOffset&& offset, TValue&& value) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices: a list of indices values (u32). + template + inline std::string Get(TIndices&&... indices) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + inline std::string GetByIndices(std::string_view indices_var) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + template + inline std::string GetByOffset(TOffset&& offset) const; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); + + void Impl(std::ostream& ss) const; + + std::string GetByOffsetImpl(std::string_view offset) const; + std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; + std::string_view StorageType() const; + std::string_view ValueType() const; + std::string_view ElementType() const; + + friend class ShaderHelper; +}; + +inline ShaderUsage operator|(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage | (uint32_t)b.usage; +} +inline ShaderUsage operator&(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage & (uint32_t)b.usage; +} +inline ShaderUsage& operator|=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage |= (uint32_t)b.usage; + return a; +} +inline ShaderUsage& operator&=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage &= (uint32_t)b.usage; + return a; +} + +namespace detail { +template >> +std::string pass_as_string(T&& v) { + return std::to_string(std::forward(v)); +} +template +std::string_view pass_as_string(std::string_view sv) { + return sv; +} +template +std::string pass_as_string(T&& v) { + return std::forward(v); +} +} // namespace detail + +inline int ShaderIndicesHelper::Rank() const { + // getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride. + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_; +} + +inline std::string ShaderIndicesHelper::OffsetToIndices(std::string_view offset_expr) const { + usage_ |= ShaderUsage::UseOffsetToIndices | ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{offset_expr} + : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); +} + +inline std::string ShaderIndicesHelper::IndicesToOffset(std::string_view indices_expr) const { + usage_ |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{indices_expr} + : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); +} + +inline std::string ShaderIndicesHelper::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const { + ORT_ENFORCE(broadcasted_result.num_components_ == -1 || + num_components_ == -1 || + broadcasted_result.num_components_ == num_components_, + "number of components should be the same for 2 variables to calculate"); + usage_ |= ShaderUsage::UseBroadcastedIndicesToOffset | ShaderUsage::UseShapeAndStride; + broadcasted_to_.insert(&broadcasted_result); + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); +} + +template +inline std::string ShaderIndicesHelper::Indices(TIndices&&... indices_args) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(IndicesType(), "(", + absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), + ')'); +} + +template +inline std::string ShaderIndicesHelper::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') + : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); +} + +template +inline std::string ShaderIndicesHelper::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{indices_var} + : GetElementAt(indices_var, idx_expr, rank_); +} + +template +inline std::string ShaderVariableHelper::SetByOffset(TOffset&& offset, TValue&& value) const { + return SetByOffsetImpl(detail::pass_as_string(offset), detail::pass_as_string(value)); +} + +template +inline std::string ShaderVariableHelper::Set(TIndicesAndValue&&... args) const { + usage_ |= ShaderUsage::UseShapeAndStride; + ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); + if constexpr (sizeof...(TIndicesAndValue) == 1) { + return SetByOffset("0", std::forward(args)...); + } else if constexpr (sizeof...(TIndicesAndValue) == 2) { + return SetByOffset(std::forward(args)...); + } else { + usage_ |= ShaderUsage::UseSet | ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(args)...), ", "), + ");"); + } +} + +inline std::string ShaderVariableHelper::SetByIndices(std::string_view indices_var, std::string_view value) const { + usage_ |= ShaderUsage::UseShapeAndStride; + if (rank_ < 2) { + return SetByOffset(indices_var, value); + } else { + usage_ |= ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, "_by_indices(", indices_var, ", ", value, ");"); + } +} + +template +inline std::string ShaderVariableHelper::GetByOffset(TOffset&& offset) const { + return GetByOffsetImpl(detail::pass_as_string(offset)); +} + +template +inline std::string ShaderVariableHelper::Get(TIndices&&... indices) const { + usage_ |= ShaderUsage::UseShapeAndStride; + ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); + if constexpr (sizeof...(TIndices) == 0) { + return GetByOffset("0"); + } else if constexpr (sizeof...(TIndices) == 1) { + return GetByOffset(std::forward(indices)...); + } else { + usage_ |= ShaderUsage::UseGet | ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(indices)...), ", "), + ')'); + } +} + +inline std::string ShaderVariableHelper::GetByIndices(std::string_view indices_var) const { + usage_ |= ShaderUsage::UseShapeAndStride; + if (rank_ < 2) { + return GetByOffset(indices_var); + } else { + usage_ |= ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, "_by_indices(", indices_var, ")"); + } +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/string_macros.h b/onnxruntime/core/providers/webgpu/string_macros.h new file mode 100644 index 0000000000000..7821d9c49a171 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_macros.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/string_utils.h" + +// macro "SS" - declare an ostream variable and its string buffer +#define SS(ss, reserve_size) \ + std::string ss##_str; \ + ss##_str.reserve(reserve_size); \ + ::onnxruntime::webgpu::OStringStream ss(&ss##_str) + +// macro "SS_GET" - get the string from the ostream +#define SS_GET(ss) ss##_str + +// macro "SS_APPEND" - use function call style to append to the ostream +#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppend(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/string_utils.h b/onnxruntime/core/providers/webgpu/string_utils.h new file mode 100644 index 0000000000000..e6d7097ad6182 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_utils.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/make_string.h" +#include + +namespace onnxruntime { +namespace webgpu { + +constexpr const size_t kStringInitialSizeSetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeGetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeShaderSourceCode = 2048; +#ifndef NDEBUG +constexpr const size_t kStringInitialSizeCacheKey = 512; +#else +constexpr const size_t kStringInitialSizeCacheKey = 256; +#endif + +using OStringStream = absl::strings_internal::OStringStream; + +namespace detail { +inline void OStringStreamAppendImpl(std::ostream& /*ss*/) noexcept { +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... args) noexcept { + OStringStreamAppendImpl(ss, t); + OStringStreamAppendImpl(ss, args...); +} + +template +inline void OStringStreamAppend(std::ostream& ss, const Args&... args) { + return OStringStreamAppendImpl(ss, ::onnxruntime::detail::if_char_array_make_ptr_t(args)...); +} + +} // namespace detail + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc new file mode 100644 index 0000000000000..06eae971309c5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/tensor/cast.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +const std::vector& CastOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 6, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 9, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 13, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_KERNEL_EX( + Cast, + kOnnxDomain, + 19, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); + +Status Cast::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + SafeInt vec_size = (size + 3) / 4; + + CastProgram program{to_}; + program + .AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .CacheHint(std::to_string(to_)); + return context.RunProgram(program); +} + +Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& input = sh.AddInput("x", ShaderUsage::UseUniform); + const auto& output = sh.AddOutput("y", ShaderUsage::UseUniform); + std::string expression; + switch (to_) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + expression = "vec4(a)"; + break; + default: + ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported."); + } + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h new file mode 100644 index 0000000000000..47e8e6412be46 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class CastProgram final : public Program { + public: + CastProgram(int32_t to) : Program{"Cast"}, to_{to} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int32_t to_; +}; + +class Cast final : public WebGpuKernel { + public: + Cast(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t to; + Status status = info.GetAttr("to", &to); + ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); + to_ = SafeInt(to); + + // ignore attribute 'saturate' as float8 is not supported in WebGPU + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int32_t to_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc new file mode 100644 index 0000000000000..c708f24dcc330 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/tensor/concat.h" + +#include "core/common/inlined_containers.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +#define WEBGPU_CONCAT_KERNEL(version) \ + ONNX_OPERATOR_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +WEBGPU_CONCAT_VERSIONED_KERNEL(1, 3) +WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) +WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) +WEBGPU_CONCAT_KERNEL(13) + +void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(index: u32) -> u32 {\n" + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" + << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" + << " return i;\n" + << " }\n" + << " }\n" + << " return " << input_count << ";\n" + << "}\n"; +} + +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { + os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i == 0) { + os << " if (input_index == 0u) {\n"; + } else if (i == inputs.size() - 1) { + os << " } else {\n"; + } else { + os << " } else if (input_index == " << i << "u) {\n"; + } + os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; + } + os << " }\n" + "}\n"; +} + +Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { + size_t input_count = Inputs().size(); + std::vector inputs; + inputs.reserve(input_count); + for (size_t i = 0; i < input_count; ++i) { + inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + // add implementation of fn calculate_input_index + AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); + // add implementation of fn assign_output_data + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); + const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" + << " let input_index = calculate_input_index(indices_axis);\n" + << " if (input_index != 0u) {\n" + << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" + << " }\n" + " assign_output_data(global_idx, input_index, indices);\n"; + return Status::OK(); +} + +Status Concat::ComputeInternal(ComputeContext& context) const { + int input_count = context.InputCount(); + InlinedTensorsVector input_tensors; + input_tensors.reserve(input_count); + for (int i = 0; i < input_count; ++i) { + input_tensors.push_back(context.Input(i)); + } + + Prepare prepare; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), input_tensors, prepare)); + if (prepare.output_num_elements == 0) { + return Status::OK(); + } + + uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); + + ConcatProgram program{prepare.axis}; + + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(input_count); + uint32_t sum = 0; + for (int i = 0; i < input_count; ++i) { + const auto& input = prepare.inputs[i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + + auto axis_size = input.tensor->Shape()[prepare.axis]; + sum += static_cast(axis_size); + sizes_in_concat_axis.push_back(sum); + } + + size_t non_empty_input_count = sizes_in_concat_axis.size(); + + if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { + // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", + input_count, ", output=1) exceeds the limit (", + context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + } + + program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), + output_size}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h new file mode 100644 index 0000000000000..0f6e6dd327e33 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/concatbase.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class ConcatProgram final : public Program { + public: + ConcatProgram(size_t axis) : Program{"Concat"}, axis_{axis} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + size_t axis_; +}; + +class Concat final : public WebGpuKernel, public ConcatBase { + public: + Concat(const OpKernelInfo& info) : WebGpuKernel(info), ConcatBase(info) { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc new file mode 100644 index 0000000000000..84cdb35d77f0b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" + +#include "core/providers/webgpu/tensor/expand.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + + return Status::OK(); +} + +Status Expand::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const auto* input_shape_tensor = context.Input(1); + + auto output_dims = input_shape_tensor->DataAsSpan(); + TensorShape output_shape{}; + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + + auto* output_tensor = context.Output(0, output_shape); + uint32_t data_size = SafeInt(output_shape.Size()); + ExpandProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {data_size}, + }); + return context.RunProgram(program); +} + +#define WEBGPU_EXPAND_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +#define WEBGPU_EXPAND_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h new file mode 100644 index 0000000000000..046520b479257 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class ExpandProgram final : public Program { + public: + ExpandProgram() : Program{"Expand"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); +}; + +class Expand final : public WebGpuKernel { + public: + Expand(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc new file mode 100644 index 0000000000000..47b78e7015135 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/gather.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var indices_indices = input_indices_indices_t(0);\n"; + for (int i = 0; i < indices.Rank(); i++) { + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; + } + shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(" << data.IndicesGet("uniforms.data_shape", axis_) << ");\n" + << " }\n" + << " var data_indices : data_indices_t;\n"; + for (int i = 0, j = 0; i < data.Rank(); i++) { + if (i == SafeInt(axis_)) { + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + j += indices.Rank(); + } else { + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; + j++; + } + } + + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); + + return Status::OK(); +} + +Status Gather::ComputeInternal(ComputeContext& context) const { + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + uint32_t data_size = SafeInt(p.output_tensor->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + uint32_t axis = static_cast(p.axis); + GatherProgram program{axis}; + program + .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .CacheHint(std::to_string(axis)) + .AddUniformVariables({{data_size}}); + return context.RunProgram(program); +} + +#define WEBGPU_GATHER_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +#define WEBGPU_GATHER_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.h b/onnxruntime/core/providers/webgpu/tensor/gather.h new file mode 100644 index 0000000000000..bebe13519ce43 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/tensor/gatherbase.h" + +namespace onnxruntime { +namespace webgpu { + +class GatherProgram final : public Program { + public: + GatherProgram(const uint32_t axis) : Program{"Gather"}, axis_{axis} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t axis_; +}; + +class Gather final : public WebGpuKernel, public GatherBase { + public: + Gather(const OpKernelInfo& info) : WebGpuKernel(info), GatherBase(info) {} + + protected: + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.cc b/onnxruntime/core/providers/webgpu/tensor/reshape.cc new file mode 100644 index 0000000000000..9ede015a0c99c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/reshape.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Reshape, + kOnnxDomain, + 21, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 14, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 13, 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 5, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.h b/onnxruntime/core/providers/webgpu/tensor/reshape.h new file mode 100644 index 0000000000000..4629598d068f7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/data_transfer_manager.h" +#include "core/providers/cpu/tensor/reshape_helper.h" + +namespace onnxruntime { +namespace webgpu { + +class Reshape final : public OpKernel { + public: + Reshape(const OpKernelInfo& info) + : OpKernel{info}, + allow_zero_(info.GetAttrOrDefault("allowzero", static_cast(0)) == 1) { + } + + Status Compute(OpKernelContext* context) const override { + // Copy the second input tensor into the shape vector + const Tensor* shapeTensor = context->Input(1); + if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (shapeTensor->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); + } + auto data_span = shapeTensor->template DataAsSpan(); + TensorShapeVector shape(data_span.begin(), data_span.end()); + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + ReshapeHelper helper(X_shape, shape, allow_zero_); + + Tensor* Y = context->Output(0, TensorShape(shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } + + private: + bool allow_zero_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc new file mode 100644 index 0000000000000..b211d48dab1c9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 13, 14, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 15, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_KERNEL_EX( + Shape, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc new file mode 100644 index 0000000000000..136a1ba9776a0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/squeeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.h b/onnxruntime/core/providers/webgpu/tensor/squeeze.h new file mode 100644 index 0000000000000..bc80cb86d5e8e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/squeeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Squeeze final : public OpKernel, public SqueezeBase { + public: + explicit Squeeze(const OpKernelInfo& info) : OpKernel{info}, SqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.cc b/onnxruntime/core/providers/webgpu/tensor/tile.cc new file mode 100644 index 0000000000000..841c36724df30 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/tile.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Tile, + kOnnxDomain, + 6, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +ONNX_OPERATOR_KERNEL_EX( + Tile, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +Status TileProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "var input_indices: input_indices_t;\n"; + for (auto i = 0; i < input.Rank(); i++) { + std::string input_dim_i = absl::StrCat("input_dim_", i); + std::string input_dim_value = absl::StrCat("input_dim_", i, "_value"); + shader.MainFunctionBody() << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n" + << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n" + << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; + } + + shader.MainFunctionBody() << output.SetByOffset("global_idx", input.GetByIndices("input_indices")); + + return Status::OK(); +} + +Status Tile::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + + const auto* repeats_tensor = context.Input(1); + const auto* repeats_data = repeats_tensor->Data(); + std::vector repeats; + + for (size_t i = 0; i < static_cast(repeats_tensor->Shape().Size()); i++) { + repeats.push_back(static_cast(repeats_data[i])); + } + + auto output_dims = input_shape.AsShapeVector(); + for (size_t axis = 0; axis < input_rank; axis++) { + output_dims[axis] *= repeats[axis]; + } + + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_tensor->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + TileProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {repeats}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.h b/onnxruntime/core/providers/webgpu/tensor/tile.h new file mode 100644 index 0000000000000..9b6ab420b3252 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class TileProgram final : public Program { + public: + TileProgram() : Program{"Tile"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"repeats", ProgramUniformVariableDataType::Uint32}); +}; + +class Tile final : public WebGpuKernel { + public: + Tile(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc new file mode 100644 index 0000000000000..c40ec43dd0009 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_KERNEL_EX( + Transpose, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; + +Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + + if (use_shared_) { + shader.AdditionalImplementation() << "var tile : array, tile_size>;\n"; + shader.MainFunctionBody() << " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + << " tile[local_id.y][local_id.x] = " << input.GetByIndices("a_indices_t(input_row, input_col)") << ";\n" + << " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n" + << " " << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n" + << " }"; + } else { + shader.AdditionalImplementation() << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; + for (size_t i = 0; i < perm_.size(); ++i) { + shader.AdditionalImplementation() << " a[" << perm_[i] << "] = i[" << i << "];\n"; + } + shader.AdditionalImplementation() << " return a;\n" + "}\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let indices = " << output.OffsetToIndices("global_idx") + << ";\n" + " let x_indices = perm(indices);\n" + " " + << output.SetByOffset("global_idx", input.GetByIndices("x_indices")); + } + return Status::OK(); +} + +Status Transpose::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const bool channels_last = new_perm == InlinedVector({2, 3, 1}); + const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; + auto new_input_shape = input_shape; + TensorShape new_output_shape(output_dims); + if (use_shared) { + new_input_shape = channels_last + ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) + : channels_first + ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) + : new_shape; + new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); + } + + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); + TransposeProgram program{*p_perm, use_shared}; + if (use_shared) { + program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); + } + + program + .CacheHint(absl::StrJoin(*p_perm, "-")) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + .AddUniformVariables({ + {static_cast(output_size)}, + }); + + use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h new file mode 100644 index 0000000000000..7cf5c1fe0865d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + Status ComputeInternal(ComputeContext& context) const override; + constexpr static uint32_t TILE_SIZE = 16; +}; + +class TransposeProgram final : public Program { + public: + TransposeProgram(const gsl::span& permutations, bool use_shared) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_CONSTANTS({"tile_size", Transpose::TILE_SIZE}); + + private: + InlinedVector perm_; + const bool use_shared_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc new file mode 100644 index 0000000000000..4bcef4fd79296 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/unsqueeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h new file mode 100644 index 0000000000000..0ae9d50f6d4e7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Unsqueeze final : public OpKernel, public UnsqueezeBase { + public: + explicit Unsqueeze(const OpKernelInfo& info) : OpKernel{info}, UnsqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); + auto data_span = axes_tensor->template DataAsSpan(); + axes.assign(data_span.begin(), data_span.end()); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc new file mode 100644 index 0000000000000..dada446b4bd47 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/where.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +// Compute where operator output shape based upon three way broad-casting. +Status ComputeOutputShape(const TensorShape& cond_shape, + const TensorShape& x_shape, const TensorShape& y_shape, TensorShape& output_shape) { + size_t cond_rank = cond_shape.NumDimensions(); + size_t x_rank = x_shape.NumDimensions(); + size_t y_rank = y_shape.NumDimensions(); + size_t output_rank = std::max(std::max(cond_rank, x_rank), y_rank); + + std::vector output_dims(output_rank, 0); + for (size_t i = 0; i < output_rank; ++i) { + int64_t cond_dim = 1; + if (i < cond_rank) + cond_dim = cond_shape[cond_rank - 1 - i]; + + int64_t x_dim = 1; + if (i < x_rank) + x_dim = x_shape[x_rank - 1 - i]; + + int64_t y_dim = 1; + if (i < y_rank) + y_dim = y_shape[y_rank - 1 - i]; + + int64_t output_dim = std::max(std::max(cond_dim, x_dim), y_dim); + // special case to handle a dim of 0 which can be broadcast with a 1 + if (output_dim == 1) + output_dim = std::min(std::min(cond_dim, x_dim), y_dim); + + const auto node_name = "Where"; + if (cond_dim != output_dim && cond_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": condition operand cannot broadcast on dim ", cond_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (x_dim != output_dim && x_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": X operand cannot broadcast on dim ", x_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (y_dim != output_dim && y_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": Y operand cannot broadcast on dim ", y_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + output_dims[output_rank - 1 - i] = output_dim; + } + + output_shape = TensorShape(output_dims); + return Status::OK(); +} + +Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& c_input = shader.AddInput("c_data", ShaderUsage::UseUniform); + const auto& a_input = shader.AddInput("a_data", ShaderUsage::UseUniform); + const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + const auto expression = [](std::string_view a, std::string_view b, std::string_view c) -> auto { + return absl::StrCat("select(", b, ", ", a, ", ", c, ")"); + }; + + if (!is_broadcast_) { + shader.MainFunctionBody() << output.SetByOffset( + "global_idx", + expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx"))); + + } else { + const auto& c_indices = shader.AddIndices("c_indices"); + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + const auto& output_indices = shader.AddIndices("output_indices"); + + const auto single_assignment = + [&expression, &shader, &output_indices, &a_indices, &b_indices, &c_indices]( + std::string_view rest_str, const std::string& x, std::string_view type_cast = "") + -> void { + const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; + const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; + const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; + + shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4 + " + x) << ";\n" + << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_b" << x << " = " << b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_c" << x << " = " << c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let index_a" << x << " = offset_a" << x << " / 4;\n" + << "let index_b" << x << " = offset_b" << x << " / 4;\n" + << "let index_c" << x << " = offset_c" << x << " / 4;\n" + << "let component_a" << x << " = offset_a" << x << " % 4;\n" + << "let component_b" << x << " = offset_b" << x << " % 4;\n" + << "let component_c" << x << " = offset_c" << x << " % 4;\n" + << rest_str << "[" << x << "] = " << type_cast << "(" << expression(a_expression, b_expression, c_expression) << ");\n"; + }; + + if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { + shader.MainFunctionBody() << "var data = vec4(0);\n"; + single_assignment("data", "0", "u32"); + single_assignment("data", "1", "u32"); + single_assignment("data", "2", "u32"); + single_assignment("data", "3", "u32"); + shader.MainFunctionBody() << "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; + } else { + single_assignment("output_data[global_idx]", "0"); + single_assignment("output_data[global_idx]", "1"); + single_assignment("output_data[global_idx]", "2"); + single_assignment("output_data[global_idx]", "3"); + } + } + + return Status::OK(); +} + +Status Where::ComputeInternal(ComputeContext& context) const { + const auto* cond_tensor = context.Input(0); + const auto* x_tensor = context.Input(1); + const auto* y_tensor = context.Input(2); + const auto& cond_shape = cond_tensor->Shape(); + const auto& x_shape = x_tensor->Shape(); + const auto& y_shape = y_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); + auto* output_tensor = context.Output(0, output_shape); + const auto component = 4; + uint32_t vec_size = gsl::narrow_cast((output_shape.Size() + 3) / component); + const auto is_broadcast = !(x_shape == y_shape && + y_shape == cond_shape); + WhereProgram program{is_broadcast}; + program + .CacheHint(is_broadcast) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4}, + {x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4}, + {y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddUniformVariables({ + {static_cast(vec_size)}, + }); + if (is_broadcast) { + program + .AddIndices(cond_shape) + .AddIndices(x_shape) + .AddIndices(y_shape) + .AddIndices(output_tensor->Shape()); + } + return context.RunProgram(program); +} + +namespace { +const std::vector& WhereOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Where, + kOnnxDomain, + 9, 15, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +ONNX_OPERATOR_KERNEL_EX( + Where, + kOnnxDomain, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/where.h b/onnxruntime/core/providers/webgpu/tensor/where.h new file mode 100644 index 0000000000000..e46b24e9ba2e5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class WhereProgram final : public Program { + public: + WhereProgram(bool is_broadcast) : Program{"Where"}, is_broadcast_{is_broadcast} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool is_broadcast_; +}; + +class Where final : public WebGpuKernel { + public: + Where(const OpKernelInfo& info) : WebGpuKernel{info} { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc new file mode 100644 index 0000000000000..eb03bf83763f1 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -0,0 +1,677 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "dawn/dawn_proc.h" +#if !defined(USE_EXTERNAL_DAWN) +#include "dawn/native/DawnNative.h" +#endif + +#include "core/common/common.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/program_cache_key.h" +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table) { + std::call_once(init_flag_, [this, &webgpu_ep_info, dawn_proc_table]() { + // Initialization.Step.1 - Create wgpu::Instance + if (instance_ == nullptr) { + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); +#if !defined(USE_EXTERNAL_DAWN) + if (dawn_procs == nullptr) { + dawn_procs = &dawn::native::GetProcs(); + } +#else + ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); +#endif + dawnProcSetProcs(dawn_procs); + + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.features.timedWaitAnyEnable = true; + instance_ = wgpu::CreateInstance(&instance_desc); + + ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance."); + } + + // Initialization.Step.2 - Create wgpu::Adapter + if (adapter_ == nullptr) { + wgpu::RequestAdapterOptions req_adapter_options = {}; + wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; + req_adapter_options.nextInChain = &adapter_toggles_desc; +#ifdef _WIN32 + req_adapter_options.backendType = wgpu::BackendType::D3D12; +#endif + req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance; + + auto enabled_adapter_toggles = GetEnabledAdapterToggles(); + adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size(); + adapter_toggles_desc.enabledToggles = enabled_adapter_toggles.data(); + + wgpu::RequestAdapterCallbackInfo req_adapter_callback_info = {}; + req_adapter_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; + req_adapter_callback_info.callback = [](WGPURequestAdapterStatus status, + WGPUAdapter adapter, const char* message, + void* userdata) { + ORT_ENFORCE(status == WGPURequestAdapterStatus_Success, "Failed to get a WebGPU adapter: ", message); + *static_cast(userdata) = wgpu::Adapter::Acquire(adapter); + }; + req_adapter_callback_info.userdata = &adapter_; + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(&req_adapter_options, req_adapter_callback_info), UINT64_MAX)); + ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter."); + } + + // Initialization.Step.3 - Create wgpu::Device + if (device_ == nullptr) { + wgpu::DeviceDescriptor device_desc = {}; + wgpu::DawnTogglesDescriptor device_toggles_desc = {}; + device_desc.nextInChain = &device_toggles_desc; + + auto enabled_device_toggles = GetEnabledDeviceToggles(); + device_toggles_desc.enabledToggleCount = enabled_device_toggles.size(); + device_toggles_desc.enabledToggles = enabled_device_toggles.data(); + + auto disabled_device_toggles = GetDisabledDeviceToggles(); + device_toggles_desc.disabledToggleCount = disabled_device_toggles.size(); + device_toggles_desc.disabledToggles = disabled_device_toggles.data(); + + std::vector required_features = GetAvailableRequiredFeatures(adapter_); + if (required_features.size() > 0) { + device_desc.requiredFeatures = required_features.data(); + device_desc.requiredFeatureCount = required_features.size(); + } + wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_); + device_desc.requiredLimits = &required_limits; + + // TODO: revise temporary error handling + device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, const char* message) { + LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << message; + }); + // TODO: revise temporary device lost handling + device_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device& /*device*/, wgpu::DeviceLostReason reason, const char* message) { + // cannot use ORT logger because it may be already destroyed + std::cerr << "WebGPU device lost (" << int(reason) << "): " << message; + }); + + wgpu::RequestDeviceCallbackInfo req_device_callback_info = {}; + req_device_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; + req_device_callback_info.callback = [](WGPURequestDeviceStatus status, WGPUDevice device, char const* message, void* userdata) { + ORT_ENFORCE(status == WGPURequestDeviceStatus_Success, "Failed to get a WebGPU device: ", message); + *static_cast(userdata) = wgpu::Device::Acquire(device); + }; + req_device_callback_info.userdata = &device_; + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(&device_desc, req_device_callback_info), UINT64_MAX)); + ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); + } + + // cache adapter info + ORT_ENFORCE(Adapter().GetInfo(&adapter_info_)); + // cache device limits + wgpu::SupportedLimits device_supported_limits; + ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); + device_limits_ = device_supported_limits.limits; + + // create buffer manager + buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); + + // create program manager + program_mgr_ = std::make_unique(Device(), DeviceLimits()); + + // set query type + if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { + query_type_ = TimestampQueryType::InsidePasses; + } else if (device_.HasFeature(wgpu::FeatureName::TimestampQuery)) { + query_type_ = TimestampQueryType::AtPasses; + } else { + query_type_ = TimestampQueryType::None; + } + }); +} + +Status WebGpuContext::Wait(wgpu::Future f) { + auto status = instance_.WaitAny(f, UINT64_MAX); + if (status == wgpu::WaitStatus::Success) { + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); +} + +Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { + const auto& inputs = program.Inputs(); + const auto& outputs = program.Outputs(); + + if (outputs.size() == 0) { + return Status::OK(); + } + + if (ValidationMode() >= ValidationMode::Basic) { + ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { + const auto* tensor = input.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All inputs must be tensors on WebGPU buffers."); + + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { + const auto* tensor = output.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All outputs must be tensors on WebGPU buffers."); + } + + const ProgramMetadata& metadata = program.Metadata(); + + // validate program metadata + if (ValidationMode() >= ValidationMode::Basic) { + const auto& [constants, overridable_constants, uniform_variables] = metadata; + + // check overridable constants + ORT_RETURN_IF(program.OverridableConstants().size() != overridable_constants.size(), + "Size of overridable constants mismatch in program \"", program.Name(), + "\", Expected: ", overridable_constants.size(), + ", Actual: ", program.OverridableConstants().size()); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_overridable_constants = program.OverridableConstants().size(); + for (size_t i = 0; i < num_overridable_constants; ++i) { + const auto& override_value = program.OverridableConstants()[i]; + const auto& definition = overridable_constants[i]; + ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, + "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.type, + ", Actual: ", override_value.type); + ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, + "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), + "\""); + } + } + + // check uniform variables + ORT_RETURN_IF(program.UniformVariables().size() != uniform_variables.size(), + "Size of uniform_value variables mismatch in program \"", program.Name(), + "\", Expected: ", uniform_variables.size(), + ", Actual: ", program.UniformVariables().size()); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_uniform_variables = program.UniformVariables().size(); + for (size_t i = 0; i < num_uniform_variables; ++i) { + const auto& uniform_value = program.UniformVariables()[i]; + const auto& definition = uniform_variables[i]; + ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, + "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.data_type, + ", Actual: ", uniform_value.data_type); + } + } + } + + uint32_t x = program.DispatchGroupSizeX(); + uint32_t y = program.DispatchGroupSizeY(); + uint32_t z = program.DispatchGroupSizeZ(); + ORT_RETURN_IF_ERROR(program_mgr_->NormalizeDispatchGroupSize(x, y, z)); + + bool is_1d_dispatch = (y == 1 && z == 1); + + auto key = CalculateProgramCacheKey(program, is_1d_dispatch); + + if (is_profiling_) { + PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), + program.Name(), + key, + inputs, + outputs); + pending_kernels_.emplace_back(std::move(pending_kernel_info)); + } + + LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")"; + + const auto* program_artifact = program_mgr_->Get(key); + if (program_artifact == nullptr) { + wgpu::ComputePipeline compute_pipeline; + std::vector shape_uniform_ranks; + auto status = program_mgr_->Build(program, + metadata, +#ifndef NDEBUG // if debug build + key, +#endif + x, + y, + z, + compute_pipeline, + shape_uniform_ranks); + ORT_RETURN_IF_ERROR(status); + program_artifact = program_mgr_->Set(key, ProgramArtifact{program, + std::move(compute_pipeline), + std::move(shape_uniform_ranks)}); +#ifndef NDEBUG // if debug build + ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); +#endif + } + + // prepare shape uniforms for shader variables (if any) and user defined uniforms + std::vector shape_uniforms; + shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); + if (ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size() + program.Indices().size(), + "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), + ") does not match current program (input: ", inputs.size(), + ", output: ", outputs.size(), + ", indices: ", program.Indices().size(), ")"); + } + + auto append_shape_uniforms = [&shape_uniforms, program_artifact](size_t i, const TensorShape& shape) { + SafeInt expected_rank = program_artifact->shape_uniform_ranks[i]; + if (expected_rank > 0) { + ORT_RETURN_IF(expected_rank != shape.NumDimensions(), + "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", (int)expected_rank, + ", Actual: ", shape.NumDimensions()); + + std::vector dims(expected_rank); + std::vector stride(expected_rank - 1); + for (size_t j = 0; j < expected_rank; ++j) { + dims[j] = SafeInt(shape[j]); + if (j < expected_rank - 1) { + stride[j] = SafeInt(shape.SizeFromDimension(j + 1)); + } + } + + shape_uniforms.emplace_back(gsl::make_span(dims)); + if (expected_rank > 1) { + shape_uniforms.emplace_back(gsl::make_span(stride)); + } + } + return Status::OK(); + }; + + for (size_t i = 0; i < inputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i, + inputs[i].use_override_shape ? inputs[i].override_shape : inputs[i].tensor->Shape())); + } + for (size_t i = 0; i < outputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size(), + outputs[i].use_override_shape ? outputs[i].override_shape : outputs[i].tensor->Shape())); + } + for (size_t i = 0; i < program.Indices().size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size() + outputs.size(), program.Indices()[i])); + } + + const size_t uniform_count = shape_uniforms.size() + program.UniformVariables().size(); + size_t current_offset = 0; + std::vector> uniform_and_offsets; + uniform_and_offsets.reserve(uniform_count); + for (size_t i = 0; i < uniform_count; i++) { + const auto& uniform = i < shape_uniforms.size() ? shape_uniforms[i] + : program.UniformVariables()[i - shape_uniforms.size()]; + size_t length = uniform.length; + if (length == 0) { // skip zero-length uniform + continue; + } + + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // https://www.w3.org/TR/WGSL/#alignof + size_t base_alignment = is_f16 + ? (length > 4 ? 16 : length > 2 ? 8 + : length * element_size) + : (length > 2 ? 16 : length * element_size); + size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; + + current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + uniform_and_offsets.emplace_back(uniform, current_offset); + + // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). + // For float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). + size_t element_per_struct = is_f16 ? 8 : 4; + current_offset += + length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + } + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. + const size_t max_alignment_of_field = 16; + const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; + + WGPUBuffer uniform_buffer = nullptr; + if (uniform_buffer_total_size > 0) { + std::vector uniform_data_buffer(uniform_buffer_total_size); + + for (auto const& [uniform, offset] : uniform_and_offsets) { + memcpy(uniform_data_buffer.data() + offset, uniform.data.data(), uniform.data.size()); + } + + uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + + WriteTimestamp(num_pending_dispatches_ * 2); + + uint32_t entry_index = 0; + std::vector bind_group_entries; + for (const auto& input : inputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(const_cast(input.tensor->DataRaw()))}); + } + for (const auto& output : outputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output.tensor->MutableDataRaw())}); + } + if (uniform_buffer) { + bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer}); + } + + wgpu::BindGroupDescriptor bind_group_desc{}; + bind_group_desc.layout = program_artifact->compute_pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = bind_group_entries.size(); + bind_group_desc.entries = bind_group_entries.data(); + bind_group_desc.label = program_artifact->name.c_str(); + + auto bind_group = Device().CreateBindGroup(&bind_group_desc); + + // TODO support graph capture + + compute_pass_encoder.SetPipeline(program_artifact->compute_pipeline); + compute_pass_encoder.SetBindGroup(0, bind_group); + compute_pass_encoder.DispatchWorkgroups(x, y, z); + + if (uniform_buffer) { + buffer_mgr_->Release(uniform_buffer); + } + + WriteTimestamp(num_pending_dispatches_ * 2 + 1); + + ++num_pending_dispatches_; + + if (num_pending_dispatches_ >= max_num_pending_dispatches_ || + (is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) { + EndComputePass(); + } + if (num_pending_dispatches_ >= max_num_pending_dispatches_) { + Flush(); + num_pending_dispatches_ = 0; + } + + return Status::OK(); +} + +std::vector WebGpuContext::GetEnabledAdapterToggles() const { + // See the description of all the toggles in toggles.cpp + // "use_dxc" for Shader Model 6+ features (e.g. float16) + // "allow_unsafe_apis" for chromium experimental features + constexpr const char* toggles[] = { + "use_dxc", + "allow_unsafe_apis", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetEnabledDeviceToggles() const { + // Enable / disable other toggles that may affect the performance. + // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" + constexpr const char* toggles[] = { + "skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled" + "disable_robustness", + "d3d_disable_ieee_strictness", + }; + return std::vector(ValidationMode() >= ValidationMode::WGPUOnly + ? std::begin(toggles) + 1 + : std::begin(toggles), + std::end(toggles)); +} + +std::vector WebGpuContext::GetDisabledDeviceToggles() const { + constexpr const char* toggles[] = { + "lazy_clear_resource_on_first_use", + "timestamp_quantization", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const { + wgpu::RequiredLimits required_limits{}; + wgpu::SupportedLimits adapter_limits; + ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); + + required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; + required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; + required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + +void WebGpuContext::WriteTimestamp(uint32_t query_index) { + if (!is_profiling_ || query_type_ != TimestampQueryType::InsidePasses) { + return; + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + compute_pass_encoder.WriteTimestamp(query_set_, query_index); +} + +void WebGpuContext::StartProfiling() { + if (query_type_ == TimestampQueryType::None) { + return; + } + + is_profiling_ = true; + + const uint32_t query_count = max_num_pending_dispatches_ * 2; + + if (!query_set_) { + // Create query set + wgpu::QuerySetDescriptor querySetDescriptor; + querySetDescriptor.count = query_count; + querySetDescriptor.type = wgpu::QueryType::Timestamp; + query_set_ = device_.CreateQuerySet(&querySetDescriptor); + } + + if (!query_resolve_buffer_) { + // Create resolve buffer + wgpu::BufferDescriptor bufferDescriptor; + bufferDescriptor.size = query_count * sizeof(uint64_t); + bufferDescriptor.usage = wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc | + wgpu::BufferUsage::CopyDst; + query_resolve_buffer_ = device_.CreateBuffer(&bufferDescriptor); + } +} + +void WebGpuContext::CollectProfilingData(profiling::Events& events) { + if (!pending_queries_.empty()) { + for (const auto& pending_query : pending_queries_) { + const auto& pending_kernels = pending_query.kernels; + const auto& query_read_buffer = pending_query.query_buffer; + + ORT_ENFORCE(Wait(query_read_buffer.MapAsync(wgpu::MapMode::Read, + 0, + query_read_buffer.GetSize(), + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + auto mapped_data = static_cast(query_read_buffer.GetConstMappedRange()); + + for (size_t i = 0; i < pending_kernels.size(); i++) { + const PendingKernelInfo& pending_kernel_info = pending_kernels[i]; + const auto& inputs = pending_kernel_info.inputs; + const auto& outputs = pending_kernel_info.outputs; + + SS(shapes, 128); + for (size_t s = 0; s < inputs.size(); s++) { + const auto& input = inputs[s]; + shapes << "inputs[" << s << "] = " << input.override_shape.ToString() << " "; + } + for (size_t s = 0; s < outputs.size(); s++) { + const auto& output = outputs[s]; + shapes << "outputs[" << s << "] = " << output.override_shape.ToString() << " "; + } + + if (gpu_timestamp_offset_ == 0) { + gpu_timestamp_offset_ = mapped_data[i * 2]; + // TODO: apply CPU-GPU time offset so that timestamps are aligned + } + uint64_t start_time = mapped_data[i * 2] - gpu_timestamp_offset_; + uint64_t end_time = mapped_data[i * 2 + 1] - gpu_timestamp_offset_; + + const std::unordered_map& event_args = { + {"shapes", SS_GET(shapes)}, + {"cache_key", pending_kernel_info.cache_key}, + }; + + profiling::EventRecord event(profiling::API_EVENT, + -1, + -1, + pending_kernel_info.name, + static_cast(std::round(start_time / 1000.0)), + static_cast(std::round((end_time - start_time) / 1000.0)), + event_args); + events.emplace_back(std::move(event)); + } + + query_read_buffer.Unmap(); + query_read_buffer.Destroy(); + } + + pending_queries_.clear(); + } + + is_profiling_ = false; +} + +void WebGpuContext::EndProfiling(TimePoint /* tp */, profiling::Events& events, profiling::Events& cached_events) { + // This function is called when no active inference is ongoing. + ORT_ENFORCE(!is_profiling_, "Profiling is ongoing in an inference run."); + + if (query_type_ != TimestampQueryType::None) { + // No pending kernels or queries should be present at this point. They should have been collected in CollectProfilingData. + ORT_ENFORCE(pending_kernels_.empty() && pending_queries_.empty(), "Pending kernels or queries are not empty."); + + events.insert(events.end(), + std::make_move_iterator(cached_events.begin()), + std::make_move_iterator(cached_events.end())); + + cached_events.clear(); + } else { + LOGS_DEFAULT(WARNING) << "TimestampQuery is not supported in this device."; + } +} + +void WebGpuContext::Flush() { + if (!current_command_encoder_) { + return; + } + + EndComputePass(); + + if (is_profiling_ && num_pending_dispatches_ > 0) { + uint32_t query_count = num_pending_dispatches_ * 2; + current_command_encoder_.ResolveQuerySet( + query_set_, + 0, + query_count, + query_resolve_buffer_, + 0); + + wgpu::BufferDescriptor bufferDescriptor; + bufferDescriptor.size = query_count * sizeof(uint64_t); + bufferDescriptor.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst; + wgpu::Buffer query_read_buffer = device_.CreateBuffer(&bufferDescriptor); + + current_command_encoder_.CopyBufferToBuffer( + query_resolve_buffer_, + 0, + query_read_buffer, + 0, + query_count * sizeof(uint64_t)); + + pending_queries_.emplace_back(std::move(pending_kernels_), query_read_buffer); + pending_kernels_.clear(); + } + + auto command_buffer = current_command_encoder_.Finish(); + Device().GetQueue().Submit(1, &command_buffer); + BufferManager().RefreshPendingBuffers(); + current_command_encoder_ = nullptr; + num_pending_dispatches_ = 0; +} + +std::unordered_map> WebGpuContextFactory::contexts_; +std::mutex WebGpuContextFactory::mutex_; + +WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode) { + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device."); + } else { + // for context ID > 0, user must provide custom WebGPU instance, adapter and device. + ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr, + "WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device."); + } + + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + if (it == contexts_.end()) { + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, validation_mode)); + it = contexts_.emplace(context_id, std::move(context)).first; + } else if (context_id != 0) { + ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, + "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device."); + } + return *it->second; +} + +WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); + + return *it->second; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h new file mode 100644 index 0000000000000..635204057b0f2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/program_manager.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { +class WebGpuContext; +class ComputeContext; +class ProgramBase; + +class WebGpuContextFactory { + public: + static WebGpuContext& CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode); + static WebGpuContext& GetContext(int context_id); + + private: + WebGpuContextFactory() {} + + static std::unordered_map> contexts_; + static std::mutex mutex_; +}; + +// Class WebGpuContext includes all necessary resources for the context. +class WebGpuContext final { + public: + void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table); + + Status Wait(wgpu::Future f); + + const wgpu::Adapter& Adapter() const { return adapter_; } + const wgpu::Device& Device() const { return device_; } + + const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } + const wgpu::Limits& DeviceLimits() const { return device_limits_; } + + const wgpu::CommandEncoder& GetCommandEncoder() { + if (!current_command_encoder_) { + current_command_encoder_ = device_.CreateCommandEncoder(); + } + return current_command_encoder_; + } + + const wgpu::ComputePassEncoder& GetComputePassEncoder() { + if (!current_compute_pass_encoder_) { + auto& command_encoder = GetCommandEncoder(); + + wgpu::ComputePassDescriptor compute_pass_desc{}; + + if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) { + wgpu::ComputePassTimestampWrites timestampWrites = { + query_set_, num_pending_dispatches_ * 2, num_pending_dispatches_ * 2 + 1}; + compute_pass_desc.timestampWrites = ×tampWrites; + } + + current_compute_pass_encoder_ = command_encoder.BeginComputePass(&compute_pass_desc); + } + return current_compute_pass_encoder_; + } + + void EndComputePass() { + if (current_compute_pass_encoder_) { + current_compute_pass_encoder_.End(); + current_compute_pass_encoder_ = nullptr; + } + } + + void Flush(); + + webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } + + inline webgpu::ValidationMode ValidationMode() const { + return validation_mode_; + } + + void StartProfiling(); + void CollectProfilingData(profiling::Events& events); + void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); + + Status Run(ComputeContext& context, const ProgramBase& program); + + private: + enum class TimestampQueryType { + None = 0, + InsidePasses, + AtPasses + }; + + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode) + : query_type_{TimestampQueryType::None}, instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode} {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + + std::vector GetEnabledAdapterToggles() const; + std::vector GetEnabledDeviceToggles() const; + std::vector GetDisabledDeviceToggles() const; + std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const; + wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) const; + void WriteTimestamp(uint32_t query_index); + + TimestampQueryType query_type_; + wgpu::QuerySet query_set_; + wgpu::Buffer query_resolve_buffer_; + + struct PendingKernelInfo { + PendingKernelInfo(std::string_view kernel_name, + std::string_view program_name, + std::string_view cache_key, + const std::vector& inputs, + const std::vector& outputs) + : name{absl::StrJoin({kernel_name, program_name}, "_")}, cache_key{cache_key}, inputs{inputs}, outputs{outputs} {} + + PendingKernelInfo(PendingKernelInfo&&) = default; + PendingKernelInfo& operator=(PendingKernelInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo); + + std::string name; + std::string cache_key; + std::vector inputs; + std::vector outputs; + }; + + struct PendingQueryInfo { + PendingQueryInfo(std::vector&& kernels, wgpu::Buffer query_buffer) + : kernels{std::move(kernels)}, query_buffer{query_buffer} {} + + PendingQueryInfo(PendingQueryInfo&&) = default; + PendingQueryInfo& operator=(PendingQueryInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingQueryInfo); + + std::vector kernels; + wgpu::Buffer query_buffer; + }; + + // info of kernels pending submission for a single batch + std::vector pending_kernels_; + // info of queries pending appending to profiling events + std::vector pending_queries_; + + uint64_t gpu_timestamp_offset_ = 0; + bool is_profiling_ = false; + + std::once_flag init_flag_; + + wgpu::Instance instance_; + wgpu::Adapter adapter_; + wgpu::Device device_; + + webgpu::ValidationMode validation_mode_; + + wgpu::AdapterInfo adapter_info_; + wgpu::Limits device_limits_; + + wgpu::CommandEncoder current_command_encoder_; + wgpu::ComputePassEncoder current_compute_pass_encoder_; + + std::unique_ptr buffer_mgr_; + std::unique_ptr program_mgr_; + friend class WebGpuContextFactory; + + uint32_t num_pending_dispatches_ = 0; + const uint32_t max_num_pending_dispatches_ = 16; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 00ebdd5583d2e..821c60ab602ea 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -3,6 +3,9 @@ #include "core/providers/webgpu/webgpu_execution_provider.h" +#ifdef __EMSCRIPTEN__ +#include +#endif #include #include #include @@ -13,6 +16,7 @@ #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #endif +#include "allocator.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/fallback_cpu_capability.h" @@ -20,6 +24,10 @@ #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_profiler.h" + namespace onnxruntime { namespace webgpu { @@ -65,6 +73,329 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Memcpy); +#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, End, Op)> + +#define KERNEL_CREATE_INFO(Start, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, Op)> + +#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, type, Op)> + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Abs); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Abs); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Neg); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Neg); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Floor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Floor); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Ceil); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Ceil); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Reciprocal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Reciprocal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sqrt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sqrt); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Exp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Exp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Erf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Erf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, HardSigmoid); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Log); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Log); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Sin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Cos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Tan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Asin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Acos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Atan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Sinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Cosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Asinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Acosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Atanh); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, Not); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, Elu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Relu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, Gelu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Sub); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Mul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Div); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 11, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Pow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 10, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Equal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Greater); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Less); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LessOrEqual); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, 18, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Shape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 5, 12, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, Conv); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 3, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 4, 10, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 19, Resize); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, InstanceNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, float, Einsum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -72,6 +403,320 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + + // element-wise operators + // unary - math + KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), + KERNEL_CREATE_INFO(13, Abs), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + KERNEL_CREATE_INFO(13, Neg), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + KERNEL_CREATE_INFO(13, Floor), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + KERNEL_CREATE_INFO(13, Ceil), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + KERNEL_CREATE_INFO(13, Reciprocal), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + KERNEL_CREATE_INFO(13, Sqrt), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + KERNEL_CREATE_INFO(13, Exp), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + KERNEL_CREATE_INFO(13, Erf), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + KERNEL_CREATE_INFO(13, Log), + + KERNEL_CREATE_INFO(7, Sin), + KERNEL_CREATE_INFO(7, Cos), + KERNEL_CREATE_INFO(7, Tan), + KERNEL_CREATE_INFO(7, Asin), + KERNEL_CREATE_INFO(7, Acos), + KERNEL_CREATE_INFO(7, Atan), + KERNEL_CREATE_INFO(9, Sinh), + KERNEL_CREATE_INFO(9, Cosh), + KERNEL_CREATE_INFO(9, Asinh), + KERNEL_CREATE_INFO(9, Acosh), + KERNEL_CREATE_INFO(9, Atanh), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), + KERNEL_CREATE_INFO(13, Tanh), + KERNEL_CREATE_INFO(1, Not), + + KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), + KERNEL_CREATE_INFO(19, Cast), + + // // activations + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + KERNEL_CREATE_INFO(6, Elu), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + KERNEL_CREATE_INFO(14, Relu), + KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + KERNEL_CREATE_INFO(16, LeakyRelu), + KERNEL_CREATE_INFO(10, ThresholdedRelu), + KERNEL_CREATE_INFO(20, Gelu), + + // // binary - math + KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + KERNEL_CREATE_INFO(14, Add), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + KERNEL_CREATE_INFO(14, Sub), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + KERNEL_CREATE_INFO(14, Mul), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + KERNEL_CREATE_INFO(14, Div), + KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + KERNEL_CREATE_INFO(15, Pow), + KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), + KERNEL_CREATE_INFO(19, Equal), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + KERNEL_CREATE_INFO(13, Greater), + KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + KERNEL_CREATE_INFO(16, GreaterOrEqual), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + KERNEL_CREATE_INFO(13, Less), + KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + KERNEL_CREATE_INFO(16, LessOrEqual), + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -93,8 +738,76 @@ std::unique_ptr RegisterKernels() { using namespace webgpu; -WebGpuExecutionProvider::WebGpuExecutionProvider() - : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)} {} +WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, + WebGpuContext& context, + WebGpuExecutionProviderInfo&& info) + : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, + context_id_{context_id}, + context_{context}, + preferred_data_layout_{info.data_layout}, + force_cpu_node_names_{std::move(info.force_cpu_node_names)}, + enable_graph_capture_{info.enable_graph_capture} { +} + +std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) { + return std::make_unique(context_); + }, + 0, false); + return std::vector{CreateAllocator(gpuBufferAllocatorCreationInfo)}; +} + +std::vector> WebGpuExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + InlinedVector candidates; + // `tenative_candidates` is a subset of `candidates`. + InlinedVector tenative_candidates; + for (auto& node_index : graph.GetNodesInTopologicalOrder()) { + const auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; + + const auto& node = *p_node; + if (!node.GetExecutionProviderType().empty()) { + // If the node was added by layout transformer, do not move it to CPU + if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) { + candidates.push_back(node.Index()); + } + continue; + } + + const KernelCreateInfo* webgpu_kernel_def = kernel_lookup.LookUpKernel(node); + // none of the provided registries has a webgpu kernel for this node + if (webgpu_kernel_def == nullptr) { + LOGS(*GetLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.OpType() << " node name: " << node.Name(); + continue; + } + + // TODO: currently this lookup is O(N). If the list becomes large we should optimize this. + if (std::find(force_cpu_node_names_.cbegin(), + force_cpu_node_names_.cend(), + node.Name()) != force_cpu_node_names_.cend()) { + LOGS(*GetLogger(), INFO) << "Force CPU execution for node: " << node.Name(); + continue; + } + candidates.push_back(node.Index()); + tenative_candidates.push_back(node.Index()); + } + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + std::vector> result; + for (auto& node_index : candidates) { + if (cpu_nodes.count(node_index) > 0) { + continue; + } + + auto sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node_index); + result.emplace_back(std::make_unique(std::move(sub_graph))); + } + return result; +} std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { static std::shared_ptr registry = webgpu::RegisterKernels(); @@ -102,7 +815,68 @@ std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() con return registry; } +std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { + return std::make_unique(context_); +} + WebGpuExecutionProvider::~WebGpuExecutionProvider() { } +std::unique_ptr WebGpuExecutionProvider::GetProfiler() { + auto profiler = std::make_unique(context_); + profiler_ = profiler.get(); + return profiler; +} + +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { + if (profiler_->Enabled()) { + context_.StartProfiling(); + } + + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + } + return Status::OK(); +} + +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + // is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + context_.Flush(); + + if (profiler_->Enabled()) { + context_.CollectProfilingData(profiler_->Events()); + } + + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool WebGpuExecutionProvider::IsGraphCaptured(int) const { + return is_graph_captured_; +} + +Status WebGpuExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); + ORT_ENFORCE(false); + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 537ecb9301f67..336395a1dd0dd 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -9,6 +9,7 @@ #include "core/graph/constants.h" #include "core/providers/providers.h" +struct pthreadpool; namespace onnxruntime { namespace webgpu { @@ -16,22 +17,78 @@ namespace webgpu { template KernelCreateInfo BuildKernelCreateInfo(); +class WebGpuContext; +enum class BufferCacheMode; +class WebGpuProfiler; } // namespace webgpu +struct WebGpuExecutionProviderInfo { + WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture) + : data_layout{data_layout}, + enable_graph_capture{enable_graph_capture}, + storage_buffer_cache_mode{}, + uniform_buffer_cache_mode{}, + query_resolve_buffer_cache_mode{}, + default_buffer_cache_mode{} {} + WebGpuExecutionProviderInfo(WebGpuExecutionProviderInfo&&) = default; + WebGpuExecutionProviderInfo& operator=(WebGpuExecutionProviderInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderInfo); + + DataLayout data_layout; + bool enable_graph_capture; + webgpu::BufferCacheMode storage_buffer_cache_mode; + webgpu::BufferCacheMode uniform_buffer_cache_mode; + webgpu::BufferCacheMode query_resolve_buffer_cache_mode; + webgpu::BufferCacheMode default_buffer_cache_mode; + std::vector force_cpu_node_names; +}; + class WebGpuExecutionProvider : public IExecutionProvider { public: - WebGpuExecutionProvider(); + WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& info); ~WebGpuExecutionProvider() override; + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const override; + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; - DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } // WebGPU EP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to // work, and concurrent run with async function may mess up the states and cause undefined behavior. bool ConcurrentRunSupported() const override { return false; } + + std::vector CreatePreferredAllocators() override; + + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + + // WebGPU EP reuses the Device ID as the key to get the WebGpuContext instance. + int GetDeviceId() const override { return context_id_; } + + std::unique_ptr GetProfiler() override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); + int context_id_; + webgpu::WebGpuContext& context_; + webgpu::WebGpuProfiler* profiler_ = nullptr; + DataLayout preferred_data_layout_; + std::vector force_cpu_node_names_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h new file mode 100644 index 0000000000000..d7682e751d9e4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/compute_context.h" + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +// ----------------------------------------------------------------------- +// Base class for WebGPU kernels +// ----------------------------------------------------------------------- +class WebGpuKernel : public OpKernel { + public: + explicit WebGpuKernel(const OpKernelInfo& info) + : OpKernel(info) { + } + + Status Compute(OpKernelContext* p_op_kernel_context) const override { + ComputeContext context{*p_op_kernel_context}; + + context.PushErrorScope(); + Status s = ComputeInternal(context); + ORT_RETURN_IF_ERROR(context.PopErrorScope()); + + return s; + } + + virtual Status ComputeInternal(ComputeContext& context) const = 0; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_profiler.cc b/onnxruntime/core/providers/webgpu/webgpu_profiler.cc new file mode 100644 index 0000000000000..ce973987e593a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_profiler.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_profiler.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +WebGpuProfiler::WebGpuProfiler(WebGpuContext& context) : context_{context} {} + +bool WebGpuProfiler::StartProfiling(TimePoint) { + enabled_ = true; + return true; +} + +void WebGpuProfiler::EndProfiling(TimePoint tp, onnxruntime::profiling::Events& events) { + context_.EndProfiling(tp, events, events_); + enabled_ = false; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_profiler.h b/onnxruntime/core/providers/webgpu/webgpu_profiler.h new file mode 100644 index 0000000000000..d826d295a3842 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_profiler.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/profiler_common.h" + +namespace onnxruntime { + +namespace webgpu { +class WebGpuContext; + +class WebGpuProfiler final : public onnxruntime::profiling::EpProfiler { + public: + WebGpuProfiler(WebGpuContext& context); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuProfiler); + ~WebGpuProfiler() {} + bool StartProfiling(TimePoint) override; + void EndProfiling(TimePoint, onnxruntime::profiling::Events&) override; + void Start(uint64_t) override { + } + void Stop(uint64_t) override { + } + inline bool Enabled() const { return enabled_; } + inline onnxruntime::profiling::Events& Events() { return events_; } + + private: + WebGpuContext& context_; + bool enabled_{false}; + onnxruntime::profiling::Events events_; // cached GPU events +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 1a1f1a438c750..803c12274c08f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -4,21 +4,194 @@ #include #include "core/framework/error_code_helper.h" -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/ort_apis.h" + +#include "core/providers/webgpu/webgpu_provider_options.h" +using namespace onnxruntime::webgpu::options; namespace onnxruntime { struct WebGpuProviderFactory : IExecutionProviderFactory { - WebGpuProviderFactory() {} + WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& webgpu_ep_info) + : context_id_{context_id}, context_{context}, info_{std::move(webgpu_ep_info)} { + } std::unique_ptr CreateProvider() override { - return std::make_unique(); + return std::make_unique(context_id_, context_, std::move(info_)); } + + private: + int context_id_; + webgpu::WebGpuContext& context_; + WebGpuExecutionProviderInfo info_; }; -std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions&) { - return std::make_shared(); +std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { + // + // STEP.1 - prepare WebGpuExecutionProviderInfo + // + WebGpuExecutionProviderInfo webgpu_ep_info{ + // preferred layout is NHWC by default + DataLayout::NHWC, + // graph capture feature is disabled by default + false, + }; + + std::string preferred_layout_str; + if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (preferred_layout_str == kPreferredLayout_NHWC) { + webgpu_ep_info.data_layout = DataLayout::NHWC; + } else if (preferred_layout_str == kPreferredLayout_NCHW) { + webgpu_ep_info.data_layout = DataLayout::NCHW; + } else { + ORT_THROW("Invalid preferred layout: ", preferred_layout_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_info.data_layout) << " (parsed from \"" + << preferred_layout_str << "\")"; + + std::string enable_graph_capture_str; + if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (enable_graph_capture_str == kEnableGraphCapture_ON) { + webgpu_ep_info.enable_graph_capture = true; + } else if (enable_graph_capture_str == kEnableGraphCapture_OFF) { + webgpu_ep_info.enable_graph_capture = false; + } else { + ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; + + auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, + webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { + std::string buffer_cache_mode_str; + if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { + return webgpu::BufferCacheMode::Disabled; + } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { + return webgpu::BufferCacheMode::LazyRelease; + } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { + return webgpu::BufferCacheMode::Simple; + } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { + return webgpu::BufferCacheMode::Bucket; + } else { + ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + } + } else { + return default_value; + } + }; + + webgpu_ep_info.storage_buffer_cache_mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << webgpu_ep_info.storage_buffer_cache_mode; + + webgpu_ep_info.uniform_buffer_cache_mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::Simple); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << webgpu_ep_info.uniform_buffer_cache_mode; + + webgpu_ep_info.query_resolve_buffer_cache_mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << webgpu_ep_info.query_resolve_buffer_cache_mode; + + webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; + + webgpu::ValidationMode validation_mode = +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::WGPUOnly // for release build, only enable WGPU validation. +#endif // !NDEBUG + ; + std::string validation_mode_str; + if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (validation_mode_str == kValidationMode_Disabled) { + validation_mode = webgpu::ValidationMode::Disabled; + } else if (validation_mode_str == kValidationMode_wgpuOnly) { + validation_mode = webgpu::ValidationMode::WGPUOnly; + } else if (validation_mode_str == kValidationMode_basic) { + validation_mode = webgpu::ValidationMode::Basic; + } else if (validation_mode_str == kValidationMode_full) { + validation_mode = webgpu::ValidationMode::Full; + } else { + ORT_THROW("Invalid validation mode: ", validation_mode_str); + } + } + + // parse force CPU node names + // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. + // each line is a node name that will be forced to run on CPU. + std::string force_cpu_node_names_str; + if (config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { + std::vector force_cpu_node_names; + + // split the string by EOL (\n or \r\n) + std::istringstream ss(force_cpu_node_names_str); + std::string line; + while (std::getline(ss, line)) { + // skip empty lines + if (line.empty()) { + continue; + } + + force_cpu_node_names.push_back(line); + } + + webgpu_ep_info.force_cpu_node_names = std::move(force_cpu_node_names); + } + + // + // STEP.2 - prepare WebGpuContext + // + int context_id = 0; + std::string context_id_str; + if (config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); + } + + size_t webgpu_instance = 0; + std::string webgpu_instance_str; + if (config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); + } + + size_t webgpu_adapter = 0; + std::string webgpu_adapter_str; + if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { + static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec); + } + + size_t webgpu_device = 0; + std::string webgpu_device_str; + if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); + } + + size_t dawn_proc_table = 0; + std::string dawn_proc_table_str; + if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); + } + + auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, + reinterpret_cast(webgpu_instance), + reinterpret_cast(webgpu_adapter), + reinterpret_cast(webgpu_device), + validation_mode); + context.Initialize(webgpu_ep_info, reinterpret_cast(dawn_proc_table)); + + return std::make_shared(context_id, context, std::move(webgpu_ep_info)); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 6257a85d45760..e0030a3ec2a11 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -8,6 +8,8 @@ #include "core/framework/provider_options.h" #include "core/providers/providers.h" +#include "core/providers/webgpu/webgpu_provider_options.h" + namespace onnxruntime { struct ConfigOptions; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h new file mode 100644 index 0000000000000..63befedffea84 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webgpu { +namespace options { + +// The following are the options that can be set in the WebGPU provider options. + +constexpr const char* kPreferredLayout = "WebGPU:preferredLayout"; +constexpr const char* kEnableGraphCapture = "WebGPU:enableGraphCapture"; + +constexpr const char* kDawnProcTable = "WebGPU:dawnProcTable"; + +constexpr const char* kDeviceId = "WebGPU:deviceId"; +constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance"; +constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter"; +constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice"; + +constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode"; +constexpr const char* kUniformBufferCacheMode = "WebGPU:uniformBufferCacheMode"; +constexpr const char* kQueryResolveBufferCacheMode = "WebGPU:queryResolveBufferCacheMode"; +constexpr const char* kDefaultBufferCacheMode = "WebGPU:defaultBufferCacheMode"; + +constexpr const char* kValidationMode = "WebGPU:validationMode"; + +constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames"; + +// The following are the possible values for the provider options. + +constexpr const char* kPreferredLayout_NCHW = "NCHW"; +constexpr const char* kPreferredLayout_NHWC = "NHWC"; + +constexpr const char* kEnableGraphCapture_ON = "1"; +constexpr const char* kEnableGraphCapture_OFF = "0"; + +constexpr const char* kBufferCacheMode_Disabled = "disabled"; +constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; +constexpr const char* kBufferCacheMode_Simple = "simple"; +constexpr const char* kBufferCacheMode_Bucket = "bucket"; + +constexpr const char* kValidationMode_Disabled = "disabled"; +constexpr const char* kValidationMode_wgpuOnly = "wgpuOnly"; +constexpr const char* kValidationMode_basic = "basic"; +constexpr const char* kValidationMode_full = "full"; + +} // namespace options +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h new file mode 100644 index 0000000000000..ff66cd535399e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +using SupportedNumberTypes = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t>; + +using SupportedFloats = + TypeList< + float, + MLFloat16>; + +inline const std::vector& WebGpuSupportedNumberTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +inline const std::vector& WebGpuSupportedFloatTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index f4f10dc4b4b97..d05fba192820a 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -26,6 +26,8 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice: return C.OrtDevice.cpu() elif device_type == "dml": return C.OrtDevice.dml() + elif device_type == "webgpu": + return C.OrtDevice.webgpu() elif device_type == "ort": return C.get_ort_device(device_index).device_type() else: diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7af659851e4f8..b8c8293746533 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1580,6 +1580,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def_static("fpga", []() { return OrtDevice::FPGA; }) .def_static("npu", []() { return OrtDevice::NPU; }) .def_static("dml", []() { return OrtDevice::GPU; }) + .def_static("webgpu", []() { return OrtDevice::GPU; }) .def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; }); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 225931533615d..fa4916f8922f2 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -24,7 +24,7 @@ struct OrtStatus { char msg[1]; // a null-terminated string }; -#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_OPENVINO BACKEND_TVM BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML BACKEND_CANN +#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_OPENVINO BACKEND_TVM BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML BACKEND_CANN BACKEND_WEBGPU #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/providers.h" #include "core/providers/provider_factory_creators.h" @@ -111,6 +111,12 @@ struct OrtStatus { #define BACKEND_CANN "" #endif +#if USE_WEBGPU +#define BACKEND_WEBGPU "-WEBGPU" +#else +#define BACKEND_WEBGPU "" +#endif + #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 5cf749dc4c97c..a7d751f4472fc 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -41,7 +41,7 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& bias_dims, const std::vector& output_dims, @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector& input_data, const std:: execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -142,7 +144,7 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); #endif RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); @@ -245,8 +247,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); } -// CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) +// CUDA, ROCm and WebGPU only for Float16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { int batch_size = 1; int sequence_length = 2; @@ -381,7 +383,10 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); } +#endif +// CUDA and ROCm only for BFloat16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 9ecaa16a2ab24..f1e0e99a5fb79 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -120,7 +120,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { @@ -134,7 +134,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { @@ -192,7 +192,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, - kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider}); + kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { @@ -207,7 +207,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { @@ -222,7 +222,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializers) { diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 438a1100ca95c..46082e1b0cd31 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -6,7 +6,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) constexpr auto k_epsilon_default = 1e-5f; constexpr auto k_random_data_min = -10.0f; constexpr auto k_random_data_max = 10.0f; @@ -65,8 +65,8 @@ static void TestLayerNorm(const std::vector& x_dims, std::vector Y_data = FillZeros(n_x_m_dims); test.AddOutput("output", n_x_m_dims, Y_data); -#ifndef USE_DML - // DML doesn't support more than one output for these ops yet +#if !defined(USE_DML) && !defined(USE_WEBGPU) + // DML and WebGPU don't support more than one output for these ops yet const std::vector& stats_dims = keep_dims ? n_and_ones_dims : n_dims; std::vector mean_data = FillZeros(stats_dims); std::vector var_data = FillZeros(stats_dims); @@ -84,6 +84,8 @@ static void TestLayerNorm(const std::vector& x_dims, test.CompareWithCPU(kRocmExecutionProvider); #elif USE_DML test.CompareWithCPU(kDmlExecutionProvider); +#elif USE_WEBGPU + test.CompareWithCPU(kWebGpuExecutionProvider); #endif } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 8138829b057f2..eb6e316202da9 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -273,7 +273,11 @@ void TestMatMulNBitsTyped() { base_opts.output_abs_error = 0.1f; } else { if constexpr (std::is_same::value) { +#ifdef USE_WEBGPU + base_opts.output_abs_error = 0.03f; +#else base_opts.output_abs_error = 0.01f; +#endif } } @@ -288,7 +292,7 @@ void TestMatMulNBitsTyped() { RunTest(opts); } -#if !defined(USE_DML) +#if !defined(USE_DML) && !defined(USE_WEBGPU) { TestOptions opts = base_opts; opts.has_g_idx = true; @@ -319,7 +323,7 @@ void TestMatMulNBitsTyped() { opts.has_zero_point = true, opts.zp_is_4bit = false; RunTest(opts); } -#endif // !defined(USE_DML) +#endif // !defined(USE_DML) && !defined(USE_WEBGPU) } TEST(MatMulNBits, Float32_Accuracy0) { @@ -458,7 +462,7 @@ TEST(MatMulNBits, Float16_Accuracy4) { #endif #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) namespace { // Legacy test function. @@ -493,6 +497,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); #endif +#ifdef USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); +#endif RunTest(opts, std::move(execution_providers)); } else { @@ -537,6 +544,9 @@ TEST(MatMulNBits, Float16Large) { // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. float abs_error = 0.3f; +#elif USE_WEBGPU + // See Intel A770 to pass these tests with an absolute error of 0.08. + float abs_error = 0.08f; #else float abs_error = 0.05f; #endif diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 1d167b5dffdb5..6b6799d73fb56 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -49,6 +49,7 @@ static void RunMultiHeadAttentionTest( bool use_float16 = false, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, // not supported in rocm right now. bool disable_dml = false) { kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length); @@ -59,6 +60,7 @@ static void RunMultiHeadAttentionTest( bool enable_rocm = (nullptr != DefaultRocmExecutionProvider(/*test_tunable_op=*/true).get()) && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; if (enable_rocm && !use_float16) { LOGS_DEFAULT(WARNING) << "ROCm MHA only have kernel for half datatype implemented, skip float datatype tests"; @@ -70,7 +72,7 @@ static void RunMultiHeadAttentionTest( enable_rocm = false; } - if (enable_cpu || enable_cuda || enable_rocm || enable_dml) { + if (enable_cpu || enable_cuda || enable_rocm || enable_dml || enable_webgpu) { OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); @@ -266,6 +268,12 @@ static void RunMultiHeadAttentionTest( execution_providers.push_back(DefaultDmlExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + + if (enable_webgpu) { + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } @@ -295,6 +303,7 @@ static void RunMultiHeadAttentionKernel( bool is_static_kv = true, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, bool disable_dml = false) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { @@ -309,7 +318,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -325,7 +335,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -341,7 +352,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -358,7 +370,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } #endif @@ -376,7 +389,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -392,11 +406,30 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } } -static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { +enum RunMultiHeadAttentionTestToggles : uint32_t { + DISABLE_NONE = 0, + DISABLE_CPU = 1 << 0, + DISABLE_CUDA = 1 << 1, + DISABLE_WEBGPU = 1 << 2, +}; +inline RunMultiHeadAttentionTestToggles operator|(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline RunMultiHeadAttentionTestToggles operator&(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) & static_cast(b)); +} + +static void RunMultiHeadAttentionTests(AttentionTestData& data, + RunMultiHeadAttentionTestToggles toggles = DISABLE_NONE) { + bool disable_cpu = toggles & DISABLE_CPU; + bool disable_cuda = toggles & DISABLE_CUDA; + bool disable_webgpu = toggles & DISABLE_WEBGPU; + if (data.fp32_output_data.size() > 0) { constexpr bool use_float16 = false; @@ -407,7 +440,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -420,7 +453,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } #endif @@ -431,7 +464,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } if (data.fp16_output_data.size() > 0) { @@ -443,7 +476,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; @@ -453,7 +486,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -464,7 +497,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #endif @@ -475,7 +508,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_Default; @@ -484,7 +517,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } @@ -503,40 +536,40 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } // This tests qk_head_size != v_head_size @@ -561,7 +594,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, false, true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // TODO (pavignol): Fix this regression @@ -571,7 +604,7 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetCrossAttentionDataWithPast(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } #endif @@ -579,27 +612,27 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU); } TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; GetAttentionDataCutlassAttnBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { // Whisper decoder cross attention without mask and different sequence lengths for Q and K/V AttentionTestData data; GetCrossAttentionData_DiffSequenceLengths(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { @@ -609,10 +642,10 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) RunMultiHeadAttentionTests(data); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // This test is disabled since it is not used in Whisper anymore, and it fails in ROCm. diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 8675a997d29a1..7d5a70148747f 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -67,6 +67,7 @@ static void RunTest( : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); if (enable_cuda && !disable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -77,6 +78,9 @@ static void RunTest( if (tensor_type == TensorType::kFloat && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } if (execution_providers.size() == 0) { // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) return; diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index edf9064bb43c9..b9ca55073d411 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -62,6 +62,8 @@ static void RunOneTest( auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); auto cpu_ep = DefaultCpuExecutionProvider(); + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + std::vector> execution_providers; if (!use_float16) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); @@ -95,10 +97,14 @@ static void RunOneTest( if (cpu_ep != nullptr) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || - rocm_ep != nullptr) { + rocm_ep != nullptr || + webgpu_ep != nullptr) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("skip", skip_dims, ToFloat16(skip_data)); @@ -132,7 +138,9 @@ static void RunOneTest( ToFloat16(sum_output_data)); } - if (dml_ep != nullptr) { + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } else if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); } else if (rocm_ep != nullptr) { execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 61a8f7e23fe87..27c466756445c 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -693,6 +693,9 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #endif #ifdef USE_ROCM ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); +#endif +#ifdef USE_WEBGPU + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultWebGpuExecutionProvider())); #endif ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); ASSERT_STATUS_OK(session_object.Initialize()); @@ -719,7 +722,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { ASSERT_TRUE(lines[size - 1].find("]") != string::npos); std::vector tags = {"pid", "dur", "ts", "ph", "X", "name", "args"}; - bool has_api_info = false; + [[maybe_unused]] bool has_api_info = false; for (size_t i = 1; i < size - 1; ++i) { for (auto& s : tags) { ASSERT_TRUE(lines[i].find(s) != string::npos); @@ -730,14 +733,16 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #ifdef USE_ROCM has_api_info = has_api_info || lines[i].find("Api") != string::npos && lines[i].find("hipLaunch") != string::npos; +#endif +#ifdef USE_WEBGPU + has_api_info = has_api_info || lines[i].find("Api") != string::npos; #endif } } -#if defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING) +// Note that the apple device is a paravirtual device which may not support webgpu timestamp query. So skip the check on it. +#if (defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING)) || (defined(USE_WEBGPU) && !defined(__APPLE__)) ASSERT_TRUE(has_api_info); -#else - ASSERT_TRUE(has_api_info || true); #endif } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index e69c87b2540e5..8f2e5282ede9a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -601,8 +601,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kWebGpuExecutionProvider) { #ifdef USE_WEBGPU - session_options.AppendExecutionProvider( - "WebGPU", {{"intra_op_num_threads", std::to_string(performance_test_config.run_config.intra_op_num_threads)}}); + session_options.AppendExecutionProvider("WebGPU", {}); #else ORT_THROW("WebGPU is not supported in this build\n"); #endif diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index b2e9034653746..d32e286ad933e 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -414,6 +414,28 @@ TEST(MathOpTest, Add_Broadcast_3x2_3x1) { #endif } +TEST(MathOpTest, Add_Broadcast_2x2x2_1x2x2) { + OpTester test("Add"); + + test.AddInput("A", {2, 2, 2}, + {101.0f, 102.0f, + 103.0f, 104.0f, + + 201.0f, 202.0f, + 203.0f, 204.0f}); + test.AddInput("B", {1, 2, 2}, + {010.0f, 020.0f, + 030.0f, 040.0f}); + test.AddOutput("C", {2, 2, 2}, + {111.0f, 122.0f, + 133.0f, 144.0f, + + 211.0f, 222.0f, + 233.0f, 244.0f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(MathOpTest, Add_Broadcast_2x1x4_1x3x1) { OpTester test("Add"); @@ -3181,7 +3203,14 @@ TEST(MathOpTest, Tan) { TEST(MathOpTest, Asin) { OpTester test("Asin"); - float abs_error = DefaultDmlExecutionProvider().get() != nullptr ? 0.0001f : -1.0f; + float abs_error = +#ifdef _WIN32 + // Set abs_error to 0.0001f for built-in function asin() in HLSL based EPs (DML and WebGPU) + DefaultDmlExecutionProvider().get() != nullptr || DefaultWebGpuExecutionProvider().get() != nullptr + ? 0.0001f + : +#endif + -1.0f; TrigFloatTest<::asinf>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}, abs_error); } diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index b517b1a2837f0..5902fbe3ddd6f 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -142,7 +142,7 @@ void RunTestWrapper() { RunTest({2, 1, 3}, {2, 2, 1}); RunTest({2, 1, 3}, {2, 2, 1}, true); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) // _TileMemcpyKernelFromInput, vectorized 4 RunTest({256, 512}, {3, 1}); @@ -253,7 +253,7 @@ TEST(TensorOpTest, TileStringType) { RunTestWrapper(); } TEST(TensorOpTest, TileBoolType) { RunTestWrapperForBool(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(TensorOpTest, TileMLFloat16Type) { RunTestWrapper(); } #endif diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 9b1e87f6ec02e..8fc76da3495a8 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -152,6 +152,9 @@ def create_backend_test(test_name=None): if backend.supports_device("MIGRAPHX"): current_failing_tests += apply_filters(filters, "current_failing_tests_MIGRAPHX") + if backend.supports_device("WEBGPU"): + current_failing_tests += apply_filters(filters, "current_failing_tests_WEBGPU") + # Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA # EPs are available (Windows GPU CI pipeline has this config) - these test will pass because CUDA has higher precedence than DML # and the nodes are assigned to only the CUDA EP (which supports these tests) diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 4b14d50127aa9..401e9f9f5c5b3 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -672,6 +672,30 @@ "^test_nonmaxsuppression_flipped_coordinates_cpu", "^test_nonmaxsuppression_center_point_box_format_cpu" ], + "current_failing_tests_WEBGPU": [ + "^test_layer_normalization_2d_axis0_cpu", + "^test_layer_normalization_2d_axis1_cpu", + "^test_layer_normalization_2d_axis_negative_1_cpu", + "^test_layer_normalization_2d_axis_negative_2_cpu", + "^test_layer_normalization_3d_axis0_epsilon_cpu", + "^test_layer_normalization_3d_axis1_epsilon_cpu", + "^test_layer_normalization_3d_axis2_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_1_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_2_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_3_epsilon_cpu", + "^test_layer_normalization_4d_axis0_cpu", + "^test_layer_normalization_4d_axis1_cpu", + "^test_layer_normalization_4d_axis2_cpu", + "^test_layer_normalization_4d_axis3_cpu", + "^test_layer_normalization_4d_axis_negative_1_cpu", + "^test_layer_normalization_4d_axis_negative_2_cpu", + "^test_layer_normalization_4d_axis_negative_3_cpu", + "^test_layer_normalization_4d_axis_negative_4_cpu", + "^test_layer_normalization_default_axis_cpu", + "^test_gelu_tanh_1_expanded_cpu", + "^test_gelu_tanh_2_expanded_cpu", + "^test_dynamicquantizelinear_expanded_cpu" + ], "current_failing_tests_pure_DML": [ "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu", "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu", diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index d57a22f024d5f..57f748ab8b6bd 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -307,6 +307,10 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { std::unique_ptr DefaultWebGpuExecutionProvider() { #ifdef USE_WEBGPU ConfigOptions config_options{}; + // Disable storage buffer cache + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + webgpu::options::kBufferCacheMode_Disabled) + .IsOK()); return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); #else return nullptr; diff --git a/onnxruntime/test/webgpu/external_dawn/main.cc b/onnxruntime/test/webgpu/external_dawn/main.cc new file mode 100644 index 0000000000000..ed8d2eab94ce9 --- /dev/null +++ b/onnxruntime/test/webgpu/external_dawn/main.cc @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. + +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include + +#include "dawn/native/DawnNative.h" + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + bool no_proc_table = argc > 0 && +#ifdef _WIN32 + wcscmp(L"--no_proc_table", argv[argc - 1]) == 0; +#else + strcmp("--no_proc_table", argv[argc - 1]) == 0; +#endif + + int retval = 0; + Ort::Env env{nullptr}; + try { + env = Ort::Env{ORT_LOGGING_LEVEL_WARNING, "Default"}; + + // model is https://github.com/onnx/onnx/blob/v1.15.0/onnx/backend/test/data/node/test_abs/model.onnx + constexpr uint8_t MODEL_DATA[] = {8, 7, 18, 12, 98, 97, 99, 107, 101, 110, + 100, 45, 116, 101, 115, 116, 58, 73, 10, 11, + 10, 1, 120, 18, 1, 121, 34, 3, 65, 98, + 115, 18, 8, 116, 101, 115, 116, 95, 97, 98, + 115, 90, 23, 10, 1, 120, 18, 18, 10, 16, + 8, 1, 18, 12, 10, 2, 8, 3, 10, 2, + 8, 4, 10, 2, 8, 5, 98, 23, 10, 1, + 121, 18, 18, 10, 16, 8, 1, 18, 12, 10, + 2, 8, 3, 10, 2, 8, 4, 10, 2, 8, + 5, 66, 4, 10, 0, 16, 13}; + + Ort::SessionOptions session_options; + session_options.DisableMemPattern(); + std::unordered_map provider_options; + if (!no_proc_table) { + provider_options["dawnProcTable"] = std::to_string(reinterpret_cast(&dawn::native::GetProcs())); + } + session_options.AppendExecutionProvider("WebGPU", provider_options); + Ort::Session session{env, MODEL_DATA, sizeof(MODEL_DATA), session_options}; + + if (no_proc_table) { + std::cerr << "DawnProcTable is not passing to ONNX Runtime, but no exception is thrown." << std::endl; + retval = -1; + } else { + // successfully initialized + std::cout << "Successfully initialized WebGPU EP." << std::endl; + retval = 0; + } + } catch (const std::exception& ex) { + std::cerr << ex.what() << std::endl; + + if (no_proc_table) { + std::cout << "DawnProcTable is not passing to ONNX Runtime, so an exception is thrown as expected." << std::endl; + retval = 0; + } else { + std::cerr << "Unexpected exception." << std::endl; + retval = -1; + } + } + + ::google::protobuf::ShutdownProtobufLibrary(); + return retval; +} diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 9624f9112c49f..2b1d9ba205482 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -571,6 +571,7 @@ def convert_arg_line_to_args(self, arg_line): ) parser.add_argument("--use_jsep", action="store_true", help="Build with JavaScript kernels.") parser.add_argument("--use_webgpu", action="store_true", help="Build with WebGPU support.") + parser.add_argument("--use_external_dawn", action="store_true", help="Treat Dawn as an external dependency.") parser.add_argument("--use_qnn", action="store_true", help="Build with QNN support.") parser.add_argument("--qnn_home", help="Path to QNN SDK dir.") parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") @@ -1058,6 +1059,7 @@ def generate_build_tree( "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), "-Donnxruntime_USE_JSEP=" + ("ON" if args.use_jsep else "OFF"), "-Donnxruntime_USE_WEBGPU=" + ("ON" if args.use_webgpu else "OFF"), + "-Donnxruntime_USE_EXTERNAL_DAWN=" + ("ON" if args.use_external_dawn else "OFF"), # Training related flags "-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"), "-Donnxruntime_ENABLE_TRAINING=" + ("ON" if args.enable_training else "OFF"), @@ -1320,6 +1322,9 @@ def generate_build_tree( if args.use_jsep and args.use_webgpu: raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + if args.use_external_dawn and not args.use_webgpu: + raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).") + if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] diff --git a/tools/ci_build/github/android/default_full_aar_build_settings.json b/tools/ci_build/github/android/default_full_aar_build_settings.json index b0eff75812673..f08f246748a5a 100644 --- a/tools/ci_build/github/android/default_full_aar_build_settings.json +++ b/tools/ci_build/github/android/default_full_aar_build_settings.json @@ -16,6 +16,7 @@ "--build_shared_lib", "--use_nnapi", "--use_xnnpack", + "--use_webgpu", "--skip_tests" ] } diff --git a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json index 84d7e355ed5b4..6175ac3a0ad58 100644 --- a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json @@ -19,6 +19,7 @@ "--build_apple_framework", "--use_coreml", "--use_xnnpack", + "--use_webgpu", "--skip_tests", "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" ], diff --git a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json index e2d8f70c02cf3..4c2c9442ab217 100644 --- a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json @@ -24,12 +24,14 @@ "--ios", "--use_xcode", "--use_xnnpack", + "--use_webgpu", "--apple_deploy_target=13.0" ], "iphonesimulator": [ "--ios", "--use_xcode", "--use_xnnpack", + "--use_webgpu", "--apple_deploy_target=13.0" ], "macabi":[ diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index eb48b44db5a1f..7b27707428670 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -83,7 +83,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" BuildJava: false BuildNodejs: false WithCache: ${{ parameters.WithCache }} @@ -95,7 +95,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} @@ -107,7 +107,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml index c4db7735aaf2f..06f374afca57a 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml @@ -41,10 +41,11 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat + EnvSetupScript: setup_env.bat buildArch: x64 - # add --enable_pybind and --build_java if necessary + # add --build_java if necessary additionalBuildFlags: >- + --enable_pybind --build_nodejs --use_webgpu --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON @@ -56,3 +57,52 @@ stages: EnablePython: false WITH_CACHE: true MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 + +- stage: webgpu_external_dawn + dependsOn: [] + jobs: + - job: build_x64_RelWithDebInfo + variables: + DEPS_CACHE_DIR: $(Agent.TempDirectory)/deps_ccache + ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all + pool: onnxruntime-Win2022-VS2022-webgpu-A10 + timeoutInMinutes: 300 + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/jobs/win-ci-prebuild-steps.yml + parameters: + EnvSetupScript: setup_env.bat + DownloadCUDA: false + DownloadTRT: false + BuildArch: x64 + BuildConfig: RelWithDebInfo + MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 + WithCache: true + Today: $(Today) + + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: true + Today: $(TODAY) + CacheDir: $(ORT_CACHE_DIR) + AdditionalKey: " $(System.StageName) | RelWithDebInfo " + BuildPyArguments: '--config RelWithDebInfo --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --update --parallel --cmake_generator "Visual Studio 17 2022" --use_webgpu --use_external_dawn --skip_tests --target onnxruntime_webgpu_external_dawn_test' + MsbuildArguments: '-maxcpucount' + BuildArch: x64 + Platform: x64 + BuildConfig: RelWithDebInfo + + - script: | + onnxruntime_webgpu_external_dawn_test.exe + displayName: Run tests (onnxruntime_webgpu_external_dawn_test) + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + - script: | + onnxruntime_webgpu_external_dawn_test.exe --no_proc_table + displayName: Run tests (onnxruntime_webgpu_external_dawn_test) + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo'