From 4037bd4c8ee908bcb5f896744963869ef1b6b2ad Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:53:36 -0700 Subject: [PATCH 001/114] [WIP] WebGPU EP initial commit --- cmake/CMakeLists.txt | 6 + cmake/deps.txt | 1 + .../external/onnxruntime_external_deps.cmake | 11 + cmake/onnxruntime_providers.cmake | 7 + cmake/onnxruntime_providers_cpu.cmake | 7 +- cmake/onnxruntime_providers_webgpu.cmake | 30 + cmake/onnxruntime_providers_webnn.cmake | 2 +- cmake/onnxruntime_unittests.cmake | 22 + include/onnxruntime/core/common/string_join.h | 61 ++ include/onnxruntime/core/graph/constants.h | 1 + .../core/session/onnxruntime_c_api.h | 57 ++ .../core/session/onnxruntime_cxx_api.h | 3 + .../core/session/onnxruntime_cxx_inline.h | 19 + .../webgpu/webgpu_contrib_kernels.cc | 70 ++ .../webgpu/webgpu_contrib_kernels.h | 17 + .../core/providers/get_execution_providers.cc | 8 + .../providers/provider_factory_creators.h | 4 + .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 156 ++++ onnxruntime/core/providers/webgpu/README.md | 104 +++ .../core/providers/webgpu/allocator.cc | 38 + onnxruntime/core/providers/webgpu/allocator.h | 34 + .../core/providers/webgpu/buffer_manager.cc | 362 ++++++++ .../core/providers/webgpu/buffer_manager.h | 96 ++ .../core/providers/webgpu/compute_context.cc | 37 + .../core/providers/webgpu/compute_context.h | 97 ++ .../core/providers/webgpu/data_transfer.cc | 48 + .../core/providers/webgpu/data_transfer.h | 28 + .../webgpu/math/unary_elementwise_ops.cc | 68 ++ .../webgpu/math/unary_elementwise_ops.h | 28 + onnxruntime/core/providers/webgpu/program.cc | 196 ++++ onnxruntime/core/providers/webgpu/program.h | 491 ++++++++++ .../providers/webgpu/program_cache_key.cc | 90 ++ .../core/providers/webgpu/program_cache_key.h | 16 + .../core/providers/webgpu/program_manager.cc | 188 ++++ .../core/providers/webgpu/program_manager.h | 71 ++ .../core/providers/webgpu/shader_helper.cc | 204 +++++ .../core/providers/webgpu/shader_helper.h | 161 ++++ .../core/providers/webgpu/shader_macros.h | 66 ++ .../core/providers/webgpu/shader_variable.cc | 277 ++++++ .../core/providers/webgpu/shader_variable.h | 263 ++++++ .../core/providers/webgpu/webgpu_context.cc | 349 ++++++++ .../core/providers/webgpu/webgpu_context.h | 124 +++ .../webgpu/webgpu_execution_provider.cc | 837 ++++++++++++++++++ .../webgpu/webgpu_execution_provider.h | 77 ++ .../core/providers/webgpu/webgpu_kernel.h | 42 + .../webgpu/webgpu_provider_factory.cc | 144 +++ .../webgpu/webgpu_provider_factory_creator.h | 18 + .../webgpu/webgpu_provider_options.h | 40 + .../providers/webgpu/webgpu_supported_types.h | 34 + onnxruntime/core/session/inference_session.cc | 8 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 + onnxruntime/core/session/ort_apis.h | 7 + .../core/session/provider_registration.cc | 60 ++ onnxruntime/test/onnx/main.cc | 62 +- tools/ci_build/build.py | 5 + 55 files changed, 5246 insertions(+), 8 deletions(-) create mode 100644 cmake/onnxruntime_providers_webgpu.cmake create mode 100644 include/onnxruntime/core/common/string_join.h create mode 100644 onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc create mode 100644 onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h create mode 100644 onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md create mode 100644 onnxruntime/core/providers/webgpu/README.md create mode 100644 onnxruntime/core/providers/webgpu/allocator.cc create mode 100644 onnxruntime/core/providers/webgpu/allocator.h create mode 100644 onnxruntime/core/providers/webgpu/buffer_manager.cc create mode 100644 onnxruntime/core/providers/webgpu/buffer_manager.h create mode 100644 onnxruntime/core/providers/webgpu/compute_context.cc create mode 100644 onnxruntime/core/providers/webgpu/compute_context.h create mode 100644 onnxruntime/core/providers/webgpu/data_transfer.cc create mode 100644 onnxruntime/core/providers/webgpu/data_transfer.h create mode 100644 onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc create mode 100644 onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h create mode 100644 onnxruntime/core/providers/webgpu/program.cc create mode 100644 onnxruntime/core/providers/webgpu/program.h create mode 100644 onnxruntime/core/providers/webgpu/program_cache_key.cc create mode 100644 onnxruntime/core/providers/webgpu/program_cache_key.h create mode 100644 onnxruntime/core/providers/webgpu/program_manager.cc create mode 100644 onnxruntime/core/providers/webgpu/program_manager.h create mode 100644 onnxruntime/core/providers/webgpu/shader_helper.cc create mode 100644 onnxruntime/core/providers/webgpu/shader_helper.h create mode 100644 onnxruntime/core/providers/webgpu/shader_macros.h create mode 100644 onnxruntime/core/providers/webgpu/shader_variable.cc create mode 100644 onnxruntime/core/providers/webgpu/shader_variable.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_context.cc create mode 100644 onnxruntime/core/providers/webgpu/webgpu_context.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc create mode 100644 onnxruntime/core/providers/webgpu/webgpu_execution_provider.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_kernel.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc create mode 100644 onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_provider_options.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_supported_types.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2e9a50e52217..db0f1ac6ba08 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -149,6 +149,7 @@ option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llv option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only") option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) +option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF) # Options related to reducing the binary size produced by the build # XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON @@ -907,6 +908,11 @@ if (onnxruntime_USE_WEBNN) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBNN=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webnn) endif() +if (onnxruntime_USE_WEBGPU) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) +endif() if (onnxruntime_USE_CANN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CANN=1) diff --git a/cmake/deps.txt b/cmake/deps.txt index 2487ea144227..2ab00cdbeb30 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -59,3 +59,4 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d839 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 +dawn;https://github.com/google/dawn/archive/9a912d8162d5a837950de14f8849230212e3f51c.zip;7f2cad3db905e2d846d8f2422623850a4463915f diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 4e5270747405..2dad3479c3c0 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -585,6 +585,17 @@ if (onnxruntime_USE_COREML) FetchContent_Populate(coremltools) endif() +if (onnxruntime_USE_WEBGPU) + FetchContent_Declare( + dawn + URL ${DEP_URL_dawn} + URL_HASH SHA1=${DEP_SHA1_dawn} + ) + set(DAWN_FETCH_DEPENDENCIES ON) + set(DAWN_ENABLE_INSTALL ON) + onnxruntime_fetchcontent_makeavailable(dawn) +endif() + message("Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 05a50a55db40..9666877cdc20 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -110,6 +110,9 @@ endif() if(onnxruntime_USE_WEBNN) set(PROVIDERS_WEBNN onnxruntime_providers_webnn) endif() +if(onnxruntime_USE_WEBGPU) + set(PROVIDERS_WEBGPU onnxruntime_providers_webgpu) +endif() if (onnxruntime_USE_CANN) set(PROVIDERS_CANN onnxruntime_providers_cann) endif() @@ -151,6 +154,10 @@ if (onnxruntime_USE_WEBNN) include(onnxruntime_providers_webnn.cmake) endif() +if (onnxruntime_USE_WEBGPU) + include(onnxruntime_providers_webgpu.cmake) +endif() + if (onnxruntime_USE_NNAPI_BUILTIN) include(onnxruntime_providers_nnapi.cmake) endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index bbcc709b144a..219fb9753635 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -40,6 +40,11 @@ file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc" ) +file(GLOB_RECURSE onnxruntime_webgpu_contrib_ops_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.cc" +) + file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/*.h" "${ONNXRUNTIME_ROOT}/core/providers/*.cc" @@ -60,7 +65,7 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() - set(onnxruntime_cpu_neural_speed_srcs + set(onnxruntime_cpu_neural_speed_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake new file mode 100644 index 000000000000..303ab9483c38 --- /dev/null +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "WebGPU EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + + # find_package(Dawn REQUIRED) + + add_compile_definitions(USE_WEBGPU=1) + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1) + endif() + file(GLOB_RECURSE onnxruntime_providers_webgpu_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.cc" + # "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + # "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_webgpu_contrib_ops_cc_srcs}) + list(APPEND onnxruntime_providers_webgpu_cc_srcs ${onnxruntime_webgpu_contrib_ops_cc_srcs}) + endif() + + source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + + set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_providers_webnn.cmake b/cmake/onnxruntime_providers_webnn.cmake index 05c63c22244d..39ca476810f4 100644 --- a/cmake/onnxruntime_providers_webnn.cmake +++ b/cmake/onnxruntime_providers_webnn.cmake @@ -22,4 +22,4 @@ add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) \ No newline at end of file + set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d7f4a0675e11..5434ead12f65 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -557,6 +557,10 @@ if(onnxruntime_USE_JSEP) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js) endif() +if(onnxruntime_USE_WEBGPU) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu) +endif() + if(onnxruntime_USE_RKNPU) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_rknpu) endif() @@ -598,6 +602,7 @@ set(ONNXRUNTIME_TEST_LIBS ${PROVIDERS_NNAPI} ${PROVIDERS_VSINPU} ${PROVIDERS_JS} + ${PROVIDERS_WEBGPU} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} @@ -658,6 +663,13 @@ if(onnxruntime_USE_JSEP) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_js) endif() +if(onnxruntime_USE_WEBGPU) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/webgpu/*) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_webgpu) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_webgpu) +endif() + # QNN EP tests require CPU EP op implementations for accuracy evaluation, so disable on minimal # or reduced op builds. if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) @@ -1088,6 +1100,11 @@ if (NOT IOS) endif() set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") + add_custom_command(TARGET onnx_test_runner POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy $ $ + COMMAND_EXPAND_LISTS + ) + if (onnxruntime_USE_TVM) if (WIN32) target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") @@ -1218,6 +1235,11 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() endif() + add_custom_command(TARGET onnxruntime_perf_test POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy $ $ + COMMAND_EXPAND_LISTS + ) + if (onnxruntime_BUILD_SHARED_LIB) #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. diff --git a/include/onnxruntime/core/common/string_join.h b/include/onnxruntime/core/common/string_join.h new file mode 100644 index 000000000000..2c2181d4ad04 --- /dev/null +++ b/include/onnxruntime/core/common/string_join.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/make_string.h" + +namespace onnxruntime { + +namespace detail { + +template +inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss) noexcept { +} + +template +inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t) noexcept { + ss << separator << t; +} + +template +inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t, const Args&... args) noexcept { + StringJoinImpl(separator, ss, t); + StringJoinImpl(separator, ss, args...); +} + +template +inline std::string StringJoinImpl(const Separator& separator, const Args&... args) noexcept { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + StringJoinImpl(separator, ss, args...); + return ss.str(); +} +} // namespace detail + +/** + * Makes a string by concatenating string representations of the arguments using the specified separator. + * Uses std::locale::classic() + */ +template +std::string StringJoin(const Separator& separator, const Args&... args) { + return detail::StringJoinImpl(separator, detail::if_char_array_make_ptr_t(args)...); +} + +// StringJoin versions for already-a-string types. + +template +inline std::string StringJoin(const Separator& /* separator */, const std::string& str) { + return str; +} + +template +inline std::string StringJoin(const Separator& /* separator */, const char* cstr) { + return cstr; +} + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 39acb6b4f2aa..f072badd199b 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -50,6 +50,7 @@ constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider"; constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider"; constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider"; constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider"; +constexpr const char* kWebGpuExecutionProvider = "WebGpuExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kVSINPUExecutionProvider = "VSINPUExecutionProvider"; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4674db42fb1c..9e5d9339bffe 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -624,6 +624,32 @@ typedef struct OrtMIGraphXProviderOptions { bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false } OrtMIGraphXProviderOptions; +/** \brief WebGPU Execution Provider Options + * + * When a user wants to use WebGPU as the execution provider, there are 2 ways to specify the WebGPU device: + * + * 1. Use the default WebGPU device. The default WebGPU device is managed by WebGPU EP internally. The user doesn't + * need to provide any device information in this case. All the fields should be set to nullptr or 0. + * + * 2. Use a custom WebGPU device. The user should create their own handles of `WGPUInstance`, `WGPUAdapter`, and + * `WGPUDevice` and use arbitrary number in [1..65536) as the device id. The user should provide the handles + * and the device id in the options. + * + * When specifying an existing Device ID, the user should provide the handles of `WGPUInstance`, `WGPUAdapter`, and + * `WGPUDevice` in the options. The device id should be the same as the one used previously. + * + * It's user's responsibility to manage the lifecycle of the handles and ensure the handles are valid during the + * lifetime of the inference session. + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_WebGPU + */ +typedef struct OrtWebGPUProviderOptions { + int device_id; // WebGPU device id. + void* instance_handle; // WebGPU instance handle. + void* adapter_handle; // WebGPU adapter handle. + void* device_handle; // WebGPU device handle. +} OrtWebGPUProviderOptions; + /** \brief OpenVINO Provider Options * * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO @@ -4667,6 +4693,37 @@ struct OrtApi { _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + + /** \brief Append WebGPU execution provider to session options + * + * If WebGPU is not available, this function will return failure. + * + * \param[in] options + * \param[in] webgpu_options - specify the WebGPU provider options. + * \param[in] string_options_keys - keys to configure the string options + * \param[in] string_options_values - values to configure the string options + * \param[in] num_keys - number of keys passed in + * + * Supported keys are listed as below. All entries are optional. + * + * | Key | Possible Values | Default Value | + * | ------------------------------ | ---------------------------------------------- | -------------- | + * | "preferredLayout" | "NHWC" or "NCHW" | "NHWC" | + * | "enableGraphCapture" | "1" or "0" | "0" | + * | "storageBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "bucket" | + * | "uniformBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "lazyRelease" | + * | "queryResolveBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" | + * | "defaultBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" | + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WebGPU, + _In_ OrtSessionOptions* options, _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_reads_(num_keys) const char* const* string_options_keys, + _In_reads_(num_keys) const char* const* string_options_values, + _In_ size_t num_keys); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 29a229f42716..cf30584e18a4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -890,6 +890,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WebGPU + SessionOptionsImpl& AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, + const std::unordered_map& string_options = {}); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d3a8cade4d28..e5c84395ad95 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -838,6 +838,25 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIG return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, + const std::unordered_map& string_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WebGPU(this->p_, &provider_options, keys.data(), values.data(), num_entries)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc new file mode 100644 index 000000000000..91f51df588fc --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); +// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); + +// template <> +// KernelCreateInfo BuildKernelCreateInfo() { +// KernelCreateInfo info; +// return info; +// } + +Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); + } + } + return Status::OK(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h new file mode 100644 index 000000000000..6cdf7382804f --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_registry.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index 61c035bc29ed..d2a72c3a38b0 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -162,6 +162,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kWebGpuExecutionProvider, +#ifdef USE_WEBGPU + true, +#else + false, #endif }, { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 47d3f2f793d7..41e418d9eb97 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -94,6 +94,10 @@ #include "core/providers/webnn/webnn_provider_factory_creator.h" #endif +#if defined(USE_WEBGPU) +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#endif + #if defined(USE_CANN) #include "core/providers/cann/cann_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md new file mode 100644 index 000000000000..a5a71fd94bf4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -0,0 +1,156 @@ +# How to Write WebGPU EP Kernel + +This document describes how to write a WebGPU EP kernel for ONNX Runtime. + +The following document will assume the operator name is `Example`, and you will see class `ExampleProgram` and `ExampleOpKernel` in the examples. Replace `Example` with the actual operator name you are implementing. + +Follow the following steps to create a WebGPU kernel: + +## 1. Decide _filename_ and _cateogory_, and create a new file at: + +`onnxruntime/core/providers/webgpu/{category}/{filename}.cc` + +- filename is usually a snake_case_name of the operator name, or a descriptive name if it includes multiple operators (eg. binary_elementwise_ops.cc) +- category is the subfolder representing the operator category (eg. math/nn/controlflow) + + see folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples + +## 2. Declare a new Program class + +### 2.1. The Program class should inherit from Program: + +```c++ +class ExampleProgram : public Program { +// ... +} +``` + +### 2.2. The Program class can define the following information: + +There are 3 types of definitions described as below. All of them are optional. If not specified, it is treated as empty. Those definitions are defined as static const members to ensure they don't depend on any runtime information. + +#### **constants** + +constants are declaration of values that are never changes in the shader code. They are inserted into the WGSL source code like this: + +```wgsl +const A : u32 = 64; +``` + +Use macro `WEBGPU_PROGRAM_DEFINE_CONSTANTS` to define constants in your Program class. + +#### **overridable constants** + +overridable constants are similar to constants, but they can be overridden before the compute pipeline is created. Overridable constants may or may not have a default value. They are inserted into the WGSL source code like this: + +```wgsl +override B : u32 = 64; +override C : f32; +``` + +Use macro `WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS` to define overridable constants in your Program class. + +#### **uniform definitions** + +uniform definitions are declaration of uniform varables. Their names and type must be defined and cannot be changed. Their values(including length) can be set at runtime. + +Use macro `WEBGPU_PROGRAM_DEFINE_UNIFORMS` to define uniform definitions in your Program class. + +### 2.3. The Program class should override the `GenerateShaderCode` method: + +```c++ +Status GenerateShaderCode(ShaderHelper& sh) const override; +``` + +In the function implementation, `sh` is an instance of `ShaderHelper` which provides a set of helper functions to generate shader code. + +Example: + +```c++ +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddVariable(ProgramVariableScope::Input, + "x", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + 1); + const auto& output = shader.AddVariable(ProgramVariableScope::Output, + "y", + ToProgramVariableDataType(Outputs()[0]->GetElementType(), 4), + 1); + shader.AppendImplementation(additional_impl_); + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "let a = ", input.GetByOffset("global_idx"), ";\n", + output.SetByOffset("global_idx", expression_)); + + return Status::OK(); +} +``` + +`ShaderHelper::AddVariable` creates an instace of `ShaderVariable`. The class `ShaderVariable` is similar to `IndicesHelper` in onnxruntime-web. It provides a set of helper functions as value/indices/offset getter/setter. + +`ShaderHelper::AppendImplementation` inserts additional implementation code into the shader code. It will be put before the main function. + +`ShaderHelper::MainFunctionBody` generates the main function body. It accepts arbitrary number of arguments and concatenates them into the main function body. + +### 2.3. Lifecycle of the Program class + +For each calls into the `ExampleOpKernel::ComputeInternal()` method, a new instance of the `ExampleProgram` class should be created as local variable (The detail will be explained in `ExampleOpKernel` as below). The Program instance is destroyed when reaching the end of scope. + +A few functions can be called on the Program instance: + +- call `ProgramBase::Inputs` and `ProgramBase::Outputs` to set input/output tensor info. +- call `ProgramBase::CacheHint` to set the cache hint. +- call `ProgramBase::UniformsVariables`(optional) and `ProgramBase::OverridableConstants`(optional) to set runtime info of uniforms and overridable constants. They need to match the corresponding definitions described above. +- call `ProgramBase::DispatchGroupSize` and `ProgramBase::WorkgroupSize`(optional) to set the dispatch group size and workgroup size. + +## 3. Declare a new OpKernel class + +### 3.1. The OpKernel class should inherit from WebGpuKernel: + +```c++ +class ExampleOpKernel : public WebGpuKernel { +// ... +} +``` + +### 3.2. The OpKernel class should override the `ComputeInternal` method: + +```c++ +Status ComputeInternal(ComputeContext& context) const override; +``` + +Usually, in the implementation, we do 3 things: +- Create a local variable of the Program class. +- Set a few runtime info of the Program instance. +- Call `context.RunProgram(program)` to run the program and return the status. + +Complicated operators may do more things. Check header files and existing implementations for more details. + +## 4. Register the operator + +Register the operator just like any EP does. Check existing implementations for more details. + +Please note that registration is composed of 2 parts: +- Use macros like `ONNX_OPERATOR_KERNEL_EX` or `ONNX_OPERATOR_VERSIONED_KERNEL_EX` (or wrap a new macro as what we usually do) to register the operator in kernel source code file. +- Add the operator to onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc + +## 5. Write tests + +This section is WIP. + +## 6. Build and test + +use `build.bat --use_webgpu` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. + +to test, find the "onnx_test_runner.exe" in your build folder. run it like: +``` +onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" --model_path=C:\code\onnxruntime\js\test\data\node\opset17\test_abs +``` + +> Assume C:\code\onnxruntime is the root of your onnxruntime repo +> +> if it does not exist, run the following in your onnxruntime repo root: +> ``` +> cd js +> npm ci +> npm run prepare-node-tests +> ``` diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md new file mode 100644 index 000000000000..d9c4313c8bf3 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/README.md @@ -0,0 +1,104 @@ +# WebGPU Execution Provider + +This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU EP is working in progress. + +## Build WebGPU EP + +Just append `--use_webgpu` to the `build.bat` command line. + +Currently only works on Windows. + +## Troubleshooting + +TODO: add solutions to common problems. + +## Development Guide + +See [How to write WebGPU EP kernel](./How_to_Write_WebGPU_EP_Kernel.md) for more information. + +## Convention + +### Use "webgpu" other than "wgpu" in this folder + +This is referring to the naming convention of variables, classes and namespace. + +ORT C API is using "wgpu". + +Let's keep it "webgpu" for this folder for now. I have a very good reason to do so: + +- search for "webgpu" in the code base shows the WebGPU EP related code and search for "wgpu" shows the WebGPU API related code. This helps me easier to find the code I want to look at. + +And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") + +### Use macros defined in shader_macros.h + +Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. + +I prefer to use the macro because I feel like it's easier to read. Check the following code: + +```cpp +ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; +``` + +vs. + +```cpp +SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); +``` + +### Use the subfolder for kernel implementation + +Operator implementation source code need to be put under a subfolder like "math"/"nn"/"tensor". + +See folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples. + +## Best Practices + +### Always use std::ostringstream to generate shader code if possible + +This helps to the performance of code generation. + +For example: + +```cpp +ss << "var " << name << " = " << value << ";\n"; +``` + +is better than + +```cpp +ss << ("var " + name + " = " + value + ";\n"); +``` + +### Avoid creating template class for kernel using data type as template parameter. + +This basically means that we should define class like this: + +```cpp +class Abs : public WebGpuKernel { + ... +}; +``` + +instead of + +```cpp + +template // T is tensor element type +class Abs : public WebGpuKernel { + ... +}; +``` + +This is because we don't really read and use `Tensor::Data()`. Tensor stores a handle to a WebGPU buffer but not a pointer to the data. Using template for data type only increases the binary size with no real benefit. + +## TODO items + +The following items are not yet implemented: + +- [ ] Validation Switch (allows to change the behavior of whether perform specific validation checks) +- [ ] pushErrorScope/popErrorScope +- [ ] Graph Capture +- [ ] Profiling supported by WebGPU Query Buffer +- [ ] WebGPU resources tracking (mainly for buffers) +- [ ] Global hanlders( unhandled exceptions and device lost ) diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc new file mode 100644 index 000000000000..8e27acdc285d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/framework/session_state.h" +#include "core/providers/webgpu/allocator.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +void* GpuBufferAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + + auto buffer = context_.BufferManager().Create(size); + + stats_.num_allocs++; + return buffer; +} + +void GpuBufferAllocator::Free(void* p) { + if (p != nullptr) { + context_.BufferManager().Release(static_cast(p)); + stats_.num_allocs--; + } +} + +void GpuBufferAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h new file mode 100644 index 000000000000..51ca65a8b482 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class GpuBufferAllocator : public IAllocator { + public: + GpuBufferAllocator(const WebGpuContext& context) + : IAllocator( + OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), + 0, OrtMemTypeDefault)), + context_{context} { + } + + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc new file mode 100644 index 000000000000..d69b1210ade4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_context.h" + +static int xx = 1; + +namespace onnxruntime { +namespace webgpu { + +size_t NormalizeBufferSize(size_t size) { + return (size + 15) / 16 * 16; +} + +class DisabledCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + // always return empty buffer + return nullptr; + } + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + void ReleaseBuffer(WGPUBuffer buffer) override { + wgpuBufferRelease(buffer); + } + + void OnRefresh() override { + // no-op + } +}; + +class LazyReleaseCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + pending_buffers_.clear(); + } + + std::vector pending_buffers_; +}; + +class SimpleCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buffers_.find(buffer_size); + if (it != buffers_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + buffers_[wgpuBufferGetSize(buffer)].push_back(buffer); + } + pending_buffers_.clear(); + } + + std::map> buffers_; + std::vector pending_buffers_; +}; + +// TODO: maybe use different bucket size for storage and uniform buffers? +constexpr std::initializer_list> BUCKET_DEFAULT_LIMIT_TABLE = { + {64, 250}, + {128, 200}, + {256, 200}, + {512, 200}, + {2048, 230}, + {4096, 200}, + {8192, 50}, + {16384, 50}, + {32768, 50}, + {65536, 50}, + {131072, 50}, + {262144, 50}, + {524288, 50}, + {1048576, 50}, + {2097152, 30}, + {4194304, 20}, + {8388608, 10}, + {12582912, 10}, + {16777216, 10}, + {26214400, 15}, + {33554432, 22}, + {44236800, 2}, + {58982400, 6}, + // we don't want to cache the bucket sizes below but not caching them + // results in some major performance hits for models like sd-turbo. + {67108864, 6}, + {134217728, 6}, + {167772160, 6}, +}; + +class BucketCacheManager : public IBufferCacheManager { + public: + BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + BucketCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + // TODO: consider graph capture. currently not supported + + for (auto& buffer : pending_buffers_) { + auto buffer_size = wgpuBufferGetSize(buffer); + + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.push_back(buffer); + } else { + wgpuBufferRelease(buffer); + } + } + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + buckets_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + buckets_.emplace(pair.first, std::vector()); + } + std::sort(buckets_keys_.begin(), buckets_keys_.end()); + +#ifndef NDEBUG // if debug build + for (size_t i = 0; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] % 16 == 0, "Bucket sizes must be multiples of 16."); + } + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_; + std::vector pending_buffers_; + std::vector buckets_keys_; +}; + +std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { + switch (cache_mode) { + case BufferCacheMode::Disabled: + return std::make_unique(); + case BufferCacheMode::LazyRelease: + return std::make_unique(); + case BufferCacheMode::Simple: + return std::make_unique(); + case BufferCacheMode::Bucket: + return std::make_unique(); + default: + ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode"); + } +} + +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { + switch (mode) { + case BufferCacheMode::Disabled: + os << "Disabled"; + break; + case BufferCacheMode::LazyRelease: + os << "LazyRelease"; + break; + case BufferCacheMode::Simple: + os << "Simple"; + break; + case BufferCacheMode::Bucket: + os << "Bucket"; + break; + default: + os << "Unknown(" << static_cast(mode) << ")"; + } + return os; +} + +BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) + : context_{context}, + storage_cache_{std::move(CreateBufferCacheManager(storage_buffer_cache_mode))}, + uniform_cache_{std::move(CreateBufferCacheManager(uniform_buffer_cache_mode))}, + query_resolve_cache_{std::move(CreateBufferCacheManager(query_resolve_buffer_cache_mode))}, + default_cache_{std::move(CreateBufferCacheManager(BufferCacheMode::Disabled))} { +} + +void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite; + desc.mappedAtCreation = true; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto mapped_data = staging_buffer.GetMappedRange(); + memcpy(mapped_data, src, size); + staging_buffer.Unmap(); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); + pending_staging_buffers_.push_back(staging_buffer); +} + +void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { + ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); + + auto buffer_size = NormalizeBufferSize(size); + ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + "Source and destination buffers must have enough space for the copy operation. src_size=", + wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); +} + +WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { + auto& cache = GetCacheManager(static_cast(usage)); + auto buffer_size = cache.CalculateBufferSize(size); + + auto buffer = cache.TryAcquireCachedBuffer(buffer_size); + if (buffer) { + return buffer; + } + + // cache miss, create a new buffer + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = usage; + // desc.label = std::to_string(xx++).c_str(); + buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); + + ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); + + cache.RegisterBuffer(buffer, size); + return buffer; +} + +void BufferManager::Release(WGPUBuffer buffer) { + GetCacheManager(buffer).ReleaseBuffer(buffer); +} + +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); + context_.Flush(); + + // TODO: revise wait in whole project + + ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, wgpu::CallbackMode::WaitAnyOnly, [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + + auto mapped_data = staging_buffer.GetConstMappedRange(); + memcpy(dst, mapped_data, size); +} + +void BufferManager::RefreshPendingBuffers() { + pending_staging_buffers_.clear(); + storage_cache_->OnRefresh(); + uniform_cache_->OnRefresh(); + query_resolve_cache_->OnRefresh(); + default_cache_->OnRefresh(); +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const { + if (usage & WGPUBufferUsage_Storage) { + return *storage_cache_; + } else if (usage & WGPUBufferUsage_Uniform) { + return *uniform_cache_; + } else if (usage & WGPUBufferUsage_QueryResolve) { + return *query_resolve_cache_; + } else { + return *default_cache_; + } +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const { + return GetCacheManager(wgpuBufferGetUsage(buffer)); +} + +std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) { + return std::make_unique(context, storage_buffer_cache_mode, uniform_buffer_cache_mode, query_resolve_buffer_cache_mode); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h new file mode 100644 index 000000000000..c94f77b6b5fa --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +enum class BufferCacheMode { + Disabled, + LazyRelease, + Simple, + Bucket +}; +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); + +// +// IBufferCacheManager is an interface for buffer cache management. +// +// By implementing this interface, we can have different buffer cache management strategies. +// Currently, we have 3 strategies: +// - Disabled: no cache. always allocate a new buffer and release it immediately after use. +// - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh. +// - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache. +// - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket. +// +class IBufferCacheManager { + public: + virtual ~IBufferCacheManager() = default; + + // calculate actual buffer size to allocate based on the requested size. + virtual size_t CalculateBufferSize(size_t request_size) = 0; + + // return a buffer if available in cache. otherwise empty. + virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) = 0; + + // register a newly created buffer + virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size) = 0; + + // release a buffer + virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; + + // when a stream refresh is requested + virtual void OnRefresh() = 0; +}; + +// +// BufferManager manages operations on buffers. +// +class BufferManager { + public: + BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + void Upload(void* src, WGPUBuffer dst, size_t size); + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size); + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + void Release(WGPUBuffer buffer); + void Download(WGPUBuffer src, void* dst, size_t size); + void RefreshPendingBuffers(); + + private: + IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const; + IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; + + WebGpuContext& context_; + std::unique_ptr storage_cache_; + std::unique_ptr uniform_cache_; + std::unique_ptr query_resolve_cache_; + std::unique_ptr default_cache_; + + std::vector pending_staging_buffers_; +}; + +class BufferManagerFactory { + public: + static std::unique_ptr Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + private: + BufferManagerFactory() {} +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc new file mode 100644 index 000000000000..67c55f823d78 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { +ComputeContext::ComputeContext(OpKernelContext& kernel_context) + : webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())}, + kernel_context_{kernel_context} { +} + +const wgpu::AdapterInfo& ComputeContext::AdapterInfo() const { + return webgpu_context_.AdapterInfo(); +} + +const wgpu::Limits& ComputeContext::DeviceLimits() const { + return webgpu_context_.DeviceLimits(); +} + +int ComputeContext::InputCount() const { + return kernel_context_.InputCount(); +} + +int ComputeContext::OutputCount() const { + return kernel_context_.OutputCount(); +} + +Status ComputeContext::RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h new file mode 100644 index 000000000000..d7aeae240101 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include + +#include "core/framework/execution_provider.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { + +class Tensor; +class OpKernelContext; + +namespace webgpu { + +class WebGpuContext; + +class ComputeContext { + public: + ComputeContext(OpKernelContext& kernel_context); + + virtual ~ComputeContext() = default; + + // + // Get various information from the context. + // + + const wgpu::AdapterInfo& AdapterInfo() const; + const wgpu::Limits& DeviceLimits() const; + + // + // Get input tensor. + // + template + const T* Input(int index) const { + return kernel_context_.Input(index); + } + + // + // Get input count. + // + int InputCount() const; + + // + // Set output tensor. + // + template + Tensor* Output(int index, TensorShapeType&& shape) { + return kernel_context_.Output(index, std::forward(shape)); + } + + // + // Get output count. + // + int OutputCount() const; + + // + // Create CPU tensor. + // + template + Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceCPUAllocator(&allocator)); + return {data_type, std::forward(shape)..., allocator}; + } + + // + // Create GPU tensor. + // + template + Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); + return {data_type, std::forward(shape)..., allocator}; + } + + // + // Run a compute shader program. + // + Status RunProgram(const ProgramBase& program); + + protected: + WebGpuContext& webgpu_context_; + OpKernelContext& kernel_context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc new file mode 100644 index 000000000000..615ae1117578 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + if (bytes > 0) { + void const* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + context_.BufferManager().MemCpy(static_cast(const_cast(src_data)), + static_cast(dst_data), bytes); + } else { + // copy from CPU to GPU + context_.BufferManager().Upload(const_cast(src_data), static_cast(dst_data), bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); + } + } + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h new file mode 100644 index 000000000000..f9949576aa60 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class DataTransfer : public IDataTransfer { + public: + DataTransfer(const WebGpuContext& context) : context_{context} {}; + ~DataTransfer() {}; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + + private: + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc new file mode 100644 index 000000000000..5c774df84638 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddVariable(ProgramVariableScope::Input, + "x", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + 1); + const auto& output = shader.AddVariable(ProgramVariableScope::Output, + "y", + ToProgramVariableDataType(Outputs()[0]->GetElementType(), 4), + 1); + shader.AppendImplementation(additional_impl_); + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "let a = ", input.GetByOffset("global_idx"), ";\n", + output.SetByOffset("global_idx", expression_)); + + return Status::OK(); +} + +#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public WebGpuKernel { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : WebGpuKernel{info} {} \ + \ + protected: \ + Status ComputeInternal(ComputeContext& context) const override { \ + const auto* input_tensor = context.Input(0); \ + auto* output_tensor = context.Output(0, input_tensor->Shape()); \ + SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; \ + UnaryElementwiseProgram program{#OP_TYPE, __VA_ARGS__}; \ + program \ + .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) \ + .Outputs({output_tensor}) \ + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) \ + .UniformVariables({ \ + {static_cast(vec_size)}, \ + }); \ + return context.RunProgram(program); \ + } \ + }; + +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +WEBGPU_ELEMENTWISE_IMPL(Abs, "abs(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, Abs, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, Abs, WebGpuSupportedFloatTypes()) + +// TODO: add other unary elementwise ops + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h new file mode 100644 index 000000000000..837f66af30dd --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class UnaryElementwiseProgram final : public Program { + public: + UnaryElementwiseProgram(const std::string& kernel_name, const std::string& expression, const std::string& additional_impl = "") + : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + std::string expression_; + std::string additional_impl_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc new file mode 100644 index 000000000000..8ba33bcafb31 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramUniformVariableValue::ProgramUniformVariableValue() + : length{0}, data_type{} {} // representing an empty uniform variable + +ProgramUniformVariableValue::ProgramUniformVariableValue(float value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, &value, sizeof(float)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(uint32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, &value, sizeof(uint32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(int32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, &value, sizeof(int32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(MLFloat16 value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, &value, sizeof(MLFloat16)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, values.data(), sizeof(float), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, values.data(), sizeof(uint32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, values.data(), sizeof(int32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, values.data(), sizeof(MLFloat16), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, + const void* ptr, + size_t element_byte_size, + size_t length /* = 1 */) + : length{length}, data_type{data_type} { + ORT_ENFORCE(length > 0, "number of element of uniform variable must be greater than 0"); + + data.resize(length * element_byte_size); + memcpy(data.data(), ptr, length * element_byte_size); +} + +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) { + os << ProgramUniformVariableDataTypeName[static_cast(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) { + os << ProgramConstantDataTypeName[static_cast(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency dep) { + bool first = true; + if ((dep & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { + os << "Type"; + first = false; + } + if ((dep & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { + if (!first) os << "|"; + os << "Rank"; + first = false; + } + if ((dep & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { + if (!first) os << "|"; + os << "Shape"; + first = false; + } + if (first) { + os << "None"; + } + + return os; +} + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { + if (component == 1) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return ProgramVariableDataType::Int64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return ProgramVariableDataType::Uint64; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 2) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Vec2Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Vec2Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Vec2Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Vec2Uint32; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 4) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Vec4Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Vec4Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Vec4Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Vec4Uint32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return ProgramVariableDataType::Vec4Bool; + default: + return ProgramVariableDataType::InvalidType; + } + } else { + return ProgramVariableDataType::InvalidType; + } +} + +ProgramBase::ProgramBase(const std::string& name) + : name_{name}, + dispatch_group_size_x_{0}, + dispatch_group_size_y_{0}, + dispatch_group_size_z_{0}, + workgroup_size_x_{WORKGROUP_SIZE}, + workgroup_size_y_{1}, + workgroup_size_z_{1} { +} + +ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { + inputs_.assign(inputs.begin(), inputs.end()); + return *this; +} + +ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { + outputs_.assign(outputs.begin(), outputs.end()); + return *this; +} + +ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x) { + return DispatchGroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y) { + return DispatchGroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { + dispatch_group_size_x_ = x; + dispatch_group_size_y_ = y; + dispatch_group_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::WorkgroupSize(uint32_t x) { + return WorkgroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y) { + return WorkgroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { + workgroup_size_x_ = x; + workgroup_size_y_ = y; + workgroup_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::UniformVariables(std::initializer_list variables) { + variables_.insert(variables_.end(), variables.begin(), variables.end()); + return *this; +} + +ProgramBase& ProgramBase::OverridableConstants(std::initializer_list overridable_constants) { + overridable_constants_.insert(overridable_constants_.end(), overridable_constants.begin(), overridable_constants.end()); + return *this; +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h new file mode 100644 index 000000000000..6df918e2f7f7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.h @@ -0,0 +1,491 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/string_join.h" +#include "core/common/safeint.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webgpu { +class ShaderHelper; +class ComputeContext; +class WebGpuContext; + +// data type of uniform variable +enum class ProgramUniformVariableDataType { + Float32, + Float16, + Uint32, + Int32, +}; +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType); + +constexpr size_t ProgramUniformVariableDataTypeSize[] = {sizeof(float), sizeof(uint16_t), sizeof(uint32_t), sizeof(int32_t)}; + +constexpr std::string_view ProgramUniformVariableDataTypeName[] = {"f32", "f16", "u32", "i32"}; + +// represents a runtime value of a uniform variable +struct ProgramUniformVariableValue { + ProgramUniformVariableValue(); // representing an empty uniform variable + ProgramUniformVariableValue(float value); + ProgramUniformVariableValue(uint32_t value); + ProgramUniformVariableValue(int32_t value); + ProgramUniformVariableValue(MLFloat16 value); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + + size_t length; + ProgramUniformVariableDataType data_type; + std::vector data; + + private: + ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, const void* ptr, size_t element_byte_size, size_t length = 1); +}; + +// represents a uniform variable definition +struct ProgramUniformVariableDefinition { + std::string_view name; + ProgramUniformVariableDataType data_type; +}; + +// data type of constant +enum class ProgramConstantDataType { + Float32, + Float16, + Uint32, + Int32, + Bool +}; +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType); + +constexpr std::string_view ProgramConstantDataTypeName[] = {"f32", "f16", "u32", "i32", "bool"}; + +// represents a constant in a program +struct ProgramConstant { + constexpr ProgramConstant(std::string_view name, float value) : name{name}, type{ProgramConstantDataType::Float32}, f32{value} {} + constexpr ProgramConstant(std::string_view name, uint32_t value) : name{name}, type{ProgramConstantDataType::Uint32}, u32{value} {} + constexpr ProgramConstant(std::string_view name, int32_t value) : name{name}, type{ProgramConstantDataType::Int32}, i32{value} {} + constexpr ProgramConstant(std::string_view name, MLFloat16 value) : name{name}, type{ProgramConstantDataType::Float16}, f16{value} {} + constexpr ProgramConstant(std::string_view name, bool value) : name{name}, type{ProgramConstantDataType::Bool}, boolean{value} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; +}; + +// represents a runtime value of an overridable constant +struct ProgramOverridableConstantValue { + constexpr ProgramOverridableConstantValue() : type{}, u32{}, has_value{false} {} // representing not overriding + constexpr ProgramOverridableConstantValue(float value) : type{ProgramConstantDataType::Float32}, f32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(uint32_t value) : type{ProgramConstantDataType::Uint32}, u32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(int32_t value) : type{ProgramConstantDataType::Int32}, i32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(MLFloat16 value) : type{ProgramConstantDataType::Float16}, f16{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(bool value) : type{ProgramConstantDataType::Bool}, boolean{value}, has_value{true} {} + + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_value; +}; + +// represents an overridable constant definition. may or may not have a default value. +struct ProgramOverridableConstantDefinition { + constexpr ProgramOverridableConstantDefinition(std::string_view name, ProgramConstantDataType type) + : name{name}, type{type}, u32{}, has_default_value{false} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, float value) + : name{name}, type{ProgramConstantDataType::Float32}, f32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, uint32_t value) + : name{name}, type{ProgramConstantDataType::Uint32}, u32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, int32_t value) + : name{name}, type{ProgramConstantDataType::Int32}, i32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, MLFloat16 value) + : name{name}, type{ProgramConstantDataType::Float16}, f16{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, bool value) + : name{name}, type{ProgramConstantDataType::Bool}, boolean{value}, has_default_value{true} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_default_value; +}; + +// represents whether the program shader depends on the type, rank, or shape of an input/output tensor +enum class ProgramInputTensorDependency : int { + None = 0, + Type = 1, + Rank = 2, + Shape = 4, + TypeAndRank = Type | Rank, + TypeAndShape = Type | Shape, +}; +std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency); + +inline ProgramInputTensorDependency operator|(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency)((int&)a | (int&)b); +} +inline ProgramInputTensorDependency operator&(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency)((int&)a & (int&)b); +} +inline ProgramInputTensorDependency& operator|=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency&)((int&)a |= (int&)b); +} +inline ProgramInputTensorDependency& operator&=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency&)((int&)a &= (int&)b); +} + +struct ProgramInput { + const Tensor* tensor; + ProgramInputTensorDependency dependency; +}; + +constexpr SafeInt WORKGROUP_SIZE = 64; + +// represents the scope of a variable in a shader program. +// +// this is not a full list of all possible variable scopes in shader programs. +// it only includes what are used in WebGPU EP. +enum class ProgramVariableScope { + Input = 0, // storage buffer variable with access mode "read" + Output = 1, // storage buffer variable with access mode "read_write" + Local = 2, // local variable + + Count // should always be the last element +}; + +// data type of variable +// +// this is not a full list of all possible data types in shader programs. +// it only includes what are used in WebGPU EP. +enum class ProgramVariableDataType { + InvalidType = -1, + Float32, + Vec2Float32, + Vec4Float32, + Float16, + Vec2Float16, + Vec4Float16, + Int32, + Vec2Int32, + Vec4Int32, + Uint32, + Vec2Uint32, + Vec4Uint32, + Int64, + Uint64, + Vec4Bool, +}; + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); + +namespace detail { +class ProgramWrapper; +} + +struct ProgramMetadata; + +class ProgramBase { + public: + // + // chain-style methods for setting properties + // + + // set the cache hint for the program + template + ProgramBase& CacheHint(CacheHintArgs&&... args) { + cache_hint_ = StringJoin("|", std::forward(args)...); + } + + // set one or more program inputs + ProgramBase& Inputs(std::initializer_list inputs); + // set one or more program outputs + ProgramBase& Outputs(std::initializer_list outputs); + + // set the size of dispatch groups. Y and Z are 1 if not specified. + ProgramBase& DispatchGroupSize(uint32_t x); + // set the size of dispatch groups. Z is 1 if not specified. + ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y); + // set the size of dispatch groups. + ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); + + // set the size of a workgroup grid. Y and Z are 1 if not specified. + ProgramBase& WorkgroupSize(uint32_t x); + // set the size of a workgroup grid. Z is 1 if not specified. + ProgramBase& WorkgroupSize(uint32_t x, uint32_t y); + // set the size of a workgroup grid. + ProgramBase& WorkgroupSize(uint32_t x, uint32_t y, uint32_t z); + + // set the uniform variables. + // + // the specified uniform variables should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& UniformVariables(std::initializer_list variables); + + // set the overridable constants + // + // the specified overridable constants should match the overridable constant definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS. + ProgramBase& OverridableConstants(std::initializer_list overridable_constants); + + // + // shader code generation + // + + virtual Status GenerateShaderCode(ShaderHelper& shader) const = 0; + + // + // abstract methods for getting metadata + // + // A derived class may contain any of the following static members: + // + // \code{.cpp} + // // define a list of constant that used in the shader program + // static constexpr const ProgramConstant constants[] = { ... }; + // + // // define a list of overridable constant that used in the shader program + // static constexpr const ProgramOverridableConstantDefinition overridable_constants[] = { ... }; + // + // // define a list of uniform variables that used in the shader program + // static constexpr const ProgramUniformVariableDefinition uniform_variables[] = { ... }; + // \endcode + // + // If those static members exist, the value of them will be used to generate the metadata. + virtual ProgramMetadata GetMetadata() const = 0; + + // + // Properties Getters + // + + inline const std::string& Name() const { return name_; } + inline const std::string& CacheHint() const { return cache_hint_; } + inline const std::vector& Inputs() const { return inputs_; } + inline const std::vector& Outputs() const { return outputs_; } + inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } + inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } + inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } + inline uint32_t WorkgroupSizeX() const { return workgroup_size_x_; } + inline uint32_t WorkgroupSizeY() const { return workgroup_size_y_; } + inline uint32_t WorkgroupSizeZ() const { return workgroup_size_z_; } + inline const std::vector& UniformVariables() const { return variables_; } + inline const std::vector& OverridableConstants() const { return overridable_constants_; } + + protected: + virtual ~ProgramBase() = default; + + private: + // Make the constructor private to prevent direct instantiation or inheritance from this class + // Use the Program template class as base class to create a new program class + explicit ProgramBase(const std::string& name); + + std::string name_; + std::string cache_hint_; + std::vector inputs_; + std::vector outputs_; + + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + uint32_t workgroup_size_x_; + uint32_t workgroup_size_y_; + uint32_t workgroup_size_z_; + + std::vector variables_; + std::vector overridable_constants_; + + friend class detail::ProgramWrapper; +}; + +namespace detail { +// class ProgramWrapper is for accessing private constructor of ProgramBase. +// only ProgramWrapper can access the constructor of ProgramBase because ProgramWrapper is the only friend class of +// ProgramBase. This design is used to prevent direct instantiation or inheritance from ProgramBase. +class ProgramWrapper : public ProgramBase { + protected: + template + ProgramWrapper(Args&&... args) : ProgramBase{std::forward(args)...} {} +}; + +#if defined(ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK) +#error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined" +#endif + +#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ + private: \ + template \ + static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ + template \ + static auto test_has_##identifier(...)->std::false_type; \ + \ + template && /* - is array */ \ + std::is_const_v && /* - has "const" modifier */ \ + std::is_convertible_v && /* - can convert to a const pointer */ \ + !std::is_member_pointer_v>> /* - is static */ \ + static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ + template \ + static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ + \ + public: \ + static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ + static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type(0))::value + +// the following template class checks whether certain static members exist in the derived class (SFINAE) +template +class DerivedProgramClassTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(constants, ProgramConstant); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(overridable_constants, ProgramOverridableConstantDefinition); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(uniform_variables, ProgramUniformVariableDefinition); +}; + +// compile-time tests for the type check +namespace test { + +struct TestClass_Empty {}; +struct TestClass_0 { + int b; +}; +struct TestClass_1 { + int a; +}; +struct TestClass_2 { + const int a; +}; +struct TestClass_3 { + const int a[2]; +}; +struct TestClass_4 { + static constexpr int a[] = {0}; +}; +struct TestClass_5 { + static int a[]; +}; +struct TestClass_6 { + static const int a[]; +}; + +template +class TestTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +}; + +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +} // namespace test + +#undef ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK + +} // namespace detail + +struct ProgramMetadata { + gsl::span constants; + gsl::span overridable_constants; + gsl::span uniform_variables; +}; + +template +class Program : public detail::ProgramWrapper { + public: + template + Program(Args&&... args) : detail::ProgramWrapper{std::forward(args)...} {} + + virtual ProgramMetadata GetMetadata() const final { + ProgramMetadata metadata; + if constexpr (detail::DerivedProgramClassTypeCheck::has_constants) { + constexpr const ProgramConstant* ptr = T::constants; + constexpr size_t len = sizeof(T::constants) / sizeof(ProgramConstant); + + static_assert(detail::DerivedProgramClassTypeCheck::has_constants_with_correct_type && + sizeof(T::constants) % sizeof(ProgramConstant) == 0, + "Derived class of \"Program\" has member \"constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() to declare constants."); + + metadata.constants = {ptr, len}; + } else { + metadata.constants = {}; + } + + if constexpr (detail::DerivedProgramClassTypeCheck::has_overridable_constants) { + constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants; + constexpr size_t len = sizeof(T::overridable_constants) / sizeof(ProgramOverridableConstantDefinition); + + static_assert(detail::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type && + sizeof(T::overridable_constants) % sizeof(ProgramOverridableConstantDefinition) == 0, + "Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() to declare overridable constants."); + + metadata.overridable_constants = {ptr, len}; + } else { + metadata.overridable_constants = {}; + } + + if constexpr (detail::DerivedProgramClassTypeCheck::has_uniform_variables) { + constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables; + constexpr size_t len = sizeof(T::uniform_variables) / sizeof(ProgramUniformVariableDefinition); + + static_assert(detail::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type && + sizeof(T::uniform_variables) % sizeof(ProgramUniformVariableDefinition) == 0, + "Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() to declare uniform variables."); + + metadata.uniform_variables = {ptr, len}; + } else { + metadata.uniform_variables = {}; + } + + return metadata; + } +}; + +#define WEBGPU_PROGRAM_DEFINE_CONSTANTS(...) \ + static constexpr const onnxruntime::webgpu::ProgramConstant constants[] = {__VA_ARGS__} + +#define WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS(...) \ + static constexpr const onnxruntime::webgpu::ProgramOverridableConstantDefinition overridable_constants[] = {__VA_ARGS__} + +#define WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(...) \ + static constexpr const onnxruntime::webgpu::ProgramUniformVariableDefinition uniform_variables[] = {__VA_ARGS__} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc new file mode 100644 index 000000000000..d720c55fb542 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/program_cache_key.h" + +#include "core/providers/webgpu/shader_macros.h" + +namespace onnxruntime { +namespace webgpu { + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + // final key format: + // =[]:::: + // + // = ||... + // = ,, + // = + // = ||... + // = + // = ||... + // = ; + ss << program.Name(); + + // append custom cache hint if any + if (auto& hint = program.CacheHint(); !hint.empty()) { + ss << "[" D("CacheHint=") << hint << "]"; + } + + // append workgroup size if overridden + if (auto x = program.WorkgroupSizeX(), y = program.WorkgroupSizeY(), z = program.WorkgroupSizeZ(); + x != 0 || y != 0 || z != 0) { + ss << ":" D("WorkgroupSize="); + // only append non-zero values. zero values are considered as use default + if (x > 0) { + ss << x; + } + ss << ","; + if (y > 0) { + ss << y; + } + ss << ","; + if (z > 0) { + ss << z; + } + } + + ss << ":" D("DispatchDim=") << is_1d_dispatch ? "1" : "3"; + ss << ":" D("UniformSizes="); + bool first = true; + for (const auto& uniform : program.UniformVariables()) { + if (first) { + first = false; + } else { + ss << "|"; + } + if (uniform.length > 0) { + ss << uniform.length; + } + } + ss << ":" D("Inputs="); + first = true; + for (const auto& input : program.Inputs()) { + if (first) { + first = false; + } else { + ss << "|"; + } + if ((input.dependency & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { +#ifndef NDEBUG // if debug build + ss << DataTypeImpl::ToString(input.tensor->DataType()); +#else + ss << input.tensor->GetElementType(); +#endif + } + ss << ";"; + if ((input.dependency & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { + ss D("Rank=") << input.tensor->Shape().NumDimensions(); + } else if ((input.dependency & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { + ss D("Dims=") << input.tensor->Shape().ToString(); + } + } + + return ss.str(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.h b/onnxruntime/core/providers/webgpu/program_cache_key.h new file mode 100644 index 000000000000..22ba19ebd0f2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc new file mode 100644 index 000000000000..de228a038b7d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/common/safeint.h" + +#include "core/common/common.h" +#include "core/common/logging/logging.h" + +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline) + : name{program.Name()}, compute_pipeline{compute_pipeline} { + // prepare uniform info + size_t current_offset = 0; + for (const auto& uniform : program.UniformVariables()) { + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + size_t length = uniform.length; + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // https://www.w3.org/TR/WGSL/#alignof + size_t base_alignment = is_f16 + ? (length > 4 ? 16 : length > 2 ? 8 + : length * element_size) + : (length > 2 ? 16 : length * element_size); + size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; + + current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + uniforms.push_back({uniform.data_type, current_offset, length}); + + // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). + // For float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). + size_t element_per_struct = is_f16 ? 8 : 4; + current_offset += + length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + } + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. + const int max_alignment_of_field = 16; + uniform_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; +} + +Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { + ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); + + auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; + if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { + auto size = static_cast(x) * static_cast(y) * static_cast(z); + SafeInt dispatch_avg = std::ceil(std::sqrt(size)); + if (dispatch_avg > limit_per_dimension) { + dispatch_avg = std::ceil(std::cbrt(size)); + ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); + x = y = z = dispatch_avg; + } else { + x = y = dispatch_avg; + z = 1; + } + } + return Status::OK(); +} + +Status ProgramManager::Build(const ProgramBase& program, + const ProgramMetadata& program_metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline) const { + ShaderHelper shader_helper{program, + program_metadata, + device_, + limits_, + normalized_dispatch_x, + normalized_dispatch_y, + normalized_dispatch_z}; + ORT_RETURN_IF_ERROR(shader_helper.Init()); + + ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + + // code is a large std::string that contains the final shader code + auto code = shader_helper.GetFinalSourceCode(); + + LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] Start ===\n\n" + << code + << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] End ===\n"; + + wgpu::ShaderModuleWGSLDescriptor wgsl_descriptor{}; + wgsl_descriptor.code = code.c_str(); + + wgpu::ShaderModuleDescriptor descriptor{}; + descriptor.nextInChain = &wgsl_descriptor; + + auto shader_module = device_.CreateShaderModule(&descriptor); + + // process overridable constants if available + size_t constant_count = program.OverridableConstants().size(); + + // making a copy of the constant names is required because they are stored as std::string_view in the program + // metadata. A value of std::string_view is not guaranteed to be a C-stlye string (null-terminated) and hence + // cannot be used directly in the WebGPU API (which expects a const char*). + std::vector constant_names; + constant_names.reserve(constant_count); + std::vector constant_entries; + constant_entries.reserve(constant_count); + for (size_t i = 0; i < constant_count; ++i) { + const auto& constant_override = program.OverridableConstants()[i]; + const auto& constant_def = program_metadata.overridable_constants[i]; + + if (constant_override.has_value) { + double value = 0; + switch (constant_override.type) { + case ProgramConstantDataType::Bool: + value = constant_override.boolean ? 1 : 0; + break; + case ProgramConstantDataType::Float16: + // convert f16(MLFloat16) -> f32(float) -> f64(double) + // because the value of a constant must be a double in WebGPU API, it is expensive to use f16 overridable constants. + value = constant_override.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + value = constant_override.f32; + break; + case ProgramConstantDataType::Int32: + value = constant_override.i32; + break; + case ProgramConstantDataType::Uint32: + value = constant_override.u32; + break; + } + + const auto& name_string = constant_names.emplace_back(constant_def.name); + wgpu::ConstantEntry entry{}; + entry.key = name_string.c_str(); + entry.value = value; + constant_entries.push_back(std::move(entry)); + } + } + + wgpu::ProgrammableStageDescriptor compute_stage{}; + compute_stage.module = shader_module; + compute_stage.entryPoint = "main"; + if (!constant_entries.empty()) { + compute_stage.constants = constant_entries.data(); + compute_stage.constantCount = constant_entries.size(); + } + + wgpu::ComputePipelineDescriptor pipeline_descriptor{}; + pipeline_descriptor.compute = compute_stage; +#ifndef NDEBUG // if debug build + pipeline_descriptor.label = program.Name().c_str(); +#endif + + compute_pipeline = device_.CreateComputePipeline(&pipeline_descriptor); + + return Status(); +} + +const ProgramArtifact* ProgramManager::Get(const std::string& key) const { + auto result = programs_.find(key); + if (result != programs_.end()) { + return &result->second; + } + + return nullptr; +} + +const ProgramArtifact* ProgramManager::Set(const std::string& key, ProgramArtifact&& program) { + return &(programs_.emplace(key, std::move(program)).first->second); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h new file mode 100644 index 000000000000..9d1b7655c864 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { + +struct ProgramUniformInfo { + ProgramUniformVariableDataType data_type; + size_t offset; + size_t length; +}; + +class ProgramArtifact { + public: + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline); + + std::string name; + wgpu::ComputePipeline compute_pipeline; + std::vector uniforms; + size_t uniform_total_size; + + ProgramArtifact(ProgramArtifact&&) = default; + ProgramArtifact& operator=(ProgramArtifact&&) = default; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact); +}; + +class ProgramManager { + public: + ProgramManager(const wgpu::Device& device, const wgpu::Limits& limits) : device_(device), limits_(limits) {} + + Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const; + + Status Build(const ProgramBase& program, + const ProgramMetadata& metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline) const; + const ProgramArtifact* Get(const std::string& key) const; + const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); + + private: + std::unordered_map programs_; + const wgpu::Device& device_; + const wgpu::Limits& limits_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc new file mode 100644 index 000000000000..203f11ff9000 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +ShaderHelper::ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z) + : device_{device}, + limits_{limits}, + dispatch_group_size_x_{dispatch_group_size_x}, + dispatch_group_size_y_{dispatch_group_size_y}, + dispatch_group_size_z_{dispatch_group_size_z}, + program_{program}, + program_metadata_{program_metadata}, + use_f16_{false} { +} + +Status ShaderHelper::Init() { + // dispatch group size is normalized so no need to validate it here + + // validate workgroup size + auto workgroup_size_x = program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ(); + + ORT_RETURN_IF_NOT(workgroup_size_x > 0 && workgroup_size_y > 0 && workgroup_size_z > 0, + "Workgroup size must be greater than 0"); + ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && + workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && + workgroup_size_z <= limits_.maxComputeWorkgroupSizeZ, + "Workgroup size exceeds the maximum allowed size [", + limits_.maxComputeWorkgroupSizeX, ", ", + limits_.maxComputeWorkgroupSizeY, ", ", + limits_.maxComputeWorkgroupSizeZ, "]"); + + ORT_RETURN_IF_NOT(workgroup_size_x * workgroup_size_y * workgroup_size_z <= limits_.maxComputeInvocationsPerWorkgroup, + "Workgroup size exceeds the maximum allowed invocations ", limits_.maxComputeInvocationsPerWorkgroup); + + // init body string stream + bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; + body_.imbue(std::locale::classic()); + + // append header for main function so it is ready for user to append main function body + body_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" + "fn main(@builtin(global_invocation_id) global_id : vec3,\n" + " @builtin(workgroup_id) workgroup_id : vec3,\n" + " @builtin(local_invocation_id) local_id : vec3"; + if (!is_1d_dispatch) { + body_ << ",\n" + " @builtin(local_invocation_index) local_idx : u32,\n" + " @builtin(num_workgroups) num_workgroups : vec3"; + } + body_ << ") {\n"; + if (is_1d_dispatch) { + body_ << " let global_idx = global_id.x;\n" + " let local_idx = local_id.x;\n"; + } else { + body_ << " let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x)\n" + " * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + } + + // init additional implementation string stream + additional_implementation_.imbue(std::locale::classic()); + + return Status::OK(); +} + +std::string ShaderHelper::GetFinalSourceCode() { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + // + // Section feature enabling + // + if (use_f16_) { + ORT_ENFORCE(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); + ss << "enable f16;\n"; + } + + // + // Section constants + // + ss << "\nconst workgroup_size_x: u32 = " << program_.WorkgroupSizeX() + << ";\nconst workgroup_size_y: u32 = " << program_.WorkgroupSizeY() + << ";\nconst workgroup_size_z: u32 = " << program_.WorkgroupSizeZ() << ";\n"; + + for (const auto& constant : program_metadata_.constants) { + ss << "const " << constant.name << ": " << constant.type << " = "; + WriteConstantValue(ss, constant); + ss << ";\n"; + } + + size_t override_constant_count = program_metadata_.overridable_constants.size(); + for (size_t i = 0; i < override_constant_count; ++i) { + // size and type are previously checked to match + const auto& constant_def = program_metadata_.overridable_constants[i]; + const auto& constant_override = program_.OverridableConstants()[i]; + + ss << "override " << constant_def.name << ": " << constant_def.type << " = "; + if (constant_override.has_value) { + WriteConstantValue(ss, constant_override); + } else { + WriteConstantValue(ss, constant_def); + } + ss << ";\n"; + } + + // + // Input/output variables + // + int variable_count = 0; + for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; + } + for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; + } + + // + // uniform variables + // + if (std::any_of(program_.UniformVariables().cbegin(), + program_.UniformVariables().cend(), + [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { + bool first = true; + ss << "struct Uniforms {\n"; + + size_t uniform_count = program_.UniformVariables().size(); + for (size_t i = 0; i < uniform_count; i++) { + const auto& uniform_def = program_metadata_.uniform_variables[i]; + const auto& uniform_value = program_.UniformVariables()[i]; + + const auto& name = uniform_def.name; + const auto& data_type = uniform_def.data_type; + const auto length = uniform_value.length; + + if (first) { + first = false; + } else { + ss << ",\n"; + } + + auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; + ss << " " << alignment << name << ": "; + if (length > 4) { + if (data_type == ProgramUniformVariableDataType::Float16) { + size_t array_size = (length + 7) / 8; + ss << "array, " << array_size << ">"; + } else { + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; + } + } else if (length > 1) { + ss << "vec" << length << "<" << data_type << ">"; + } else { + ss << data_type; + } + } + + ss << "};\n" + "@group(0) @binding(" + << variable_count << ") var uniforms: Uniforms;\n"; + } + + // + // Indices helper + // + ss << "\n"; + // for (const auto& group : vars_) { + // } + + // + // Additional Implementation + // + ss << additional_implementation_.str(); + additional_implementation_.str(""); + + // + // Main Function Body + // + ss << body_.str(); + body_.str(""); + ss << "\n" + "}\n"; + + return ss.str(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h new file mode 100644 index 000000000000..ac6dfebfef81 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/common/safeint.h" +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_variable.h" + +namespace onnxruntime { +namespace webgpu { + +class ShaderHelper final { + // The content of a shader code is composed of the following parts: + // + // ** + // ** section: feature sets definition + // ** + // // this sections enable features like "enable f16;". need to be defined at the beginning of the shader. + // + // ** + // ** section: constants and overridable constants + // ** + // // this section defines constants and overridable constants. + // - constants are defined as "const a:f32 = 1.0;". It's hard coded in the shader. + // - overridable constants are defined as "override a:f32 = 1.0;" (may override or not) + // or "override b:u32;" (must override) + // the value can be overriden by pipeline creation config. + // + // ** + // ** section: inputs and outputs + // ** + // // this section defines input and output variables. + // user can call shader_helper.AddVariable() to add input and output variables. + // + // ** + // ** section: uniforms + // ** + // // this section defines uniform type and variables. + // + // ** + // ** section: indices helper generated utility functions + // ** + // // this section defines utility functions to calculate indices. + // + // ** + // ** section: additional implementation + // ** + // // this section contains additional implementation provided by the user. + // user can call shader_helper.AppendImplementation() to append additional implementation. + // + // ** + // ** section: main function + // ** + // // this section contains the main function of the shader. + // user can call shader_helper.MainFunctionBody() to set the main function body. + // + + public: + ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z); + + Status Init(); + + const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, int rank = 1) { + return AddVariableImpl(scope, name, type, rank); + } + const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, const TensorShape& dims) { + return AddVariableImpl(scope, name, type, dims); + } + + template + inline std::ostringstream& AppendImplementation(Strs&&... impl) { + onnxruntime::detail::MakeStringImpl(additional_implementation_, std::forward(impl)...); + return additional_implementation_; + } + + template + inline std::ostringstream& MainFunctionBody(Strs&&... body) { + onnxruntime::detail::MakeStringImpl(body_, std::forward(body)...); + return body_; + } + + std::string GuardAgainstOutOfBoundsWorkgroupSizes(const std::string& size) const { + return " if (global_idx >= " + size + ") { return; }\n"; + } + + private: + template // T is one of {int, const TensorShape&} + const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, T&& arg) { + ORT_ENFORCE((scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) && + vars_[static_cast(ProgramVariableScope::Input)].size() + vars_[static_cast(ProgramVariableScope::Output)].size() < limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); + + if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { + use_f16_ = true; + } + + return vars_[static_cast(scope)].emplace_back(name, type, std::forward(arg)); + } + + template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} + void WriteConstantValue(std::ostringstream& ss, const ConstantType& constant) const { + switch (constant.type) { + case ProgramConstantDataType::Float16: + ss << constant.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + ss << constant.f32; + break; + case ProgramConstantDataType::Int32: + ss << constant.i32; + break; + case ProgramConstantDataType::Uint32: + ss << constant.u32; + break; + case ProgramConstantDataType::Bool: + ss << (constant.boolean ? "true" : "false"); + break; + default: + ORT_THROW("Invalid constant type", constant.type); + } + } + + std::string GetFinalSourceCode(); + friend class ProgramManager; + + const wgpu::Device& device_; + const wgpu::Limits& limits_; + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + const ProgramBase& program_; + const ProgramMetadata& program_metadata_; + + std::array, static_cast(ProgramVariableScope::Count)> vars_; + std::ostringstream ss2; + std::ostringstream additional_implementation_; + std::ostringstream body_; + + bool use_f16_ = false; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_macros.h b/onnxruntime/core/providers/webgpu/shader_macros.h new file mode 100644 index 000000000000..a1c61950e6a1 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_macros.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// macro "D": append to the ostream only in debug build +// +// Usage example: +// +// ss << "error code: " << err_code D(" (") << D(err_msg) D(")"); +// +// This resolves to: (debug build) +// ss << "error code: " << err_code << " (" << err_msg << ")"; +// +// This resolves to: (release build) +// ss << "error code: " << err_code; + +#ifdef D +#undef D +#endif + +#ifndef NDEBUG // if debug build +#define D(str) << str +#else +#define D(str) +#endif + +// macro "DSS" append to the ostream only in debug build +// (assume variable "ss" is in scope) +// +// Usage example: +// +// DSS << "detail error message: " << err_msg; +// +// This resolves to: (debug build) +// ss << "detail error message: " << err_msg; +// +// This resolves to: (release build) +// if constexpr (false) ss << "detail error message: " << err_msg; // no-op + +#ifdef DSS +#undef DSS +#endif + +#ifndef NDEBUG // if debug build +#define DSS ss +#else +#define DSS \ + if constexpr (false) ss +#endif + +// macro "SS" - use function call style to append to the ostream +// (assume variable "ss" is in scope) +// +// Usage example: +// +// SS("error code: ", err_code, " (", err_msg, ")"); +// +// This resolves to: +// ss << "error code: " << err_code << " (" << err_msg << ")"; + +#ifdef SS +#undef SS +#endif + +#define SS(...) ::onnxruntime::detail::MakeStringImpl(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc new file mode 100644 index 000000000000..d49d76c1ee85 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/providers/webgpu/shader_variable.h" + +#include "core/providers/webgpu/shader_macros.h" + +namespace onnxruntime { +namespace webgpu { + +ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank) + : name_(name), type_(type), rank_(rank), usage_(UseUniform) { + Init(); +} + +ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims) + : name_(name), type_(type), rank_(static_cast(dims.NumDimensions())), dims_(dims), usage_(None) { + Init(); +} + +void ShaderVariable::Init() { + ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); +} + +void ShaderVariable::Impl(std::ostringstream& ss) { + // Start generating code + + const std::string value_t = name_ + "_value_t"; + const std::string indices_t = name_ + "_indices_t"; + + const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; + const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; + + // Types + SS("alias ", value_t, " = ", ValueType(), ";\n"); + SS("alias ", indices_t, " = ", IndicesType(), ";\n"); + + // Need shape and strides when (not use uniform) and (any other usage is enabled) + if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { + SS("const ", shape, " = ", indices_t, "("); + + bool first = true; + for (auto dim : dims_.GetDims()) { + if (!first) { + ss << ","; + } + + ss << dim; + first = false; + } + ss << ");\n"; + + SS("const ", stride, " = ", indices_t, "("); + first = true; + for (int i = rank_ - 1; i >= 0; i--) { + if (!first) { + ss << ","; + } + ss << dims_.SizeToDimension(i); + first = false; + } + ss << ");\n"; + } + + // Implementation of "fn o2i_{name}" + if (usage_ & UseOffsetToIndices) { + if (rank_ >= 2) { + SS("fn o2i_", name_, "(offset : u32)->", indices_t, " {\n"); + SS(" var indices: ", indices_t, ";\n"); + SS(" var current = offset;\n"); + for (size_t i = 0; i < rank_ - 1; i++) { + auto current_stride = GetElementAt(stride, i, rank_); + SS(" let dim", i, " = current / ", current_stride, ";\n"); + SS(" let rest", i, " = current % ", current_stride, ";\n"); + SS(" indices[", i, "] = dim", i, ";\n"); + SS(" current = rest", i, ";\n"); + } + SS(" indices[", rank_ - 1, "] = current;\n"); + SS(" return indices;\n"); + SS("}\n"); + } + } + + // Implementation of "fn i2o_{name}" + if (usage_ & UseIndicesToOffset) { + if (rank_ >= 2) { + SS("fn i2o_", name_, "(indices : ", indices_t, ")->u32 {\n"); + SS(" return "); + for (size_t i = 0; i < rank_ - 1; i++) { + SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); + } + SS("indices[", rank_ - 1, "];\n"); + SS("}\n"); + } + } + + // Implementation of "fn {res_name}_bi2o_{name}" + if (usage_ & UseBroadcastedIndicesToOffset) { + // TODO: do we need this if rank < 2? + for (const auto& iter : broadcasted_to_) { + const auto& broadcasted_result = iter.get(); + SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.IndicesType(), ")->u32 {\n"); + if (rank_ == 0) { + SS(" return 0;\n"); + } else { + SS(" return "); + for (size_t i = 0; i < rank_ - 1; i++) { + auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); + SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); + } + SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + } + SS("}\n"); + } + } + + // Implementation of "fn set_{name}" + if (usage_ & UseSet) { + if (rank_ >= 2) { + SS("fn set_", name_, "(d0: u32"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i, ": u32"); + } + SS(", value: ", value_t, ") {\n"); + SS(" set_", name_, "_by_indices(d0"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i); + } + SS(", value);\n"); + SS("}\n"); + } + } + + // Implementation of "fn set_{name}_by_indices" + if (usage_ & UseSetByIndices) { + if (rank_ >= 2) { + SS("fn set_", name_, "_by_indices(indices: ", indices_t, ", value: ", value_t, ") {\n"); + SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); + SS("}\n"); + } + } + + // Implementation of "fn get_{name}" + if (usage_ & UseGet) { + if (rank_ >= 2) { + SS("fn get_", name_, "(d0: u32"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i, ": u32"); + } + SS(")->", value_t, " {\n"); + SS(" return get_", name_, "_by_indices(d0"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i); + } + SS(");\n"); + SS("}\n"); + } + } + + // Implementation of "fn get_{name}_by_indices" + if (usage_ & UseGetByIndices) { + if (rank_ >= 2) { + SS("fn get_", name_, "_by_indices(indices: ", indices_t, ")->", value_t, " {\n"); + SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS("}\n"); + } + } +} + +std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + ss << "i32(" << name_ << "[" << offset << "].x)"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << "u32(" << name_ << "[" << offset << "].x)"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + ss << "vec4(bool(" + << name_ << "[" << offset << "] & 0xFFu), bool(" + << name_ << "[" << offset << "] & 0xFF00u), bool(" + << name_ << "[" << offset << "] & 0xFF0000u), bool(" + << name_ << "[" << offset << "] & 0xFF000000u))"; + break; + default: + ss << name_ << "[" << offset << "]"; + } + + return ss.str(); +} + +std::string ShaderVariable::SetByOffsetImpl(const std::string& offset, const std::string& value) const { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), select(0u, 0xFFFFFFFFu, " << value << " < 0));"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), 0u);"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + ss << name_ << "[" << offset << "]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(" << value << "));"; + break; + default: + ss << name_ << "[" << offset << "]=" << value << ";"; + } + + return ss.str(); +} + +std::string_view ShaderVariable::StorageType() const { + constexpr static const std::string_view STORAGE_TYPE[] = { + "f32", // f32 + "vec2", // vec2f32 + "vec4", // vec4f32 + "f16", // f16 + "vec2", // vec2f16 + "vec4", // vec4f16 + "i32", // i32 + "vec2", // vec2i32 + "vec4", // vec4i32 + "u32", // u32 + "vec2", // vec2u32 + "vec4", // vec4u32 + "vec2", // int64 + "vec2", // uint64 + "u32", // vec4bool + }; + + return STORAGE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariable::ValueType() const { + constexpr static const std::string_view VALUE_TYPE[] = { + "f32", // f32 + "f32", // vec2f32 + "f32", // vec4f32 + "f16", // f16 + "f16", // vec2f16 + "f16", // vec4f16 + "i32", // i32 + "i32", // vec2i32 + "i32", // vec4i32 + "u32", // u32 + "u32", // vec2u32 + "u32", // vec4u32 + "i32", // int64 (trancated to i32) + "u32", // uint64 (trancated to u32) + "vec4", // vec4bool + }; + + return VALUE_TYPE[static_cast(type_)]; +} + +std::string ShaderVariable::IndicesType() const { + return rank_ < 2 ? "u32" + : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") + : MakeStringWithClassicLocale("array")); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h new file mode 100644 index 000000000000..0a5cad823787 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/safeint.h" +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +template +std::string GetElementAt(const std::string& var, const TIdx& idx, int rank, bool is_f16 = false) { + // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. + if (var.rfind("uniform.", 0) == 0) { + if (rank > 4) { + if constexpr (std::is_integral_v) { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + } else { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + } + } else { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + } else { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); + } + } + } else { + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; + } + } else { + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; + } +} + +class ShaderVariable { + public: + ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank); + ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims); + + ShaderVariable(ShaderVariable&&) = default; + ShaderVariable& operator=(ShaderVariable&&) = default; + + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. + // \param offset: a WGSL expression (u32) representing the offset. + inline std::string OffsetToIndices(const std::string& offset_expr) const; + + // create a WGSL expression (u32) for getting offset from indices. + // \param indices: a WGSL expression ({varname}_indices_t) representing the indices. + inline std::string IndicesToOffset(const std::string& indices_expr) const; + + // create a WGSL expression (u32) for getting original offset from broadcasted indices. + // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. + // \param broadcasted_result: the broadcasted result variable. + inline std::string BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const; + + // create a WGSL expression ({varname}_indices_t) as an indices literal + // \param init: a list of indices values. + template + inline std::string Indices(TIndices&&... indices_args) const; + + // create a WGSL statement for setting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to set. + // \param value: the value (u32) to set. + template + inline std::string IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const; + + // create a WGSL expression (u32) for getting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to get. + template + inline std::string IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const; + + // create a WGSL statement for setting data at the given indices. + // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). + template + inline std::string Set(TIndicesAndValue&&... args) const; + + // create a WGSL statement for setting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param value: the value ({varname}_value_t) to set. + inline std::string SetByIndices(const std::string& indices_var, const std::string& value) const; + + // create a WGSL statement for setting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + // \param value: the value ({varname}_value_t) to set. + template + inline std::string SetByOffset(TOffset&& offset, TValue&& value) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices: a list of indices values (u32). + template + inline std::string Get(TIndices&&... indices) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + inline std::string GetByIndices(const std::string& indices_var) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + template + inline std::string GetByOffset(TOffset&& offset) const; + + private: + enum Usage : uint32_t { + None = 0, + UseOffsetToIndices = 1, + UseIndicesToOffset = 2, + UseBroadcastedIndicesToOffset = 4, + UseSet = 8, + UseSetByIndices = 16, + UseGet = 32, + UseGetByIndices = 64, + UseUniform = 128, + }; + + friend ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b); + friend ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b); + friend ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b); + friend ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b); + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); + + void Init(); + void Impl(std::ostringstream& ss); + + std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const; + std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; + + std::string_view StorageType() const; + std::string_view ValueType() const; + std::string IndicesType() const; + + std::string name_; + ProgramVariableDataType type_; + int rank_; + TensorShape dims_; + + mutable Usage usage_; + mutable std::vector> broadcasted_to_; + + friend class ShaderHelper; +}; + +inline ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage)((uint32_t&)a | (uint32_t&)b); +} +inline ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage)((uint32_t&)a & (uint32_t&)b); +} +inline ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage&)((uint32_t&)a |= (uint32_t&)b); +} +inline ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage&)((uint32_t&)a &= (uint32_t&)b); +} + +namespace detail { +template >> +std::string pass_as_string(T&& v) { + return std::to_string(std::forward(v)); +} +template +std::string pass_as_string(const T& v) { + return v; +} +} // namespace detail + +inline std::string ShaderVariable::OffsetToIndices(const std::string& offset_expr) const { + usage_ |= UseOffsetToIndices; + return rank_ < 2 ? offset_expr : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); +} + +inline std::string ShaderVariable::IndicesToOffset(const std::string& indices_expr) const { + usage_ |= UseIndicesToOffset; + return rank_ < 2 ? indices_expr : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); +} + +inline std::string ShaderVariable::BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const { + usage_ |= UseBroadcastedIndicesToOffset; + broadcasted_to_.push_back(broadcasted_result); + return MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); +} + +template +inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { + return rank_ == 0 ? "" : MakeStringWithClassicLocale(name_, "_indices_t(", onnxruntime::detail::StringJoinImpl(", ", std::forward(indices_args)...), ')'); +} + +template +inline std::string ShaderVariable::IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const { + return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') + : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); +} + +template +inline std::string ShaderVariable::IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const { + return rank_ < 2 ? indices_var : GetElementAt(indices_var, idx_expr, rank_); +} + +template +inline std::string ShaderVariable::SetByOffset(TOffset&& offset, TValue&& value) const { + return SetByOffsetImpl(detail::pass_as_string(offset), detail::pass_as_string(value)); +} + +template +inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { + ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); + if constexpr (sizeof...(TIndicesAndValue) == 1) { + return SetByOffset("0", std::forward(args)...); + } else if constexpr (sizeof...(TIndicesAndValue) == 2) { + return SetByOffset(std::forward(args)...); + } else { + usage_ |= UseSet | UseSetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(args)...), ");"); + } +} + +inline std::string ShaderVariable::SetByIndices(const std::string& indices_var, const std::string& value) const { + if (rank_ < 2) { + return SetByOffset(indices_var, value); + } else { + usage_ |= UseSetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, "_by_indices(", indices_var, ", ", value, ");"); + } +} + +template +inline std::string ShaderVariable::GetByOffset(TOffset&& offset) const { + return GetByOffsetImpl(detail::pass_as_string(offset)); +} + +template +inline std::string ShaderVariable::Get(TIndices&&... indices) const { + ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); + if constexpr (sizeof...(TIndices) == 0) { + return GetByOffset("0"); + } else if constexpr (sizeof...(TIndices) == 1) { + return GetByOffset(std::forward(indices)...); + } else { + usage_ |= UseGet | UseGetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(indices)...), ')'); + } +} + +inline std::string ShaderVariable::GetByIndices(const std::string& indices_var) const { + if (rank_ < 2) { + return GetByOffset(indices_var); + } else { + usage_ |= UseGetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, "_by_indices(", indices_var, ")"); + } +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc new file mode 100644 index 000000000000..a891f5a8a551 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -0,0 +1,349 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/common/common.h" + +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/program_cache_key.h" +#include "core/providers/webgpu/program_manager.h" + +namespace onnxruntime { +namespace webgpu { + +std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16}; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::RequiredLimits GetAvailableRequiredLimits(const wgpu::Adapter& adapter) { + wgpu::RequiredLimits required_limits{}; + wgpu::SupportedLimits adapter_limits; + ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); + + required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; + required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; + required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + +void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info) { + std::call_once(init_flag_, [this, &webgpu_ep_info]() { + // Initialization.Step.1 - Create wgpu::Instance + if (instance_ == nullptr) { + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.features.timedWaitAnyEnable = true; + instance_ = wgpu::CreateInstance(&instance_desc); + + ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance."); + } + + // Initialization.Step.2 - Create wgpu::Adapter + if (adapter_ == nullptr) { + wgpu::RequestAdapterOptions req_adapter_options = {}; + wgpu::RequestAdapterCallbackInfo req_adapter_callback_info = {}; + req_adapter_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; + req_adapter_callback_info.callback = [](WGPURequestAdapterStatus status, + WGPUAdapter adapter, const char* message, + void* userdata) { + ORT_ENFORCE(status == WGPURequestAdapterStatus_Success, "Failed to get a WebGPU adapter: ", message); + *static_cast(userdata) = wgpu::Adapter::Acquire(adapter); + }; + req_adapter_callback_info.userdata = &adapter_; + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(&req_adapter_options, req_adapter_callback_info), UINT64_MAX)); + ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter."); + } + + // Initialization.Step.3 - Create wgpu::Device + if (device_ == nullptr) { + wgpu::DeviceDescriptor device_desc = {}; + std::vector required_features = GetAvailableRequiredFeatures(adapter_); + if (required_features.size() > 0) { + device_desc.requiredFeatures = required_features.data(); + } + wgpu::RequiredLimits required_limits = GetAvailableRequiredLimits(adapter_); + device_desc.requiredLimits = &required_limits; + + // TODO: temporary error handling + device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, const char* message) { + LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << message; + }); + + wgpu::RequestDeviceCallbackInfo req_device_callback_info = {}; + req_device_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; + req_device_callback_info.callback = [](WGPURequestDeviceStatus status, WGPUDevice device, char const* message, void* userdata) { + ORT_ENFORCE(status == WGPURequestAdapterStatus_Success, "Failed to get a WebGPU device: ", message); + *static_cast(userdata) = wgpu::Device::Acquire(device); + }; + req_device_callback_info.userdata = &device_; + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(&device_desc, req_device_callback_info), UINT64_MAX)); + ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); + } + + // cache adapter info + ORT_ENFORCE(Adapter().GetInfo(&adapter_info_)); + // cache device limits + wgpu::SupportedLimits device_supported_limits; + ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); + device_limits_ = device_supported_limits.limits; + + // create buffer manager + buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); + + // create program manager + program_mgr_ = std::make_unique(Device(), DeviceLimits()); + }); +} + +Status WebGpuContext::Wait(wgpu::Future f) { + auto status = instance_.WaitAny(f, UINT64_MAX); + if (status == wgpu::WaitStatus::Success) { + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); +} + +Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& program) { + const auto& inputs = program.Inputs(); + const auto& outputs = program.Outputs(); + +#ifndef NDEBUG // if debug build + ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { + const auto* tensor = input.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + tensor->Location().name == WEBGPU_BUFFER; + }), + "All inputs must be tensors on WebGPU buffers."); + + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](Tensor* tensor) { + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + tensor->Location().name == WEBGPU_BUFFER; + }), + "All outputs must be tensors on WebGPU buffers."); +#endif + + if (outputs.size() == 0) { + return Status::OK(); + } + + ProgramMetadata metadata = program.GetMetadata(); + + // validate program metadata + { + const auto& [constants, overridable_constants, uniform_variables] = metadata; + + // check overridable constants + ORT_RETURN_IF(program.OverridableConstants().size() != overridable_constants.size(), + "Size of overridable constants mismatch in program \"", program.Name(), + "\", Expected: ", overridable_constants.size(), + ", Actual: ", program.OverridableConstants().size()); + size_t num_overridable_constants = program.OverridableConstants().size(); + for (size_t i = 0; i < num_overridable_constants; ++i) { + const auto& override_value = program.OverridableConstants()[i]; + const auto& definition = overridable_constants[i]; + ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, + "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.type, + ", Actual: ", override_value.type); + ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, + "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), + "\""); + } + + // check uniform variables + ORT_RETURN_IF(program.UniformVariables().size() != uniform_variables.size(), + "Size of uniform_value variables mismatch in program \"", program.Name(), + "\", Expected: ", uniform_variables.size(), + ", Actual: ", program.UniformVariables().size()); + size_t num_uniform_variables = program.UniformVariables().size(); + for (size_t i = 0; i < num_uniform_variables; ++i) { + const auto& uniform_value = program.UniformVariables()[i]; + const auto& definition = uniform_variables[i]; + ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, + "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.data_type, + ", Actual: ", uniform_value.data_type); + } + } + + uint32_t x = program.DispatchGroupSizeX(); + uint32_t y = program.DispatchGroupSizeY(); + uint32_t z = program.DispatchGroupSizeZ(); + ORT_RETURN_IF_ERROR(program_mgr_->NormalizeDispatchGroupSize(x, y, z)); + + bool is_1d_dispatch = (y == 1 && z == 1); + + auto key = CalculateProgramCacheKey(program, is_1d_dispatch); + + const auto* program_artifact = program_mgr_->Get(key); + if (program_artifact == nullptr) { + wgpu::ComputePipeline compute_pipeline; + auto status = program_mgr_->Build(program, + metadata, +#ifndef NDEBUG // if debug build + key, +#endif + x, + y, + z, + compute_pipeline); + ORT_RETURN_IF_ERROR(status); + program_artifact = program_mgr_->Set(key, ProgramArtifact{program, std::move(compute_pipeline)}); +#ifndef NDEBUG // if debug build + ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); +#endif + } + + std::vector input_buffers; + input_buffers.reserve(inputs.size()); + for (const auto& input : inputs) { + input_buffers.push_back(reinterpret_cast(const_cast(input.tensor->DataRaw()))); + } + + std::vector output_buffers; + output_buffers.reserve(outputs.size()); + for (const auto& output : outputs) { + output_buffers.push_back(reinterpret_cast(output->MutableDataRaw())); + } + + WGPUBuffer uniform_buffer = nullptr; + auto uniform_buffer_size = program_artifact->uniform_total_size; + if (uniform_buffer_size > 0) { + auto num_uniforms = program.UniformVariables().size(); + ORT_ENFORCE(program_artifact->uniforms.size() == num_uniforms, + "Uniforms size mismatch. Artifact: ", program_artifact->uniforms.size(), ", Current: ", num_uniforms); + + std::vector uniform_data(uniform_buffer_size); + + for (size_t i = 0; i < num_uniforms; ++i) { + const auto& uniform = program.UniformVariables()[i]; + const auto& artifact_uniform = program_artifact->uniforms[i]; + + ORT_ENFORCE(uniform.data_type == artifact_uniform.data_type, + "Uniform[", i, "] data type mismatch. Artifact: ", artifact_uniform.data_type, + ", Current: ", uniform.data_type); + ORT_ENFORCE(uniform.length == artifact_uniform.length, + "Uniform[", i, "] elements number mismatch. Artifact: ", artifact_uniform.length, ", Current: ", uniform.length); + ORT_ENFORCE(uniform.data.size() == artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], + "Uniform[", i, "] data size mismatch. Artifact: ", artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], + ", Current: ", uniform.data.size()); + + auto offset = artifact_uniform.offset; + auto size = uniform.data.size(); + memcpy(uniform_data.data() + offset, uniform.data.data(), size); + } + + uniform_buffer = buffer_mgr_->Create(uniform_buffer_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data.data(), uniform_buffer_size); + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + + // TODO: write timestamp query + + uint32_t entry_index = 0; + std::vector bind_group_entries; + for (const auto& input : inputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(const_cast(input.tensor->DataRaw()))}); + } + for (const auto& output : outputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output->MutableDataRaw())}); + } + if (uniform_buffer) { + bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer}); + } + + wgpu::BindGroupDescriptor bind_group_desc{}; + bind_group_desc.layout = program_artifact->compute_pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = bind_group_entries.size(); + bind_group_desc.entries = bind_group_entries.data(); + bind_group_desc.label = program_artifact->name.c_str(); + + auto bind_group = Device().CreateBindGroup(&bind_group_desc); + + // TODO support graph capture + + compute_pass_encoder.SetPipeline(program_artifact->compute_pipeline); + compute_pass_encoder.SetBindGroup(0, bind_group); + compute_pass_encoder.DispatchWorkgroups(x, y, z); + + if (uniform_buffer) { + buffer_mgr_->Release(uniform_buffer); + } + + // TODO: write timestamp query + + ++num_pending_dispatches_; + + // if (querytype == at-passes) { EndComputePass(); } + + if (num_pending_dispatches_ >= max_num_pending_dispatches_) { + Flush(); + } + + return Status::OK(); +} + +std::unordered_map> WebGpuContextFactory::contexts_; +std::mutex WebGpuContextFactory::mutex_; + +WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) { + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device."); + } else { + // for context ID > 0, user must provide custom WebGPU instance, adapter and device. + ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr, + "WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device."); + } + + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + if (it == contexts_.end()) { + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device)); + it = contexts_.emplace(context_id, std::move(context)).first; + } else if (context_id != 0) { + ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, + "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device."); + } + return *it->second; +} + +WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); + + return *it->second; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h new file mode 100644 index 000000000000..d8b0c2b48b06 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/program_manager.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { +class WebGpuContext; +class ComputeContext; +class ProgramBase; + +class WebGpuContextFactory { + public: + static WebGpuContext& CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device); + static WebGpuContext& GetContext(int context_id); + + private: + WebGpuContextFactory() {} + + static std::unordered_map> contexts_; + static std::mutex mutex_; +}; + +// Class WebGpuContext includes all necessary resources for the context. +class WebGpuContext final { + public: + void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info); + + Status Wait(wgpu::Future f); + + const wgpu::Adapter& Adapter() const { return adapter_; } + const wgpu::Device& Device() const { return device_; } + + const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } + const wgpu::Limits& DeviceLimits() const { return device_limits_; } + + const wgpu::CommandEncoder& GetCommandEncoder() { + if (!current_command_encoder_) { + current_command_encoder_ = device_.CreateCommandEncoder(); + } + return current_command_encoder_; + } + + const wgpu::ComputePassEncoder& GetComputePassEncoder() { + if (!current_compute_pass_encoder_) { + auto& command_encoder = GetCommandEncoder(); + + wgpu::ComputePassDescriptor compute_pass_desc{}; + + // TODO: add support for GPU Query + + current_compute_pass_encoder_ = command_encoder.BeginComputePass(&compute_pass_desc); + } + return current_compute_pass_encoder_; + } + + void EndComputePass() { + if (current_compute_pass_encoder_) { + current_compute_pass_encoder_.End(); + current_compute_pass_encoder_ = nullptr; + } + } + + void Flush() { + if (!current_command_encoder_) { + return; + } + + EndComputePass(); + + // TODO: add support for GPU Query + + auto command_buffer = current_command_encoder_.Finish(); + Device().GetQueue().Submit(1, &command_buffer); + BufferManager().RefreshPendingBuffers(); + current_command_encoder_ = nullptr; + } + + webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } + + Status Run(const ComputeContext& context, const ProgramBase& program); + + private: + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) : instance_{instance}, adapter_{adapter}, device_{device} {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + + std::once_flag init_flag_; + + wgpu::Instance instance_; + wgpu::Adapter adapter_; + wgpu::Device device_; + + wgpu::AdapterInfo adapter_info_; + wgpu::Limits device_limits_; + + wgpu::CommandEncoder current_command_encoder_; + wgpu::ComputePassEncoder current_compute_pass_encoder_; + + std::unique_ptr buffer_mgr_; + std::unique_ptr program_mgr_; + friend class WebGpuContextFactory; + + int num_pending_dispatches_ = 0; + const int max_num_pending_dispatches_ = 16; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc new file mode 100644 index 000000000000..e7688d1fafb9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -0,0 +1,837 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_execution_provider.h" + +#ifdef __EMSCRIPTEN__ +#include +#endif +#include +#include +#include +#include +#include + +#ifndef DISABLE_CONTRIB_OPS +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#endif + +#include "allocator.h" +#include "core/framework/compute_capability.h" +#include "core/framework/data_transfer_manager.h" +#include "core/framework/fallback_cpu_capability.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/function_utils.h" +#include "core/graph/indexed_sub_graph.h" +#include "data_transfer.h" + +namespace onnxruntime { + +namespace webgpu { +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} + +class Memcpy final : public OpKernel { + public: + Memcpy(const OpKernelInfo& info) : OpKernel(info) {} + + Status Compute(OpKernelContext* ctx) const override { + const auto* X = ctx->Input(0); + Tensor* Y = ctx->Output(0, X->Shape()); + return Info().GetDataTransferManager().CopyTensor(*X, *Y); + } +}; + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPU, 0) + .ExecQueueId(0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .ExecQueueId(1) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, End, Op)> + +#define KERNEL_CREATE_INFO(Start, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, Op)> + +#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, type, Op)> + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Abs); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Abs); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Neg); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Neg); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Floor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Floor); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Ceil); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Ceil); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Reciprocal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Reciprocal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sqrt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sqrt); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Exp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Exp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Erf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Erf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, HardSigmoid); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Log); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Log); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Sin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Cos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Tan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Asin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Acos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Atan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Sinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Cosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Asinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Acosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Atanh); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, Not); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 10, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, Elu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Relu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Sub); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Mul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Div); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 11, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Pow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 10, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Equal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Greater); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Less); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LessOrEqual); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Shape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 5, 12, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Reshape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Transpose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, Conv); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 3, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 4, 10, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 19, Resize); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, InstanceNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Range); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, float, Einsum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); + +std::unique_ptr RegisterKernels() { + auto kernel_registry = std::make_unique(); + + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // element-wise operators + // unary - math + KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), + KERNEL_CREATE_INFO(13, Abs), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + // KERNEL_CREATE_INFO(13, Neg), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + // KERNEL_CREATE_INFO(13, Floor), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + // KERNEL_CREATE_INFO(13, Ceil), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + // KERNEL_CREATE_INFO(13, Reciprocal), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + // KERNEL_CREATE_INFO(13, Sqrt), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + // KERNEL_CREATE_INFO(13, Exp), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + // KERNEL_CREATE_INFO(13, Erf), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + // KERNEL_CREATE_INFO(13, Sigmoid), + // KERNEL_CREATE_INFO(6, HardSigmoid), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + // KERNEL_CREATE_INFO(13, Log), + + // KERNEL_CREATE_INFO(7, Sin), + // KERNEL_CREATE_INFO(7, Cos), + // KERNEL_CREATE_INFO(7, Tan), + // KERNEL_CREATE_INFO(7, Asin), + // KERNEL_CREATE_INFO(7, Acos), + // KERNEL_CREATE_INFO(7, Atan), + // KERNEL_CREATE_INFO(9, Sinh), + // KERNEL_CREATE_INFO(9, Cosh), + // KERNEL_CREATE_INFO(9, Asinh), + // KERNEL_CREATE_INFO(9, Acosh), + // KERNEL_CREATE_INFO(9, Atanh), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), + // KERNEL_CREATE_INFO(13, Tanh), + // KERNEL_CREATE_INFO(1, Not), + + // KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), + // KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), + // KERNEL_CREATE_INFO(19, Cast), + + // // activations + // KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip), + // KERNEL_CREATE_INFO_VERSIONED(11, 11, Clip), + // KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip), + // KERNEL_CREATE_INFO(13, Clip), + // KERNEL_CREATE_INFO(6, Elu), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + // KERNEL_CREATE_INFO(14, Relu), + // KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + // KERNEL_CREATE_INFO(16, LeakyRelu), + // KERNEL_CREATE_INFO(10, ThresholdedRelu), + + // // binary - math + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + // KERNEL_CREATE_INFO(14, Add), + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + // KERNEL_CREATE_INFO(14, Sub), + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + // KERNEL_CREATE_INFO(14, Mul), + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + // KERNEL_CREATE_INFO(14, Div), + // KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + // KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + // KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + // KERNEL_CREATE_INFO(15, Pow), + // KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), + // KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), + // KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), + // KERNEL_CREATE_INFO(19, Equal), + // KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + // KERNEL_CREATE_INFO(13, Greater), + // KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + // KERNEL_CREATE_INFO(16, GreaterOrEqual), + // KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + // KERNEL_CREATE_INFO(13, Less), + // KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + // KERNEL_CREATE_INFO(16, LessOrEqual), + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + // KERNEL_CREATE_INFO(16, Where), + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_THROW_IF_ERROR(kernel_registry->Register(std::move(info))); + } + } + +#ifndef DISABLE_CONTRIB_OPS + Status status = ::onnxruntime::contrib::webgpu::RegisterWebGpuContribKernels(*kernel_registry); + ORT_ENFORCE(status.IsOK(), "Failed to register WebGPU contrib kernels: " + status.ErrorMessage()); +#endif + + return kernel_registry; +} + +} // namespace webgpu + +using namespace webgpu; + +WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, + const WebGpuContext& context, + const WebGpuExecutionProviderInfo& info) + : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, + context_id_{context_id}, + context_{context}, + preferred_data_layout_{info.data_layout}, + enable_graph_capture_{info.enable_graph_capture} { +} + +std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) { + return std::make_unique(context_); + }, + 0, false); + return std::vector{CreateAllocator(gpuBufferAllocatorCreationInfo)}; +} + +std::vector> WebGpuExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + InlinedVector candidates; + // `tenative_candidates` is a subset of `candidates`. + InlinedVector tenative_candidates; + for (auto& node_index : graph.GetNodesInTopologicalOrder()) { + const auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; + + const auto& node = *p_node; + if (!node.GetExecutionProviderType().empty()) { + // If the node was added by layout transformer, do not move it to CPU + if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) { + candidates.push_back(node.Index()); + } + continue; + } + + const KernelCreateInfo* webgpu_kernel_def = kernel_lookup.LookUpKernel(node); + // none of the provided registries has a webgpu kernel for this node + if (webgpu_kernel_def == nullptr) { + LOGS(*GetLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.OpType() << " node name: " << node.Name(); + continue; + } + candidates.push_back(node.Index()); + tenative_candidates.push_back(node.Index()); + } + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + std::vector> result; + for (auto& node_index : candidates) { + if (cpu_nodes.count(node_index) > 0) { + continue; + } + + auto sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node_index); + result.emplace_back(std::make_unique(std::move(sub_graph))); + } + return result; +} + +std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr registry = webgpu::RegisterKernels(); + + return registry; +} + +std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { + return std::make_unique(context_); +} + +WebGpuExecutionProvider::~WebGpuExecutionProvider() { +} + +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + } + return Status::OK(); +} + +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + // is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool WebGpuExecutionProvider::IsGraphCaptured(int) const { + return is_graph_captured_; +} + +Status WebGpuExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); + ORT_ENFORCE(false); + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h new file mode 100644 index 000000000000..6fb2381637a6 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" +#include "core/graph/constants.h" +#include "core/providers/providers.h" + +struct pthreadpool; +namespace onnxruntime { +namespace webgpu { + +// forward declaration for this EP's namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + +class WebGpuContext; +enum class BufferCacheMode; +} // namespace webgpu + +struct WebGpuExecutionProviderInfo { + DataLayout data_layout; + bool enable_graph_capture; + webgpu::BufferCacheMode storage_buffer_cache_mode; + webgpu::BufferCacheMode uniform_buffer_cache_mode; + webgpu::BufferCacheMode query_resolve_buffer_cache_mode; + webgpu::BufferCacheMode default_buffer_cache_mode; +}; + +class WebGpuExecutionProvider : public IExecutionProvider { + public: + WebGpuExecutionProvider(int context_id, const webgpu::WebGpuContext& context, const WebGpuExecutionProviderInfo& info); + ~WebGpuExecutionProvider() override; + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const override; + + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; + + DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } + + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } + + // WebGPU EP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to + // work, and concurrent run with async function may mess up the states and cause undefined behavior. + bool ConcurrentRunSupported() const override { return false; } + + std::vector CreatePreferredAllocators() override; + + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + + // WebGPU EP reuses the Device ID as the key to get the WebGpuContext instance. + int GetDeviceId() const override { return context_id_; } + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); + int context_id_; + const webgpu::WebGpuContext& context_; + DataLayout preferred_data_layout_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h new file mode 100644 index 000000000000..6486987501d1 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/compute_context.h" + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +// ----------------------------------------------------------------------- +// Base class for WebGPU kernels +// ----------------------------------------------------------------------- +class WebGpuKernel : public OpKernel { + public: + explicit WebGpuKernel(const OpKernelInfo& info) + : OpKernel(info) { + } + + Status Compute(OpKernelContext* p_op_kernel_context) const override { + ComputeContext context{*p_op_kernel_context}; + auto s = ComputeInternal(context); + // use this to precisely locate the node where CUDA failure comes from + // if (cudaSuccess != cudaDeviceSynchronize()) + // __debugbreak(); + // if (s.IsOK()) { + // auto err = cudaGetLastError(); + // if (err != cudaSuccess) { + // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); + // } + // } + return s; + } + + virtual Status ComputeInternal(ComputeContext& context) const = 0; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc new file mode 100644 index 000000000000..93258b84c511 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/framework/error_code_helper.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/ort_apis.h" + +#include "core/providers/webgpu/webgpu_provider_options.h" +using namespace onnxruntime::webgpu::options; + +namespace onnxruntime { + +struct WebGpuProviderFactory : IExecutionProviderFactory { + WebGpuProviderFactory(int context_id, const webgpu::WebGpuContext& context, const WebGpuExecutionProviderInfo& webgpu_ep_info) + : context_id_{context_id}, context_{context}, info_{webgpu_ep_info} { + } + + std::unique_ptr CreateProvider() override { + return std::make_unique(context_id_, context_, info_); + } + + private: + int context_id_; + const webgpu::WebGpuContext& context_; + WebGpuExecutionProviderInfo info_; +}; + +std::shared_ptr WebGpuProviderFactoryCreator::Create(const SessionOptions* session_options) { + // + // STEP.1 - prepare WebGpuExecutionProviderInfo + // + WebGpuExecutionProviderInfo webgpu_ep_info{ + // preferred layout is NHWC by default + DataLayout::NHWC, + // graph capture feature is disabled by default + false, + }; + + std::string preferred_layout_str; + if (session_options->config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (preferred_layout_str == kPreferredLayout_NHWC) { + webgpu_ep_info.data_layout = DataLayout::NHWC; + } else if (preferred_layout_str == kPreferredLayout_NCHW) { + webgpu_ep_info.data_layout = DataLayout::NCHW; + } else { + ORT_THROW("Invalid preferred layout: ", preferred_layout_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_info.data_layout) << " (parsed from \"" + << preferred_layout_str << "\")"; + + std::string enable_graph_capture_str; + if (session_options->config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (enable_graph_capture_str == kkEnableGraphCapture_ON) { + webgpu_ep_info.enable_graph_capture = true; + } else if (enable_graph_capture_str == kkEnableGraphCapture_OFF) { + webgpu_ep_info.enable_graph_capture = false; + } else { + ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; + + auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, webgpu::BufferCacheMode default) -> webgpu::BufferCacheMode { + std::string buffer_cache_mode_str; + if (session_options->config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { + return webgpu::BufferCacheMode::Disabled; + } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { + return webgpu::BufferCacheMode::LazyRelease; + } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { + return webgpu::BufferCacheMode::Simple; + } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { + return webgpu::BufferCacheMode::Bucket; + } else { + ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + } + } else { + return default; + } + }; + + webgpu_ep_info.storage_buffer_cache_mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << webgpu_ep_info.storage_buffer_cache_mode; + + webgpu_ep_info.uniform_buffer_cache_mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::LazyRelease); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << webgpu_ep_info.uniform_buffer_cache_mode; + + webgpu_ep_info.query_resolve_buffer_cache_mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << webgpu_ep_info.query_resolve_buffer_cache_mode; + + webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; + + // + // STEP.2 - prepare WebGpuContext + // + int context_id = 0; + std::string context_id_str; + if (session_options->config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); + } + + size_t webgpu_instance = 0; + std::string webgpu_instance_str; + if (session_options->config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); + } + + size_t webgpu_adapter = 0; + std::string webgpu_adapter_str; + if (session_options->config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { + static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec); + } + + size_t webgpu_device = 0; + std::string webgpu_device_str; + if (session_options->config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); + } + + auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, + reinterpret_cast(webgpu_instance), + reinterpret_cast(webgpu_adapter), + reinterpret_cast(webgpu_device)); + context.Initialize(webgpu_ep_info); + + return std::make_shared(context_id, context, webgpu_ep_info); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h new file mode 100644 index 000000000000..7fac9234b949 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/provider_options.h" +#include "core/providers/providers.h" + +namespace onnxruntime { +struct SessionOptions; + +struct WebGpuProviderFactoryCreator { + static std::shared_ptr Create(const SessionOptions* session_options); +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h new file mode 100644 index 000000000000..65ccbd800b12 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webgpu { +namespace options { + +// The following are the options that can be set in the WebGPU provider options. + +constexpr const char* kPreferredLayout = "preferredLayout"; +constexpr const char* kEnableGraphCapture = "enableGraphCapture"; + +constexpr const char* kDeviceId = "deviceId"; +constexpr const char* kWebGpuInstance = "webgpuInstance"; +constexpr const char* kWebGpuAdapter = "webgpuAdapter"; +constexpr const char* kWebGpuDevice = "webgpuDevice"; + +constexpr const char* kStorageBufferCacheMode = "storageBufferCacheMode"; +constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; +constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; +constexpr const char* kDefaultBufferCacheMode = "defaultBufferCacheMode"; + +// The following are the possible values for the provider options. + +constexpr const char* kPreferredLayout_NCHW = "NCHW"; +constexpr const char* kPreferredLayout_NHWC = "NHWC"; + +constexpr const char* kkEnableGraphCapture_ON = "1"; +constexpr const char* kkEnableGraphCapture_OFF = "0"; + +constexpr const char* kBufferCacheMode_Disabled = "disabled"; +constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; +constexpr const char* kBufferCacheMode_Simple = "simple"; +constexpr const char* kBufferCacheMode_Bucket = "bucket"; + +} // namespace options +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h new file mode 100644 index 000000000000..fccaef2c5357 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +using SupportedTypes = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t>; + +using SupportedFloats = + TypeList< + float, + MLFloat16>; + +inline const std::vector& WebGpuSupportedDataTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +inline const std::vector& WebGpuSupportedFloatTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b9e017df5baa..dced1cf0e146 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -759,12 +759,12 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr // Some session option values (default or user provided) may not work with some EPs. // Rather than put the onus on the user to know these, make the appropriate change while logging the change. - if (provider_type == onnxruntime::kDmlExecutionProvider) { - // DML's memory is not byte addressable and hence mem pattern doesn't work. + if (provider_type == onnxruntime::kDmlExecutionProvider || provider_type == onnxruntime::kWebGpuExecutionProvider) { + // DML and WebGPU memory is not byte addressable and hence mem pattern doesn't work. if (session_options_.enable_mem_pattern) { LOGS(*session_logger_, INFO) - << "Having memory pattern enabled is not supported while using the DML Execution Provider. " - << "So disabling it for this session since it uses the DML Execution Provider."; + << "Having memory pattern enabled is not supported while using " << provider_type << ". " + << "So disabling it for this session since it uses " << provider_type << "."; session_options_.enable_mem_pattern = false; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1a5484ddc005..f231b0148b37 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2730,6 +2730,8 @@ static constexpr OrtApi ort_api_1_to_20 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fcae173e6c16..fd765feae6ad 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -384,6 +384,13 @@ ORT_API_STATUS_IMPL(InvokeOp, ORT_API(void, ReleaseOp, _Frees_ptr_opt_ OrtOp* op); +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_WebGPU, + _In_ OrtSessionOptions* options, + _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_reads_(num_keys) const char* const* string_options_keys, + _In_reads_(num_keys) const char* const* string_options_values, + _In_ size_t num_keys); + ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* provider_name, diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index db8b97f6d2c1..d2f8579fef7e 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -131,6 +132,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "WebGPU") == 0) { +#if defined(USE_WEBGPU) + options->provider_factories.push_back(WebGpuProviderFactoryCreator::Create(&(options->value))); +#else + status = create_not_supported_status(); #endif } else if (strcmp(provider_name, "AZURE") == 0) { #if defined(USE_AZURE) @@ -158,6 +165,59 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, + _In_ OrtSessionOptions* options, + _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_reads_(num_keys) const char* const* string_options_keys, + _In_reads_(num_keys) const char* const* string_options_values, + _In_ size_t num_keys) { + API_IMPL_BEGIN + std::vector options_keys; + options_keys.reserve(num_keys + 4); + std::vector options_values; + options_values.reserve(num_keys + 4); + + // the following code uses std::to_chars() to convert int/size_t to string. + // unlike std::to_string(), std::to_chars() is guaranteed locale-independent. + // + // uint64_t to string is no more than 20 characters, and + // int32_t to string is no more than 11 characters. + static_assert(sizeof(size_t) == 4 || sizeof(size_t) == 8); + char buffer[sizeof(size_t) == 4 ? 11 : 20]; + + auto res = std::to_chars(buffer, buffer + sizeof(buffer), webgpu_options->device_id); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert device_id to string"); + std::string device_id(buffer, res.ptr - buffer); + options_keys.push_back("deviceId"); + options_values.push_back(device_id.c_str()); + + res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->instance_handle)); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert instance_handle to string"); + std::string instance_handle(buffer, res.ptr - buffer); + options_keys.push_back("webgpuInstance"); + options_values.push_back(instance_handle.c_str()); + + res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->adapter_handle)); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert adapter_handle to string"); + std::string adapter_handle(buffer, res.ptr - buffer); + options_keys.push_back("webgpuAdapter"); + options_values.push_back(adapter_handle.c_str()); + + res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->device_handle)); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert device_handle to string"); + std::string device_handle(buffer, res.ptr - buffer); + options_keys.push_back("webgpuDevice"); + options_values.push_back(device_handle.c_str()); + + for (size_t i = 0; i != num_keys; ++i) { + options_keys.push_back(string_options_keys[i]); + options_values.push_back(string_options_values[i]); + } + + return OrtApis::SessionOptionsAppendExecutionProvider(options, "WebGPU", options_keys.data(), options_values.data(), options_keys.size()); + API_IMPL_END +} + #if defined(__APPLE__) || defined(ORT_MINIMAL_BUILD) static OrtStatus* CreateNotEnabledStatus(const std::string& ep) { return OrtApis::CreateStatus(ORT_FAIL, (ep + " execution provider is not enabled in this build. ").c_str()); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 0397bba90438..17b5cce6a4d6 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -47,13 +47,16 @@ void usage() { "\t-v: verbose\n" "\t-n [test_case_name]: Specifies a single test case to run.\n" "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', 'vsinpu'" - "'openvino', 'rocm', 'migraphx', 'acl', 'armnn', 'xnnpack', 'nnapi', 'qnn', 'snpe' or 'coreml'. " + "'openvino', 'rocm', 'migraphx', 'acl', 'armnn', 'xnnpack', 'webgpu', 'nnapi', 'qnn', 'snpe' or 'coreml'. " "Default: 'cpu'.\n" "\t-p: Pause after launch, can attach debugger and continue\n" "\t-x: Use parallel executor, default (without -x): sequential executor.\n" "\t-d [device_id]: Specifies the device id for multi-device (e.g. GPU). The value should > 0\n" "\t-t: Specify custom relative tolerance values for output value comparison. default: 1e-5\n" "\t-a: Specify custom absolute tolerance values for output value comparison. default: 1e-5\n" + "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" + "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" + "\t [Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" @@ -123,6 +126,39 @@ static TestTolerances LoadTestTolerances(bool enable_cuda, bool enable_openvino, overrides_json["atol_default"], overrides_json["rtol_default"], absolute_overrides, relative_overrides); } +static bool ParseSessionConfigs(const std::string& configs_string, + std::unordered_map& session_configs) { + std::istringstream ss(configs_string); + std::string token; + + while (ss >> token) { + if (token == "") { + continue; + } + + std::string_view token_sv(token); + + auto pos = token_sv.find("|"); + if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { + // Error: must use a '|' to separate the key and value for session configuration entries. + return false; + } + + std::string key(token_sv.substr(0, pos)); + std::string value(token_sv.substr(pos + 1)); + + auto it = session_configs.find(key); + if (it != session_configs.end()) { + // Error: specified duplicate session configuration entry: {key} + return false; + } + + session_configs.insert(std::make_pair(std::move(key), std::move(value))); + } + + return true; +} + #ifdef _WIN32 int GetNumCpuCores() { SYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer[256]; @@ -179,6 +215,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool enable_armnn = false; bool enable_rocm = false; bool enable_migraphx = false; + bool enable_webgpu = false; bool enable_xnnpack = false; bool override_tolerance = false; double atol = 1e-5; @@ -188,6 +225,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool user_graph_optimization_level_set = false; bool set_denormal_as_zero = false; std::basic_string ep_runtime_config_string; + std::unordered_map session_config_entries; std::string provider_name = "cpu"; OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_ERROR; @@ -198,7 +236,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool pause = false; { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:t:a:xvo:d:i:pzfb"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:t:a:xvo:d:C:i:pzfb"))) != -1) { switch (ch) { case 'A': enable_cpu_mem_arena = false; @@ -267,6 +305,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { enable_rocm = true; } else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) { enable_migraphx = true; + } else if (!CompareCString(optarg, ORT_TSTR("webgpu"))) { + enable_webgpu = true; } else if (!CompareCString(optarg, ORT_TSTR("xnnpack"))) { enable_xnnpack = true; } else { @@ -323,6 +363,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) { return -1; } break; + case 'C': + if (!ParseSessionConfigs(ToUTF8String(optarg), session_config_entries)) { + return -1; + } + break; case 'i': ep_runtime_config_string = optarg; break; @@ -409,6 +454,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (disable_ep_context_embed_mode) sf.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + for (auto& it : session_config_entries) { + sf.AddConfigEntry(it.first.c_str(), it.second.c_str()); + } + if (enable_tensorrt) { #ifdef USE_TENSORRT Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf, device_id)); @@ -698,6 +747,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } + if (enable_webgpu) { +#ifdef USE_WEBGPU + sf.AppendExecutionProvider("WebGPU", {}); +#else + fprintf(stderr, "WebGPU is not supported in this build"); + return -1; +#endif + } + if (user_graph_optimization_level_set) { sf.SetGraphOptimizationLevel(graph_optimization_level); } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 587d035541c4..9e71e35a9290 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -563,6 +563,7 @@ def convert_arg_line_to_args(self, arg_line): "--nnapi_min_api", type=int, help="Minimum Android API level to enable NNAPI, should be no less than 27" ) parser.add_argument("--use_jsep", action="store_true", help="Build with JavaScript kernels.") + parser.add_argument("--use_webgpu", action="store_true", help="Build with WebGPU support.") parser.add_argument("--use_qnn", action="store_true", help="Build with QNN support.") parser.add_argument("--qnn_home", help="Path to QNN SDK dir.") parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") @@ -1054,6 +1055,7 @@ def generate_build_tree( "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), "-Donnxruntime_USE_JSEP=" + ("ON" if args.use_jsep else "OFF"), + "-Donnxruntime_USE_WEBGPU=" + ("ON" if args.use_webgpu else "OFF"), # Training related flags "-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"), "-Donnxruntime_ENABLE_TRAINING=" + ("ON" if args.enable_training else "OFF"), @@ -1310,6 +1312,9 @@ def generate_build_tree( raise BuildError("WebNN is only available for WebAssembly build.") cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] + if args.use_jsep and args.use_webgpu: + raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] From 9c362501db6c8cf645c22cc5b76b14993462ab02 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:07:13 -0700 Subject: [PATCH 002/114] update C-API --- .../core/session/onnxruntime_c_api.h | 18 ++++++++++++------ .../core/session/onnxruntime_cxx_api.h | 6 +++--- .../core/session/onnxruntime_cxx_inline.h | 10 +++++----- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- onnxruntime/core/session/ort_apis.h | 4 ++-- .../core/session/provider_registration.cc | 4 +++- 6 files changed, 26 insertions(+), 18 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9e5d9339bffe..e6049b45c8ec 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -641,14 +641,20 @@ typedef struct OrtMIGraphXProviderOptions { * It's user's responsibility to manage the lifecycle of the handles and ensure the handles are valid during the * lifetime of the inference session. * - * \see OrtApi::SessionOptionsAppendExecutionProvider_WebGPU + * About DawnProcTable: + * + * When using an ONNX Runtime build that is not directly linked dawn during the build, a pointer to the runtime memory + * address of the DawnProcTable should be provided. Otherwise, keep it as nullptr. + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_WGPU */ -typedef struct OrtWebGPUProviderOptions { +typedef struct OrtWGPUProviderOptions { int device_id; // WebGPU device id. void* instance_handle; // WebGPU instance handle. void* adapter_handle; // WebGPU adapter handle. void* device_handle; // WebGPU device handle. -} OrtWebGPUProviderOptions; + void* dawn_proc_table; // DawnProcTable pointer. +} OrtWGPUProviderOptions; /** \brief OpenVINO Provider Options * @@ -4699,7 +4705,7 @@ struct OrtApi { * If WebGPU is not available, this function will return failure. * * \param[in] options - * \param[in] webgpu_options - specify the WebGPU provider options. + * \param[in] wgpu_options - specify the WebGPU provider options. * \param[in] string_options_keys - keys to configure the string options * \param[in] string_options_values - values to configure the string options * \param[in] num_keys - number of keys passed in @@ -4719,8 +4725,8 @@ struct OrtApi { * * \since Version 1.20. */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WebGPU, - _In_ OrtSessionOptions* options, _In_ const OrtWebGPUProviderOptions* webgpu_options, + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WGPU, + _In_ OrtSessionOptions* options, _In_ const OrtWGPUProviderOptions* wgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, _In_reads_(num_keys) const char* const* string_options_values, _In_ size_t num_keys); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index cf30584e18a4..f85dad0a41ea 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -890,9 +890,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WebGPU - SessionOptionsImpl& AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, - const std::unordered_map& string_options = {}); + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WGPU + SessionOptionsImpl& AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options, + const std::unordered_map& string_options = {}); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index e5c84395ad95..b675ff04268f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -839,21 +839,21 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIG } template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, - const std::unordered_map& string_options) { - auto num_entries = provider_options.size(); +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options, + const std::unordered_map& string_options) { + auto num_entries = string_options.size(); std::vector keys, values; if (num_entries > 0) { keys.reserve(num_entries); values.reserve(num_entries); - for (const auto& entry : provider_options) { + for (const auto& entry : string_options) { keys.push_back(entry.first.c_str()); values.push_back(entry.second.c_str()); } } - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WebGPU(this->p_, &provider_options, keys.data(), values.data(), num_entries)); + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WGPU(this->p_, &wgpu_options, keys.data(), values.data(), num_entries)); return *this; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f231b0148b37..3e787bb17aee 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2731,7 +2731,7 @@ static constexpr OrtApi ort_api_1_to_20 = { &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) - &OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, + &OrtApis::SessionOptionsAppendExecutionProvider_WGPU, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fd765feae6ad..86cb3f3122d6 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -384,9 +384,9 @@ ORT_API_STATUS_IMPL(InvokeOp, ORT_API(void, ReleaseOp, _Frees_ptr_opt_ OrtOp* op); -ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_WebGPU, +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_WGPU, _In_ OrtSessionOptions* options, - _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_ const OrtWGPUProviderOptions* wgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, _In_reads_(num_keys) const char* const* string_options_values, _In_ size_t num_keys); diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index d2f8579fef7e..1938ea3fd2c1 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -165,7 +165,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WGPU, _In_ OrtSessionOptions* options, _In_ const OrtWebGPUProviderOptions* webgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, @@ -209,6 +209,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, options_keys.push_back("webgpuDevice"); options_values.push_back(device_handle.c_str()); + // TODO: dawn proc table + for (size_t i = 0; i != num_keys; ++i) { options_keys.push_back(string_options_keys[i]); options_values.push_back(string_options_values[i]); From 3a0756d4f2cc091dea9c4e04d9053f6931b5e0ac Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:54:41 -0700 Subject: [PATCH 003/114] fix build break --- onnxruntime/core/session/provider_registration.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 1938ea3fd2c1..da97cdc25ab1 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -167,7 +167,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WGPU, _In_ OrtSessionOptions* options, - _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_ const OrtWGPUProviderOptions* webgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, _In_reads_(num_keys) const char* const* string_options_values, _In_ size_t num_keys) { From 5199e9858993b85b9f7809c94054e65a2c4ed5e3 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:59:44 -0700 Subject: [PATCH 004/114] add an empty symbols.txt file --- onnxruntime/core/providers/webgpu/symbols.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/symbols.txt diff --git a/onnxruntime/core/providers/webgpu/symbols.txt b/onnxruntime/core/providers/webgpu/symbols.txt new file mode 100644 index 000000000000..e69de29bb2d1 From 1c68dbd361157e331923668e0c5c04b1d0d17864 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:18:40 -0700 Subject: [PATCH 005/114] fix an error in doc --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index a5a71fd94bf4..87309f6673bb 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -143,7 +143,7 @@ use `build.bat --use_webgpu` to build the WebGPU EP. For Release build, append ` to test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` -onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" --model_path=C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` > Assume C:\code\onnxruntime is the root of your onnxruntime repo From 7db03de2ccdd60b246405d88c34a02934e70f0f2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:31:49 -0700 Subject: [PATCH 006/114] remove string_join.h in favor of absl::StrJoin --- include/onnxruntime/core/common/string_join.h | 61 ------------------- onnxruntime/core/providers/webgpu/program.h | 5 +- .../core/providers/webgpu/shader_variable.h | 14 ++++- 3 files changed, 14 insertions(+), 66 deletions(-) delete mode 100644 include/onnxruntime/core/common/string_join.h diff --git a/include/onnxruntime/core/common/string_join.h b/include/onnxruntime/core/common/string_join.h deleted file mode 100644 index 2c2181d4ad04..000000000000 --- a/include/onnxruntime/core/common/string_join.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/common/make_string.h" - -namespace onnxruntime { - -namespace detail { - -template -inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss) noexcept { -} - -template -inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t) noexcept { - ss << separator << t; -} - -template -inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t, const Args&... args) noexcept { - StringJoinImpl(separator, ss, t); - StringJoinImpl(separator, ss, args...); -} - -template -inline std::string StringJoinImpl(const Separator& separator, const Args&... args) noexcept { - std::ostringstream ss; - ss.imbue(std::locale::classic()); - StringJoinImpl(separator, ss, args...); - return ss.str(); -} -} // namespace detail - -/** - * Makes a string by concatenating string representations of the arguments using the specified separator. - * Uses std::locale::classic() - */ -template -std::string StringJoin(const Separator& separator, const Args&... args) { - return detail::StringJoinImpl(separator, detail::if_char_array_make_ptr_t(args)...); -} - -// StringJoin versions for already-a-string types. - -template -inline std::string StringJoin(const Separator& /* separator */, const std::string& str) { - return str; -} - -template -inline std::string StringJoin(const Separator& /* separator */, const char* cstr) { - return cstr; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 6df918e2f7f7..277c00e08901 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -7,8 +7,9 @@ #include #include +#include + #include "core/common/common.h" -#include "core/common/string_join.h" #include "core/common/safeint.h" #include "core/framework/tensor.h" @@ -218,7 +219,7 @@ class ProgramBase { // set the cache hint for the program template ProgramBase& CacheHint(CacheHintArgs&&... args) { - cache_hint_ = StringJoin("|", std::forward(args)...); + cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(args)...), "|"); } // set one or more program inputs diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 0a5cad823787..65a015c8e7ba 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -191,7 +191,11 @@ inline std::string ShaderVariable::BroadcastedIndicesToOffset(const std::string& template inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { - return rank_ == 0 ? "" : MakeStringWithClassicLocale(name_, "_indices_t(", onnxruntime::detail::StringJoinImpl(", ", std::forward(indices_args)...), ')'); + return rank_ == 0 + ? "" + : MakeStringWithClassicLocale(name_, "_indices_t(", + absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), + ')'); } template @@ -219,7 +223,9 @@ inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { return SetByOffset(std::forward(args)...); } else { usage_ |= UseSet | UseSetByIndices | UseIndicesToOffset; - return MakeStringWithClassicLocale("set_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(args)...), ");"); + return MakeStringWithClassicLocale("set_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(args)...), ", "), + ");"); } } @@ -246,7 +252,9 @@ inline std::string ShaderVariable::Get(TIndices&&... indices) const { return GetByOffset(std::forward(indices)...); } else { usage_ |= UseGet | UseGetByIndices | UseIndicesToOffset; - return MakeStringWithClassicLocale("get_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(indices)...), ')'); + return MakeStringWithClassicLocale("get_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(indices)...), ", "), + ')'); } } From 6a373c231ea048799e5359a929170a3fffda0a6c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:26:46 -0700 Subject: [PATCH 007/114] fix DLL copy --- cmake/onnxruntime_providers_webgpu.cmake | 7 +++++++ cmake/onnxruntime_unittests.cmake | 10 ---------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 303ab9483c38..587c4b2c1ff2 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -27,4 +27,11 @@ onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + # Copy webgpu_dawn.dll to the output directory + add_custom_command( + TARGET onnxruntime_providers_webgpu + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" "$" + VERBATIM ) + set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 5434ead12f65..511c25dd6d15 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1100,11 +1100,6 @@ if (NOT IOS) endif() set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") - add_custom_command(TARGET onnx_test_runner POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy $ $ - COMMAND_EXPAND_LISTS - ) - if (onnxruntime_USE_TVM) if (WIN32) target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") @@ -1235,11 +1230,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() endif() - add_custom_command(TARGET onnxruntime_perf_test POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy $ $ - COMMAND_EXPAND_LISTS - ) - if (onnxruntime_BUILD_SHARED_LIB) #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. From ee42bba8a2e19030c8d8cea3e0d0d08c279082d6 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 01:11:14 -0700 Subject: [PATCH 008/114] update doc: require --skip_tests --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 2 +- onnxruntime/core/providers/webgpu/README.md | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 87309f6673bb..3c20130ae2ce 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -139,7 +139,7 @@ This section is WIP. ## 6. Build and test -use `build.bat --use_webgpu` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. +use `build.bat --use_webgpu --skip_tests` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. to test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md index d9c4313c8bf3..20864d360914 100644 --- a/onnxruntime/core/providers/webgpu/README.md +++ b/onnxruntime/core/providers/webgpu/README.md @@ -4,7 +4,9 @@ This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU E ## Build WebGPU EP -Just append `--use_webgpu` to the `build.bat` command line. +Just append `--use_webgpu --skip_tests` to the `build.bat` command line. + +NOTE: `--skip_tests` is required for now. All existing tests are for CPU EP anyway so no need to run them. Currently only works on Windows. From 3f46e5c6e6aa8311540f3dbfb01a45dce2e35bab Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 02:10:44 -0700 Subject: [PATCH 009/114] update dawn version --- cmake/deps.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 2ab00cdbeb30..597c051b5f47 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -59,4 +59,4 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d839 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 -dawn;https://github.com/google/dawn/archive/9a912d8162d5a837950de14f8849230212e3f51c.zip;7f2cad3db905e2d846d8f2422623850a4463915f +dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99 From 9f61279361e33e3cec0891a8cd95869d841bc17a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:11:38 -0700 Subject: [PATCH 010/114] disable Tint tests --- cmake/external/onnxruntime_external_deps.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 2dad3479c3c0..6640609aa71d 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -593,6 +593,7 @@ if (onnxruntime_USE_WEBGPU) ) set(DAWN_FETCH_DEPENDENCIES ON) set(DAWN_ENABLE_INSTALL ON) + set(TINT_BUILD_TESTS OFF) onnxruntime_fetchcontent_makeavailable(dawn) endif() From 6bb6335a71bdfbeb45f0f2ed9bd20ebb004967ab Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:02:16 -0700 Subject: [PATCH 011/114] fix one build break in Linux --- onnxruntime/core/providers/webgpu/compute_context.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index d7aeae240101..4d567b088fc1 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -70,7 +70,7 @@ class ComputeContext { Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { AllocatorPtr allocator; ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceCPUAllocator(&allocator)); - return {data_type, std::forward(shape)..., allocator}; + return {data_type, std::forward(shape), allocator}; } // @@ -80,7 +80,7 @@ class ComputeContext { Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { AllocatorPtr allocator; ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); - return {data_type, std::forward(shape)..., allocator}; + return {data_type, std::forward(shape), allocator}; } // From d839dbc213e8402b75b595173fa6592f8e3cc021 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 19:52:47 -0700 Subject: [PATCH 012/114] remove unused variables --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index a891f5a8a551..5f09223b2271 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -219,18 +219,6 @@ Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& #endif } - std::vector input_buffers; - input_buffers.reserve(inputs.size()); - for (const auto& input : inputs) { - input_buffers.push_back(reinterpret_cast(const_cast(input.tensor->DataRaw()))); - } - - std::vector output_buffers; - output_buffers.reserve(outputs.size()); - for (const auto& output : outputs) { - output_buffers.push_back(reinterpret_cast(output->MutableDataRaw())); - } - WGPUBuffer uniform_buffer = nullptr; auto uniform_buffer_size = program_artifact->uniform_total_size; if (uniform_buffer_size > 0) { From b70943d92b0c225eae7b9576a22d6633342aff6e Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 30 Aug 2024 14:02:58 -0700 Subject: [PATCH 013/114] make webgpu build on linux and known to most tools (#21937) Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- cmake/onnxruntime.cmake | 3 ++- cmake/onnxruntime_python.cmake | 1 + onnxruntime/core/providers/webgpu/compute_context.h | 2 +- onnxruntime/core/providers/webgpu/shader_variable.h | 4 ++-- .../core/providers/webgpu/webgpu_provider_factory.cc | 5 +++-- onnxruntime/python/onnxruntime_pybind_state.cc | 4 ++++ onnxruntime/test/perftest/command_args_parser.cc | 6 ++++-- onnxruntime/test/perftest/ort_test_session.cc | 7 +++++++ onnxruntime/test/util/default_providers.cc | 8 ++++++++ onnxruntime/test/util/include/default_providers.h | 1 + tools/ci_build/gen_def.py | 1 + 11 files changed, 34 insertions(+), 8 deletions(-) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 927b4ac84b03..52b6bda34686 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -38,7 +38,7 @@ function(get_c_cxx_api_headers HEADERS_VAR) # need to add header files for enabled EPs foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) - # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory + # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory # with onnxruntime_c_api.h . Most other EPs probably also do not work in this way. if((NOT f STREQUAL cuda) AND (NOT f STREQUAL rocm)) file(GLOB _provider_headers CONFIGURE_DEPENDS @@ -200,6 +200,7 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_RKNPU} ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBGPU} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} ${PROVIDERS_INTERNAL_TESTING} diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index b2dbe4b3da5e..c5ba54421723 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -178,6 +178,7 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE ${PROVIDERS_ACL} ${PROVIDERS_ARMNN} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBGPU} ${PROVIDERS_AZURE} ${PROVIDERS_QNN} onnxruntime_optimizer diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 4d567b088fc1..9c352d3d76dd 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -14,11 +14,11 @@ #include "core/framework/execution_provider.h" #include "core/providers/webgpu/program.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { class Tensor; -class OpKernelContext; namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 65a015c8e7ba..ef95e26e6df7 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -131,7 +131,7 @@ class ShaderVariable { void Init(); void Impl(std::ostringstream& ss); - std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const; + std::string GetByOffsetImpl(const std::string& offset) const; std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; std::string_view StorageType() const; @@ -140,7 +140,7 @@ class ShaderVariable { std::string name_; ProgramVariableDataType type_; - int rank_; + size_t rank_; TensorShape dims_; mutable Usage usage_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 93258b84c511..e871b66f1dc9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -67,7 +67,8 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( } LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; - auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, webgpu::BufferCacheMode default) -> webgpu::BufferCacheMode { + auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, + webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { std::string buffer_cache_mode_str; if (session_options->config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { @@ -82,7 +83,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( ORT_THROW("Invalid buffer cache mode: ", config_entry_str); } } else { - return default; + return default_value; } }; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 47b8d75f22ae..036585586d9a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1207,6 +1207,10 @@ std::unique_ptr CreateExecutionProviderInstance( return onnxruntime::XnnpackProviderFactoryCreator::Create( cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options) ->CreateProvider(); +#endif + } else if (type == kWebGpuExecutionProvider) { +#if defined(USE_WEBGPU) + return onnxruntime::WebGpuProviderFactoryCreator::Create(&session_options)->CreateProvider(); #endif } else if (type == kCannExecutionProvider) { #ifdef USE_CANN diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 84c3bc16346f..0b8e291ec7fb 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -37,8 +37,8 @@ namespace perftest { "\t-A: Disable memory arena\n" "\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n" "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" - "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " - "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. " + "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai:webgpu]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " + "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'. " "Default:'cpu'.\n" "\t-b [tf|ort]: backend to use. Default:ort\n" "\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n" @@ -279,6 +279,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.machine_config.provider_type_name = onnxruntime::kXnnpackExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("webgpu"))) { + test_config.machine_config.provider_type_name = onnxruntime::kWebGpuExecutionProvider; } else { return false; } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index fc1bdb10d745..57a20e2d03ee 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -551,6 +551,13 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "XNNPACK", {{"intra_op_num_threads", std::to_string(performance_test_config.run_config.intra_op_num_threads)}}); #else ORT_THROW("Xnnpack is not supported in this build\n"); +#endif + } else if (provider_name_ == onnxruntime::kWebGpuExecutionProvider) { +#ifdef USE_WEBGPU + session_options.AppendExecutionProvider( + "WebGPU", {{"intra_op_num_threads", std::to_string(performance_test_config.run_config.intra_op_num_threads)}}); +#else + ORT_THROW("WebGpu is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kVitisAIExecutionProvider) { #ifdef USE_VITISAI diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bb..871285269daf 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -301,6 +301,14 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { #endif } +std::unique_ptr DefaultWebGpuExecutionProvider() { +#ifdef USE_WEBGPU + return WebGpuProviderFactoryCreator::Create(nullptr)->CreateProvider(); +#else + return nullptr; +#endif +} + std::unique_ptr DefaultCannExecutionProvider() { #ifdef USE_CANN OrtCANNProviderOptions provider_options{}; diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 606dfc068d39..610b5b4ced68 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -62,6 +62,7 @@ std::unique_ptr DefaultQnnExecutionProvider(); std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options, const SessionOptions* session_options = nullptr); std::unique_ptr DefaultXnnpackExecutionProvider(); +std::unique_ptr DefaultWebGpuExecutionProvider(); std::unique_ptr DefaultCannExecutionProvider(); std::unique_ptr DefaultDmlExecutionProvider(); diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index c4add6f0e891..765e9d135b7f 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -80,6 +80,7 @@ def parse_arguments(): "dnnl", "tensorrt", "azure", + "webgpu" ): file.write(f"#include \n") file.write("void* GetFunctionEntryByName(const char* name){\n") From 843726753f8bb35bb2a900add041a53f4e4245e8 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:45:28 -0700 Subject: [PATCH 014/114] revert type of ShaderVariable::rank_ to int --- onnxruntime/core/providers/webgpu/shader_variable.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index ef95e26e6df7..15d2259c34a9 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -140,7 +140,7 @@ class ShaderVariable { std::string name_; ProgramVariableDataType type_; - size_t rank_; + int rank_; TensorShape dims_; mutable Usage usage_; From 3caf032a9848d922428e4e2a5f798be1649b3c72 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 15:49:26 -0700 Subject: [PATCH 015/114] output Impl() for variables --- onnxruntime/core/providers/webgpu/shader_helper.cc | 8 ++++++-- onnxruntime/core/providers/webgpu/shader_variable.cc | 2 +- onnxruntime/core/providers/webgpu/shader_variable.h | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 203f11ff9000..d3466b6d611a 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -180,8 +180,12 @@ std::string ShaderHelper::GetFinalSourceCode() { // Indices helper // ss << "\n"; - // for (const auto& group : vars_) { - // } + for (const auto& var_group : vars_) { + for (const auto& var : var_group) { + var.Impl(ss); + } + ss << "\n"; + } // // Additional Implementation diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index d49d76c1ee85..4bff31e9dd30 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -26,7 +26,7 @@ void ShaderVariable::Init() { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); } -void ShaderVariable::Impl(std::ostringstream& ss) { +void ShaderVariable::Impl(std::ostringstream& ss) const { // Start generating code const std::string value_t = name_ + "_value_t"; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 15d2259c34a9..fbdb6590a735 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -129,7 +129,7 @@ class ShaderVariable { ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); void Init(); - void Impl(std::ostringstream& ss); + void Impl(std::ostringstream& ss) const; std::string GetByOffsetImpl(const std::string& offset) const; std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; From 84494c4344027b15951f47246be941e7c72a3604 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:03:16 -0700 Subject: [PATCH 016/114] code formatting --- tools/ci_build/gen_def.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index 765e9d135b7f..2b7790ec4e68 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -80,7 +80,7 @@ def parse_arguments(): "dnnl", "tensorrt", "azure", - "webgpu" + "webgpu", ): file.write(f"#include \n") file.write("void* GetFunctionEntryByName(const char* name){\n") From aa70163a7a0431402b46408570fb97b41558bb27 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:04:07 -0700 Subject: [PATCH 017/114] better format of Uniform --- onnxruntime/core/providers/webgpu/shader_helper.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index d3466b6d611a..3986b13e0a7d 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -137,7 +137,7 @@ std::string ShaderHelper::GetFinalSourceCode() { program_.UniformVariables().cend(), [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { bool first = true; - ss << "struct Uniforms {\n"; + ss << "struct Uniforms {"; size_t uniform_count = program_.UniformVariables().size(); for (size_t i = 0; i < uniform_count; i++) { @@ -151,11 +151,11 @@ std::string ShaderHelper::GetFinalSourceCode() { if (first) { first = false; } else { - ss << ",\n"; + ss << ","; } auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; - ss << " " << alignment << name << ": "; + ss << "\n " << alignment << name << ": "; if (length > 4) { if (data_type == ProgramUniformVariableDataType::Float16) { size_t array_size = (length + 7) / 8; @@ -171,7 +171,7 @@ std::string ShaderHelper::GetFinalSourceCode() { } } - ss << "};\n" + ss << "\n};\n" "@group(0) @binding(" << variable_count << ") var uniforms: Uniforms;\n"; } From d772db7ae7a2710a0f6ca6f9338186029dcb1e3c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:06:02 -0700 Subject: [PATCH 018/114] revise document --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 4 ++-- onnxruntime/core/providers/webgpu/README.md | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 3c20130ae2ce..a7123ac4a580 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -146,9 +146,9 @@ to test, find the "onnx_test_runner.exe" in your build folder. run it like: onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` -> Assume C:\code\onnxruntime is the root of your onnxruntime repo +> Assume `C:\code\onnxruntime` is the root of your onnxruntime repo > -> if it does not exist, run the following in your onnxruntime repo root: +> if folder `C:\code\onnxruntime\js\test\data` does not exist, run the following in your onnxruntime repo root: > ``` > cd js > npm ci diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md index 20864d360914..999f1fecbda7 100644 --- a/onnxruntime/core/providers/webgpu/README.md +++ b/onnxruntime/core/providers/webgpu/README.md @@ -4,11 +4,14 @@ This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU E ## Build WebGPU EP -Just append `--use_webgpu --skip_tests` to the `build.bat` command line. +Just append `--use_webgpu --skip_tests` to the `build.bat`/`build.sh` command line. NOTE: `--skip_tests` is required for now. All existing tests are for CPU EP anyway so no need to run them. -Currently only works on Windows. +For linux, a few dependencies need to be installed: +```sh +apt-get install libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev libx11-dev libx11-xcb-dev +``` ## Troubleshooting From 6ef3dadfa5c4290922f9d3874858c071ed07f36c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 21:20:29 -0700 Subject: [PATCH 019/114] more build fix for linux --- .../core/providers/webgpu/buffer_manager.cc | 2 -- onnxruntime/core/providers/webgpu/program.h | 1 + .../core/providers/webgpu/program_cache_key.cc | 2 +- .../core/providers/webgpu/shader_variable.cc | 14 +++++++------- .../core/providers/webgpu/webgpu_context.cc | 2 +- .../providers/webgpu/webgpu_execution_provider.h | 7 +++++++ 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index d69b1210ade4..e1f065b65f13 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -4,8 +4,6 @@ #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_context.h" -static int xx = 1; - namespace onnxruntime { namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 277c00e08901..812e44e014ee 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -220,6 +220,7 @@ class ProgramBase { template ProgramBase& CacheHint(CacheHintArgs&&... args) { cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(args)...), "|"); + return *this; } // set one or more program inputs diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index d720c55fb542..a4530910944d 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -47,7 +47,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp } } - ss << ":" D("DispatchDim=") << is_1d_dispatch ? "1" : "3"; + ss << ":" D("DispatchDim=") << (is_1d_dispatch ? "1" : "3"); ss << ":" D("UniformSizes="); bool first = true; for (const auto& uniform : program.UniformVariables()) { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 4bff31e9dd30..9483ab19036c 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -72,7 +72,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("fn o2i_", name_, "(offset : u32)->", indices_t, " {\n"); SS(" var indices: ", indices_t, ";\n"); SS(" var current = offset;\n"); - for (size_t i = 0; i < rank_ - 1; i++) { + for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_); SS(" let dim", i, " = current / ", current_stride, ";\n"); SS(" let rest", i, " = current % ", current_stride, ";\n"); @@ -90,7 +90,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { if (rank_ >= 2) { SS("fn i2o_", name_, "(indices : ", indices_t, ")->u32 {\n"); SS(" return "); - for (size_t i = 0; i < rank_ - 1; i++) { + for (int i = 0; i < rank_ - 1; i++) { SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); } SS("indices[", rank_ - 1, "];\n"); @@ -108,7 +108,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS(" return 0;\n"); } else { SS(" return "); - for (size_t i = 0; i < rank_ - 1; i++) { + for (int i = 0; i < rank_ - 1; i++) { auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); } @@ -122,12 +122,12 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { if (usage_ & UseSet) { if (rank_ >= 2) { SS("fn set_", name_, "(d0: u32"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } SS(", value: ", value_t, ") {\n"); SS(" set_", name_, "_by_indices(d0"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i); } SS(", value);\n"); @@ -148,12 +148,12 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { if (usage_ & UseGet) { if (rank_ >= 2) { SS("fn get_", name_, "(d0: u32"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } SS(")->", value_t, " {\n"); SS(" return get_", name_, "_by_indices(d0"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i); } SS(");\n"); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 5f09223b2271..049a729f5c98 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -93,7 +93,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info wgpu::RequestDeviceCallbackInfo req_device_callback_info = {}; req_device_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; req_device_callback_info.callback = [](WGPURequestDeviceStatus status, WGPUDevice device, char const* message, void* userdata) { - ORT_ENFORCE(status == WGPURequestAdapterStatus_Success, "Failed to get a WebGPU device: ", message); + ORT_ENFORCE(status == WGPURequestDeviceStatus_Success, "Failed to get a WebGPU device: ", message); *static_cast(userdata) = wgpu::Device::Acquire(device); }; req_device_callback_info.userdata = &device_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 6fb2381637a6..4b2d2882b6ec 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -22,6 +22,13 @@ enum class BufferCacheMode; } // namespace webgpu struct WebGpuExecutionProviderInfo { + WebGpuExecutionProviderInfo(DataLayout data_layout1, bool enable_graph_capture1) + : data_layout{data_layout1} + , enable_graph_capture{enable_graph_capture1} + , storage_buffer_cache_mode{} + , uniform_buffer_cache_mode{} + , query_resolve_buffer_cache_mode{} + , default_buffer_cache_mode{} {} DataLayout data_layout; bool enable_graph_capture; webgpu::BufferCacheMode storage_buffer_cache_mode; From a56f6c3edae7ad7c15cffc4032cd13c7e4b2f452 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 21:23:36 -0700 Subject: [PATCH 020/114] apply formatter --- .../providers/webgpu/webgpu_execution_provider.h | 12 ++++++------ test_webgpu.bat | 12 ++++++++++++ test_webgpu_cases.txt | 1 + 3 files changed, 19 insertions(+), 6 deletions(-) create mode 100644 test_webgpu.bat create mode 100644 test_webgpu_cases.txt diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 4b2d2882b6ec..5f27fad14afc 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -23,12 +23,12 @@ enum class BufferCacheMode; struct WebGpuExecutionProviderInfo { WebGpuExecutionProviderInfo(DataLayout data_layout1, bool enable_graph_capture1) - : data_layout{data_layout1} - , enable_graph_capture{enable_graph_capture1} - , storage_buffer_cache_mode{} - , uniform_buffer_cache_mode{} - , query_resolve_buffer_cache_mode{} - , default_buffer_cache_mode{} {} + : data_layout{data_layout1}, + enable_graph_capture{enable_graph_capture1}, + storage_buffer_cache_mode{}, + uniform_buffer_cache_mode{}, + query_resolve_buffer_cache_mode{}, + default_buffer_cache_mode{} {} DataLayout data_layout; bool enable_graph_capture; webgpu::BufferCacheMode storage_buffer_cache_mode; diff --git a/test_webgpu.bat b/test_webgpu.bat new file mode 100644 index 000000000000..feec724c1a7d --- /dev/null +++ b/test_webgpu.bat @@ -0,0 +1,12 @@ +rem @echo off +:: if file js\test\data\node\__generated_onnx_node_tests not found, generate it +if not exist "%~dp0js\test\data\node\__generated_onnx_node_tests" ( + pushd "%~dp0js" + call npm ci + call npm run prepare-node-tests + popd +) + +for /F "tokens=*" %%A in (%~dp0test_webgpu_cases.txt) do ( + echo %%A +) diff --git a/test_webgpu_cases.txt b/test_webgpu_cases.txt new file mode 100644 index 000000000000..4cc29f5b13ed --- /dev/null +++ b/test_webgpu_cases.txt @@ -0,0 +1 @@ +test_abs From 12cd79d6742e8967e697fddf144fcd55dcf1c5cc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 01:38:11 -0700 Subject: [PATCH 021/114] simple test runner --- cmake/onnxruntime_unittests.cmake | 16 +++ .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 13 ++- .../test/providers/webgpu/test_webgpu.bat | 3 + .../test/providers/webgpu/test_webgpu.js | 98 +++++++++++++++++++ test_webgpu.bat | 12 --- test_webgpu_cases.txt | 1 - 6 files changed, 129 insertions(+), 14 deletions(-) create mode 100644 onnxruntime/test/providers/webgpu/test_webgpu.bat create mode 100644 onnxruntime/test/providers/webgpu/test_webgpu.js delete mode 100644 test_webgpu.bat delete mode 100644 test_webgpu_cases.txt diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b050698a5570..6c43680ecc75 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1120,6 +1120,22 @@ if (NOT IOS) LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} BUNDLE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + + ## TODO: remove this when merging to main branch + # + # should support better test runner + # + if (onnxruntime_USE_WEBGPU) + add_custom_command( + TARGET onnx_test_runner + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${ONNXRUNTIME_ROOT}/test/providers/webgpu/test_webgpu.js" + "${ONNXRUNTIME_ROOT}/test/providers/webgpu/test_webgpu.bat" + "$" + VERBATIM ) + endif() + endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index a7123ac4a580..a27a7b3131bd 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -141,7 +141,18 @@ This section is WIP. use `build.bat --use_webgpu --skip_tests` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. -to test, find the "onnx_test_runner.exe" in your build folder. run it like: +to test, find the "test_webgpu.bat" in your build folder. run it for tests: +``` +# run all tests +test_webgpu.bat + +# run a specific test +test_webgpu.bat test_abs +``` + + + +to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.bat b/onnxruntime/test/providers/webgpu/test_webgpu.bat new file mode 100644 index 000000000000..fad6569c2457 --- /dev/null +++ b/onnxruntime/test/providers/webgpu/test_webgpu.bat @@ -0,0 +1,3 @@ +@echo off + +node "%~dp0test_webgpu.js" %* diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js new file mode 100644 index 000000000000..111f321ccbbd --- /dev/null +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -0,0 +1,98 @@ +const HELP = ` + Call onnx_test_runner to test WebGPU EP. + + Usage: node test_webgpu.js [options] + + Options: + -h Print this help message. + -t= Path of the test data folder (eg. "../../../js/test/data/node") + -v Verbose mode. + -m= ';' separated list of test names (eg. test_abs) +`; + +const DEFAULT_TESTS = [ + 'test_abs', +]; + +const path = require('path'); +const fs = require('fs'); +const { spawnSync } = require('child_process'); + +const ONNX_TEST_RUNNER_FILENAME = path.join(__dirname, + 'onnx_test_runner' + (process.platform === 'win32' ? '.exe' : '')); + +if (process.argv.includes('-h')) { + console.log(HELP); + process.exit(0); +} + +const VERBOSE = process.argv.includes('-v'); +let test_data_path = process.argv.find(arg => arg.startsWith('-t=')); +if (!test_data_path) { + test_data_path = path.join(__dirname, (process.platform === 'win32' ? '../' : '') + '../../../js/test/data/node'); +} else { + test_data_path = test_data_path.substring(3); +} + +const test_models = []; +const test_model_list = process.argv.find(arg => arg.startsWith('-m=')); +if (test_model_list) { + test_model_list.substring(3).split(';').forEach(test_model => { + test_models.push(test_model); + }); +} +const tests = new Set(test_model_list ? test_models : DEFAULT_TESTS); +const test_cases = []; +fs.readdirSync(test_data_path, { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const opset = dirent.name; + fs.readdirSync(path.join(test_data_path, opset), { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const name = dirent.name; + if (tests.has(name)) { + test_cases.push(path.join(test_data_path, opset, name)); + } + } + }); + } +}); + +let passed = []; +let not_implemented = []; +let failed = []; +test_cases.forEach(test_case => { + process.stdout.write(`Running test case: "${test_case}"...`); + const args = [ + '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', test_case, + ]; + if (VERBOSE) { + args.unshift('-v'); + } + const p = spawnSync(ONNX_TEST_RUNNER_FILENAME, args, { shell: true, stdio: ['ignore', 'pipe', 'pipe'] }); + if (p.status !== 0) { + process.stdout.write('Failed\n'); + failed.push(test_case); + } else if (!p.stdout.toString().includes('Not implemented: 0')) { + process.stdout.write('Not Implemented\n'); + not_implemented.push(test_case); + } else { + process.stdout.write('OK\n'); + passed.push(test_case); + } +}); + +console.log(`\n${passed.length} tests passed.`); +console.log(`\n${not_implemented.length} tests not implemented:`); +not_implemented.slice(0, 3).forEach(test_case => { + console.log(` ${test_case}`); +}); +if (not_implemented.length > 3) { + console.log(` ...`); +} +console.log(`\n${failed.length} tests failed:`); +failed.slice(0, 3).forEach(test_case => { + console.log(` ${test_case}`); +}); +if (failed.length > 3) { + console.log(` ...`); +} diff --git a/test_webgpu.bat b/test_webgpu.bat deleted file mode 100644 index feec724c1a7d..000000000000 --- a/test_webgpu.bat +++ /dev/null @@ -1,12 +0,0 @@ -rem @echo off -:: if file js\test\data\node\__generated_onnx_node_tests not found, generate it -if not exist "%~dp0js\test\data\node\__generated_onnx_node_tests" ( - pushd "%~dp0js" - call npm ci - call npm run prepare-node-tests - popd -) - -for /F "tokens=*" %%A in (%~dp0test_webgpu_cases.txt) do ( - echo %%A -) diff --git a/test_webgpu_cases.txt b/test_webgpu_cases.txt deleted file mode 100644 index 4cc29f5b13ed..000000000000 --- a/test_webgpu_cases.txt +++ /dev/null @@ -1 +0,0 @@ -test_abs From 14c89661ae854f502ee7ef2f85b15915bd123af9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 16:38:53 -0700 Subject: [PATCH 022/114] Program macros update - allow extend --- .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 6 +- onnxruntime/core/providers/webgpu/program.h | 219 +++++++++++++----- 2 files changed, 159 insertions(+), 66 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index a27a7b3131bd..7ae7e2b37fc2 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -37,7 +37,7 @@ constants are declaration of values that are never changes in the shader code. T const A : u32 = 64; ``` -Use macro `WEBGPU_PROGRAM_DEFINE_CONSTANTS` to define constants in your Program class. +Use macro `WEBGPU_PROGRAM_DEFINE_CONSTANTS` to define constants in your Program class, or use `WEBGPU_PROGRAM_EXTEND_CONSTANTS` to extend the constants defined in the base class. #### **overridable constants** @@ -48,13 +48,13 @@ override B : u32 = 64; override C : f32; ``` -Use macro `WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS` to define overridable constants in your Program class. +Use macro `WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS` to define overridable constants in your Program class, or use `WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS` to extend the overridable constants defined in the base class. #### **uniform definitions** uniform definitions are declaration of uniform varables. Their names and type must be defined and cannot be changed. Their values(including length) can be set at runtime. -Use macro `WEBGPU_PROGRAM_DEFINE_UNIFORMS` to define uniform definitions in your Program class. +Use macro `WEBGPU_PROGRAM_DEFINE_UNIFORMS_VARIABLES` to define uniform definitions in your Program class, or use `WEBGPU_PROGRAM_EXTEND_UNIFORMS_VARIABLES` to extend the uniform definitions defined in the base class. ### 2.3. The Program class should override the `GenerateShaderCode` method: diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 812e44e014ee..d056ee8577f1 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -54,6 +54,9 @@ struct ProgramUniformVariableValue { // represents a uniform variable definition struct ProgramUniformVariableDefinition { + constexpr ProgramUniformVariableDefinition(std::string_view name, ProgramUniformVariableDataType data_type) + : name{name}, data_type{data_type} {} + std::string_view name; ProgramUniformVariableDataType data_type; }; @@ -337,27 +340,32 @@ class ProgramWrapper : public ProgramBase { #error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined" #endif -#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ - private: \ - template \ - static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ - template \ - static auto test_has_##identifier(...)->std::false_type; \ - \ - template && /* - is array */ \ - std::is_const_v && /* - has "const" modifier */ \ - std::is_convertible_v && /* - can convert to a const pointer */ \ - !std::is_member_pointer_v>> /* - is static */ \ - static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ - template \ - static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ - \ - public: \ - static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ +#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ + private: \ + template \ + static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ + template \ + static auto test_has_##identifier(...)->std::false_type; \ + \ + template ::value && /* - is a const std::array */ \ + std::is_const_v && /* - has "const" modifier */ \ + !std::is_member_pointer_v>> /* - is static */ \ + static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ + template \ + static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ + \ + public: \ + static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type(0))::value +// the following template class checks whether the type is a const std::array +template +struct is_const_std_array : std::false_type {}; +template +struct is_const_std_array> : std::true_type {}; + // the following template class checks whether certain static members exist in the derived class (SFINAE) template class DerivedProgramClassTypeCheck { @@ -367,52 +375,90 @@ class DerivedProgramClassTypeCheck { }; // compile-time tests for the type check +// +// TODO: move this to test folder namespace test { +template +class TestTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +}; + struct TestClass_Empty {}; -struct TestClass_0 { +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_0 { int b; }; -struct TestClass_1 { +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_1 { int a; }; -struct TestClass_2 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_2 { const int a; }; -struct TestClass_3 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_0 { const int a[2]; }; -struct TestClass_4 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_1 { static constexpr int a[] = {0}; }; -struct TestClass_5 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_2 { static int a[]; }; -struct TestClass_6 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_3 { static const int a[]; }; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); -template -class TestTypeCheck { - ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +struct TestClass_StdArray_0 { + std::array a = {1}; }; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(!TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(!TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(TestTypeCheck::has_a_with_correct_type); +struct TestClass_StdArray_1 { + static constexpr std::array a = {1, 2}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_2 { + static const std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_3 { + static constexpr const std::array a = {1, 2, 3, 4}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_4 { + static std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); } // namespace test @@ -435,13 +481,12 @@ class Program : public detail::ProgramWrapper { virtual ProgramMetadata GetMetadata() const final { ProgramMetadata metadata; if constexpr (detail::DerivedProgramClassTypeCheck::has_constants) { - constexpr const ProgramConstant* ptr = T::constants; - constexpr size_t len = sizeof(T::constants) / sizeof(ProgramConstant); + constexpr const ProgramConstant* ptr = T::constants.data(); + constexpr size_t len = T::constants.size(); - static_assert(detail::DerivedProgramClassTypeCheck::has_constants_with_correct_type && - sizeof(T::constants) % sizeof(ProgramConstant) == 0, + static_assert(detail::DerivedProgramClassTypeCheck::has_constants_with_correct_type, "Derived class of \"Program\" has member \"constants\" but its type is incorrect. " - "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() to declare constants."); + "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_CONSTANTS() to declare constants."); metadata.constants = {ptr, len}; } else { @@ -449,13 +494,12 @@ class Program : public detail::ProgramWrapper { } if constexpr (detail::DerivedProgramClassTypeCheck::has_overridable_constants) { - constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants; - constexpr size_t len = sizeof(T::overridable_constants) / sizeof(ProgramOverridableConstantDefinition); + constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data(); + constexpr size_t len = T::overridable_constants.size(); - static_assert(detail::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type && - sizeof(T::overridable_constants) % sizeof(ProgramOverridableConstantDefinition) == 0, + static_assert(detail::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type, "Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. " - "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() to declare overridable constants."); + "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS() to declare overridable constants."); metadata.overridable_constants = {ptr, len}; } else { @@ -463,13 +507,12 @@ class Program : public detail::ProgramWrapper { } if constexpr (detail::DerivedProgramClassTypeCheck::has_uniform_variables) { - constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables; - constexpr size_t len = sizeof(T::uniform_variables) / sizeof(ProgramUniformVariableDefinition); + constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data(); + constexpr size_t len = T::uniform_variables.size(); - static_assert(detail::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type && - sizeof(T::uniform_variables) % sizeof(ProgramUniformVariableDefinition) == 0, + static_assert(detail::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type, "Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. " - "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() to declare uniform variables."); + "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() or WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES() to declare uniform variables."); metadata.uniform_variables = {ptr, len}; } else { @@ -480,14 +523,64 @@ class Program : public detail::ProgramWrapper { } }; +namespace detail { +// helper function to convert a C-style array to std::array +// +// This is basically the same as std::to_array in C++20. +// +template +constexpr auto _to_std_array_impl(T (&arr)[N], std::index_sequence) -> std::array, N> { + return {{arr[Idx]...}}; +} + +template +constexpr auto _to_std_array(T (&arr)[N]) -> std::array, N> { + return _to_std_array_impl(arr, std::make_index_sequence{}); +} + +// helper function to concatenate a std::array and a C-style array to a std::array +// +template +constexpr std::array, L + R> _concat2_impl(const std::array& lhs, + T (&rhs)[R], + std::index_sequence, + std::index_sequence) { + return {{lhs[IdxL]..., rhs[IdxR]...}}; +} + +template +constexpr std::array, L + R> _concat2(const std::array& lhs, T (&rhs)[R]) { + return _concat2_impl(lhs, rhs, std::make_index_sequence{}, std::make_index_sequence{}); +} + +} // namespace detail +#define WEBGPU_PROGRAM_DEFINE_(identifier, T, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::detail::_to_std_array(identifier##_own) + +#define WEBGPU_PROGRAM_EXTEND_(identifier, T, BASE, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::detail::_concat2(BASE::identifier, identifier##_own) + #define WEBGPU_PROGRAM_DEFINE_CONSTANTS(...) \ - static constexpr const onnxruntime::webgpu::ProgramConstant constants[] = {__VA_ARGS__} + WEBGPU_PROGRAM_DEFINE_(constants, onnxruntime::webgpu::ProgramConstant, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(constants, onnxruntime::webgpu::ProgramConstant, BASE, __VA_ARGS__) #define WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS(...) \ - static constexpr const onnxruntime::webgpu::ProgramOverridableConstantDefinition overridable_constants[] = {__VA_ARGS__} + WEBGPU_PROGRAM_DEFINE_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, BASE, __VA_ARGS__) #define WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(...) \ - static constexpr const onnxruntime::webgpu::ProgramUniformVariableDefinition uniform_variables[] = {__VA_ARGS__} + WEBGPU_PROGRAM_DEFINE_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, BASE, __VA_ARGS__) } // namespace webgpu } // namespace onnxruntime From 4fff35f99fe26df47c862882179eafe1695a961d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 21:42:12 -0700 Subject: [PATCH 023/114] fix BucketCacheManager --- onnxruntime/core/providers/webgpu/buffer_manager.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index e1f065b65f13..da544e1d1ed6 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -176,6 +176,8 @@ class BucketCacheManager : public IBufferCacheManager { wgpuBufferRelease(buffer); } } + + pending_buffers_.clear(); } protected: From 4fd8ad19327db52b9fa147ccb5bc5a0b8978acc1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 22:09:51 -0700 Subject: [PATCH 024/114] add a method to get logger from ComputeContext --- .../core/providers/webgpu/compute_context.cc | 20 ------------ .../core/providers/webgpu/compute_context.h | 31 ++++++++++++++----- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index 67c55f823d78..b7a1af5b26ef 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -13,25 +13,5 @@ ComputeContext::ComputeContext(OpKernelContext& kernel_context) kernel_context_{kernel_context} { } -const wgpu::AdapterInfo& ComputeContext::AdapterInfo() const { - return webgpu_context_.AdapterInfo(); -} - -const wgpu::Limits& ComputeContext::DeviceLimits() const { - return webgpu_context_.DeviceLimits(); -} - -int ComputeContext::InputCount() const { - return kernel_context_.InputCount(); -} - -int ComputeContext::OutputCount() const { - return kernel_context_.OutputCount(); -} - -Status ComputeContext::RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); -} - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 9c352d3d76dd..ab090956b4d4 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -34,34 +34,49 @@ class ComputeContext { // Get various information from the context. // - const wgpu::AdapterInfo& AdapterInfo() const; - const wgpu::Limits& DeviceLimits() const; + inline const wgpu::AdapterInfo& AdapterInfo() const { + return webgpu_context_.AdapterInfo(); + } + inline const wgpu::Limits& DeviceLimits() const { + return webgpu_context_.DeviceLimits(); + } + + // + // Get the logger + // + inline const logging::Logger& Logger() const { + return kernel_context_.Logger(); + } // // Get input tensor. // template - const T* Input(int index) const { + inline const T* Input(int index) const { return kernel_context_.Input(index); } // // Get input count. // - int InputCount() const; + inline int InputCount() const { + return kernel_context_.InputCount(); + } // // Set output tensor. // template - Tensor* Output(int index, TensorShapeType&& shape) { + inline Tensor* Output(int index, TensorShapeType&& shape) { return kernel_context_.Output(index, std::forward(shape)); } // // Get output count. // - int OutputCount() const; + inline int OutputCount() const { + return kernel_context_.OutputCount(); + } // // Create CPU tensor. @@ -86,7 +101,9 @@ class ComputeContext { // // Run a compute shader program. // - Status RunProgram(const ProgramBase& program); + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } protected: WebGpuContext& webgpu_context_; From 3bd92adcf6ea537477b6b65171a14e2429c1443e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 22:33:25 -0700 Subject: [PATCH 025/114] add verbose log for cache key --- onnxruntime/core/providers/webgpu/compute_context.h | 1 + onnxruntime/core/providers/webgpu/webgpu_context.cc | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index ab090956b4d4..132f629ac745 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -14,6 +14,7 @@ #include "core/framework/execution_provider.h" #include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_context.h" #include "core/framework/op_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 049a729f5c98..9e51cc08eec0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -6,6 +6,7 @@ #include "core/common/common.h" +#include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" @@ -124,7 +125,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& program) { +Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -200,6 +201,8 @@ Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& auto key = CalculateProgramCacheKey(program, is_1d_dispatch); + LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")"; + const auto* program_artifact = program_mgr_->Get(key); if (program_artifact == nullptr) { wgpu::ComputePipeline compute_pipeline; From 6a1bbfe907cf1f5a2b494edde9ebbdcbfe1795d9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 01:02:31 -0700 Subject: [PATCH 026/114] revise suite test --- .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 54 +- .../test/providers/webgpu/test_webgpu.js | 1138 ++++++++++++++++- 2 files changed, 1130 insertions(+), 62 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 7ae7e2b37fc2..624cfd80dd8f 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -119,6 +119,7 @@ Status ComputeInternal(ComputeContext& context) const override; ``` Usually, in the implementation, we do 3 things: + - Create a local variable of the Program class. - Set a few runtime info of the Program instance. - Call `context.RunProgram(program)` to run the program and return the status. @@ -130,6 +131,7 @@ Complicated operators may do more things. Check header files and existing implem Register the operator just like any EP does. Check existing implementations for more details. Please note that registration is composed of 2 parts: + - Use macros like `ONNX_OPERATOR_KERNEL_EX` or `ONNX_OPERATOR_VERSIONED_KERNEL_EX` (or wrap a new macro as what we usually do) to register the operator in kernel source code file. - Add the operator to onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -139,29 +141,59 @@ This section is WIP. ## 6. Build and test +### Build + use `build.bat --use_webgpu --skip_tests` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. -to test, find the "test_webgpu.bat" in your build folder. run it for tests: +### Prepare test data + +Assume `C:\code\onnxruntime` is the root of your onnxruntime repo in all documents below. + +if folder `C:\code\onnxruntime\js\test\data` does not exist, run the following in your onnxruntime repo root: + +``` +cd js +npm ci +npm run prepare-node-tests +``` + +### Run Suite test (temporary: this may change recently) + +to do suite test, find the "test_webgpu.bat" in your build folder (It's usually in `build\Windows\Debug\Debug`). run it for tests: + ``` # run all tests test_webgpu.bat -# run a specific test -test_webgpu.bat test_abs +# run a test list from args +test_webgpu.bat -m=test_abs;test_cos ``` +To add more tests to the suite list, edit the file at `C:\code\onnxruntime\onnxruntime\test\providers\webgpu\test_webgpu.js`. After editing, run build again otherwise this file will not be copied to the build folder. +> How does it work? +> +> The `test_webgpu.bat` calls `test_webgpu.js` with nodejs. +> +> The `test_webgpu.js` use the test list (either the suite list or from cmd args) to prepare a temporary folder and creates symbolic links to the test data folder (under `C:\code\onnxruntime\js\test\data`). Then it runs `onnx_test_runner` on the temporary folder. + +### Run single test / debug to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: + ``` onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` -> Assume `C:\code\onnxruntime` is the root of your onnxruntime repo -> -> if folder `C:\code\onnxruntime\js\test\data` does not exist, run the following in your onnxruntime repo root: -> ``` -> cd js -> npm ci -> npm run prepare-node-tests -> ``` +The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. + +Some features are useful but if you are troubleshooting and want to rule out the cause, you can: + +- set `storageBufferCacheMode` to `disabled` to disable the storage buffer cache. +- set `-M` and `-A` to disable memory pattern and memory arena. +- set `-j 1` to disable parallel execution (if you have multiple models to test). + +Example: +``` +onnx_test_runner.exe -v -A -M -j 1 -e webgpu -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js index 111f321ccbbd..254bded19ae7 100644 --- a/onnxruntime/test/providers/webgpu/test_webgpu.js +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -11,7 +11,1047 @@ const HELP = ` `; const DEFAULT_TESTS = [ - 'test_abs', + "test_abs", + "test_acos_example", + "test_acos", + "test_acosh_example", + "test_acosh", + // // "test_adagrad_multiple", + // // "test_adagrad", + // // "test_adam_multiple", + // // "test_adam", + "test_add_bcast", + // "test_add_uint8", + "test_add", + // "test_and_bcast3v1d", + // "test_and_bcast3v2d", + // "test_and_bcast4v2d", + // "test_and_bcast4v3d", + // "test_and_bcast4v4d", + // "test_and2d", + // "test_and3d", + // "test_and4d", + "test_argmax_default_axis_example_select_last_index", + "test_argmax_default_axis_example", + "test_argmax_default_axis_random_select_last_index", + "test_argmax_default_axis_random", + "test_argmax_keepdims_example_select_last_index", + "test_argmax_keepdims_example", + "test_argmax_keepdims_random_select_last_index", + "test_argmax_keepdims_random", + "test_argmax_negative_axis_keepdims_example_select_last_index", + "test_argmax_negative_axis_keepdims_example", + "test_argmax_negative_axis_keepdims_random_select_last_index", + "test_argmax_negative_axis_keepdims_random", + "test_argmax_no_keepdims_example_select_last_index", + "test_argmax_no_keepdims_example", + "test_argmax_no_keepdims_random_select_last_index", + "test_argmax_no_keepdims_random", + "test_argmin_default_axis_example_select_last_index", + "test_argmin_default_axis_example", + "test_argmin_default_axis_random_select_last_index", + "test_argmin_default_axis_random", + "test_argmin_keepdims_example_select_last_index", + "test_argmin_keepdims_example", + "test_argmin_keepdims_random_select_last_index", + "test_argmin_keepdims_random", + "test_argmin_negative_axis_keepdims_example_select_last_index", + "test_argmin_negative_axis_keepdims_example", + "test_argmin_negative_axis_keepdims_random_select_last_index", + "test_argmin_negative_axis_keepdims_random", + "test_argmin_no_keepdims_example_select_last_index", + "test_argmin_no_keepdims_example", + "test_argmin_no_keepdims_random_select_last_index", + "test_argmin_no_keepdims_random", + "test_asin_example", + "test_asin", + "test_asinh_example", + "test_asinh", + "test_atan_example", + "test_atan", + "test_atanh_example", + "test_atanh", + // "test_averagepool_1d_default", + // "test_averagepool_2d_ceil", + "test_averagepool_2d_default", + "test_averagepool_2d_pads_count_include_pad", + "test_averagepool_2d_pads", + "test_averagepool_2d_precomputed_pads_count_include_pad", + "test_averagepool_2d_precomputed_pads", + "test_averagepool_2d_precomputed_same_upper", + "test_averagepool_2d_precomputed_strides", + "test_averagepool_2d_same_lower", + "test_averagepool_2d_same_upper", + "test_averagepool_2d_strides", + // "test_averagepool_3d_default", + "test_basic_conv_with_padding", + "test_basic_conv_without_padding", + // "test_basic_convinteger", + // "test_batchnorm_epsilon_training_mode", + "test_batchnorm_epsilon", + // "test_batchnorm_example_training_mode", + "test_batchnorm_example", + // // "test_bernoulli_double_expanded", + // // "test_bernoulli_double", + // // "test_bernoulli_expanded", + // // "test_bernoulli_seed_expanded", + // // "test_bernoulli_seed", + // // "test_bernoulli", + // // "test_bitshift_left_uint16", + // // "test_bitshift_left_uint32", + // // "test_bitshift_left_uint64", + // // "test_bitshift_left_uint8", + // // "test_bitshift_right_uint16", + // // "test_bitshift_right_uint32", + // // "test_bitshift_right_uint64", + // // "test_bitshift_right_uint8", + // // "test_blackmanwindow_expanded", + // // "test_blackmanwindow_symmetric_expanded", + // // "test_blackmanwindow_symmetric", + // // "test_blackmanwindow", + // // "test_cast_BFLOAT16_to_FLOAT", + // // "test_cast_DOUBLE_to_FLOAT", + // // "test_cast_DOUBLE_to_FLOAT16", + // // "test_cast_FLOAT_to_BFLOAT16", + // // "test_cast_FLOAT_to_DOUBLE", + // // "test_cast_FLOAT_to_FLOAT16", + // // "test_cast_FLOAT_to_STRING", + // // "test_cast_FLOAT16_to_DOUBLE", + // // "test_cast_FLOAT16_to_FLOAT", + // // "test_cast_STRING_to_FLOAT", + // // "test_castlike_BFLOAT16_to_FLOAT_expanded", + // // "test_castlike_BFLOAT16_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT_expanded", + // // "test_castlike_DOUBLE_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT16_expanded", + // // "test_castlike_DOUBLE_to_FLOAT16", + // // "test_castlike_FLOAT_to_BFLOAT16_expanded", + // // "test_castlike_FLOAT_to_BFLOAT16", + // // "test_castlike_FLOAT_to_DOUBLE_expanded", + // // "test_castlike_FLOAT_to_DOUBLE", + // // "test_castlike_FLOAT_to_FLOAT16_expanded", + // // "test_castlike_FLOAT_to_FLOAT16", + // // "test_castlike_FLOAT_to_STRING_expanded", + // // "test_castlike_FLOAT_to_STRING", + // // "test_castlike_FLOAT16_to_DOUBLE_expanded", + // // "test_castlike_FLOAT16_to_DOUBLE", + // // "test_castlike_FLOAT16_to_FLOAT_expanded", + // // "test_castlike_FLOAT16_to_FLOAT", + // // "test_castlike_STRING_to_FLOAT_expanded", + // // "test_castlike_STRING_to_FLOAT", + "test_ceil_example", + "test_ceil", + // "test_celu_expanded", + // "test_celu", + // "test_clip_default_inbounds", + // "test_clip_default_int8_inbounds", + // "test_clip_default_int8_max", + // "test_clip_default_int8_min", + // "test_clip_default_max", + // "test_clip_default_min", + // "test_clip_example", + // "test_clip_inbounds", + // "test_clip_outbounds", + // "test_clip_splitbounds", + // "test_clip", + // // "test_compress_0", + // // "test_compress_1", + // // "test_compress_default_axis", + // // "test_compress_negative_axis", + "test_concat_1d_axis_0", + "test_concat_1d_axis_negative_1", + "test_concat_2d_axis_0", + "test_concat_2d_axis_1", + "test_concat_2d_axis_negative_1", + "test_concat_2d_axis_negative_2", + "test_concat_3d_axis_0", + "test_concat_3d_axis_1", + "test_concat_3d_axis_2", + "test_concat_3d_axis_negative_1", + "test_concat_3d_axis_negative_2", + "test_concat_3d_axis_negative_3", + "test_conv_with_autopad_same", + "test_conv_with_strides_and_asymmetric_padding", + "test_conv_with_strides_no_padding", + "test_conv_with_strides_padding", + // // "test_convinteger_with_padding", + // // "test_convinteger_without_padding", + "test_convtranspose_1d", + // // "test_convtranspose_3d", + "test_convtranspose_autopad_same", + "test_convtranspose_dilations", + "test_convtranspose_kernel_shape", + "opset{9,17}/test_convtranspose_output_shape", + "test_convtranspose_pad", + "test_convtranspose_pads", + "test_convtranspose_with_kernel", + "test_convtranspose", + "test_cos_example", + "test_cos", + "test_cosh_example", + "test_cosh", + // "test_cumsum_1d_exclusive", + // "test_cumsum_1d_reverse_exclusive", + // "test_cumsum_1d_reverse", + // "test_cumsum_1d", + // "test_cumsum_2d_axis_0", + // "test_cumsum_2d_axis_1", + // "test_cumsum_2d_negative_axis", + "test_depthtospace_crd_mode_example", + "test_depthtospace_crd_mode", + "test_depthtospace_dcr_mode", + "test_depthtospace_example", + "test_depthtospace", + "test_dequantizelinear_axis", + "test_dequantizelinear", + // // "test_det_2d", + // // "test_det_nd", + // // "test_dft_axis", + // // "test_dft_inverse", + // // "test_dft", + "test_div_bcast", + "test_div_example", + // "test_div_uint8", + "test_div", + // // "test_dropout_default_mask_ratio", + // // "test_dropout_default_mask", + // // "test_dropout_default_old", + // // "test_dropout_default_ratio", + // // "test_dropout_default", + // // "test_dropout_random_old", + // // "test_dropout_random", + // // "test_dynamic_slice_default_axes", + // // "test_dynamic_slice_end_out_of_bounds", + // // "test_dynamic_slice_neg", + // // "test_dynamic_slice_start_out_of_bounds", + // // "test_dynamic_slice", + // // "test_dynamicquantizelinear_expanded", + // // "test_dynamicquantizelinear_max_adjusted_expanded", + // // "test_dynamicquantizelinear_max_adjusted", + // // "test_dynamicquantizelinear_min_adjusted_expanded", + // // "test_dynamicquantizelinear_min_adjusted", + // // "test_dynamicquantizelinear", + "test_edge_pad", + // "test_einsum_batch_diagonal", + // "test_einsum_batch_matmul", + // "test_einsum_inner_prod", + // "test_einsum_sum", + // "test_einsum_transpose", + "test_elu_default", + "test_elu_example", + "test_elu", + // "test_equal_bcast", + // "test_equal", + "test_erf", + "test_exp_example", + "test_exp", + "test_expand_dim_changed", + "test_expand_dim_unchanged", + // "test_eyelike_populate_off_main_diagonal", + // "test_eyelike_with_dtype", + // "test_eyelike_without_dtype", + "test_flatten_axis0", + "test_flatten_axis1", + "test_flatten_axis2", + "test_flatten_axis3", + "test_flatten_default_axis", + "test_flatten_negative_axis1", + "test_flatten_negative_axis2", + "test_flatten_negative_axis3", + "test_flatten_negative_axis4", + "test_floor_example", + "test_floor", + "test_gather_0", + "test_gather_1", + "test_gather_2d_indices", + "test_gather_negative_indices", + "test_gather_elements_0", + "test_gather_elements_1", + "test_gather_elements_negative_indices", + // "test_gather_negative_indices", + // // "test_gathernd_example_float32", + // // "test_gathernd_example_int32_batch_dim1", + // // "test_gathernd_example_int32", + "test_gemm_all_attributes", + "test_gemm_alpha", + "test_gemm_beta", + "test_gemm_broadcast", + "test_gemm_default_matrix_bias", + "test_gemm_default_no_bias", + // "test_gemm_default_scalar_bias", + "test_gemm_default_single_elem_vector_bias", + "test_gemm_default_vector_bias", + "test_gemm_default_zero_bias", + "test_gemm_nobroadcast", + "test_gemm_transposeA", + "test_gemm_transposeB", + "test_globalaveragepool_precomputed", + "test_globalaveragepool", + "test_globalmaxpool_precomputed", + "test_globalmaxpool", + "test_greater_bcast", + "test_greater_equal_bcast_expanded", + "test_greater_equal_bcast", + "test_greater_equal_expanded", + "test_greater_equal", + "test_greater", + // // "test_gridsample_aligncorners_true", + // // "test_gridsample_bicubic", + // // "test_gridsample_bilinear", + // // "test_gridsample_border_padding", + // // "test_gridsample_nearest", + // // "test_gridsample_reflection_padding", + // // "test_gridsample_zeros_padding", + // // "test_gridsample", + // // "test_gru_batchwise", + // // "test_gru_defaults", + // // "test_gru_seq_length", + // // "test_gru_with_initial_bias", + // // "test_hammingwindow_expanded", + // // "test_hammingwindow_symmetric_expanded", + // // "test_hammingwindow_symmetric", + // // "test_hammingwindow", + // // "test_hannwindow_expanded", + // // "test_hannwindow_symmetric_expanded", + // // "test_hannwindow_symmetric", + // // "test_hannwindow", + // // "test_hardmax_axis_0", + // // "test_hardmax_axis_1", + // // "test_hardmax_axis_2", + // // "test_hardmax_default_axis", + // // "test_hardmax_example", + // // "test_hardmax_negative_axis", + // // "test_hardmax_one_hot", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", + // // "test_hardswish_expanded", + // // "test_hardswish", + "test_if", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", + "test_instancenorm_epsilon", + "test_instancenorm_example", + // "test_isinf_negative", + // "test_isinf_positive", + // "test_isinf", + // "test_isnan", + "test_layer_normalization_2d_axis_negative_1_expanded", + "test_layer_normalization_2d_axis_negative_1", + "test_layer_normalization_2d_axis_negative_2_expanded", + "test_layer_normalization_2d_axis_negative_2", + "test_layer_normalization_2d_axis0_expanded", + "test_layer_normalization_2d_axis0", + "test_layer_normalization_2d_axis1_expanded", + "test_layer_normalization_2d_axis1", + // // "test_layer_normalization_3d_axis_negative_1_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_1_epsilon", + // // "test_layer_normalization_3d_axis_negative_2_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_2_epsilon", + // // "test_layer_normalization_3d_axis_negative_3_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_3_epsilon", + // // "test_layer_normalization_3d_axis0_epsilon_expanded", + "test_layer_normalization_3d_axis0_epsilon", + "test_layer_normalization_3d_axis1_epsilon_expanded", + "test_layer_normalization_3d_axis1_epsilon", + // // "test_layer_normalization_3d_axis2_epsilon_expanded", + "test_layer_normalization_3d_axis2_epsilon", + "test_layer_normalization_4d_axis_negative_1_expanded", + "test_layer_normalization_4d_axis_negative_1", + // // "test_layer_normalization_4d_axis_negative_2_expanded", + "test_layer_normalization_4d_axis_negative_2", + // "test_layer_normalization_4d_axis_negative_3_expanded", + "test_layer_normalization_4d_axis_negative_3", + // "test_layer_normalization_4d_axis_negative_4_expanded", + "test_layer_normalization_4d_axis_negative_4", + "test_layer_normalization_4d_axis0_expanded", + "test_layer_normalization_4d_axis0", + "test_layer_normalization_4d_axis1_expanded", + "test_layer_normalization_4d_axis1", + // // "test_layer_normalization_4d_axis2_expanded", + "test_layer_normalization_4d_axis2", + "test_layer_normalization_4d_axis3_expanded", + "test_layer_normalization_4d_axis3", + "test_layer_normalization_default_axis_expanded", + "test_layer_normalization_default_axis", + "test_leakyrelu_default", + "test_leakyrelu_example", + "test_leakyrelu", + "test_less_bcast", + "test_less_equal_bcast_expanded", + "test_less_equal_bcast", + "test_less_equal_expanded", + "test_less_equal", + "test_less", + "test_log_example", + "test_log", + // // "test_logsoftmax_axis_0_expanded", + // // "test_logsoftmax_axis_0", + // // "test_logsoftmax_axis_1_expanded", + // // "test_logsoftmax_axis_1", + // // "test_logsoftmax_axis_2_expanded", + // // "test_logsoftmax_axis_2", + // // "test_logsoftmax_default_axis_expanded", + // // "test_logsoftmax_default_axis", + // // "test_logsoftmax_example_1_expanded", + // // "test_logsoftmax_example_1", + // // "test_logsoftmax_large_number_expanded", + // // "test_logsoftmax_large_number", + // // "test_logsoftmax_negative_axis_expanded", + // // "test_logsoftmax_negative_axis", + // "test_lrn_default", + // "test_lrn", + // // "test_lstm_batchwise", + // // "test_lstm_defaults", + // // "test_lstm_with_initial_bias", + // // "test_lstm_with_peepholes", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", + // // "test_matmulinteger", + // "test_max_example", + // "test_max_float16", + // "test_max_float32", + // "test_max_float64", + // "test_max_int16", + // "test_max_int32", + // "test_max_int64", + // "test_max_int8", + // "test_max_one_input", + // "test_max_two_inputs", + // "test_max_uint16", + // "test_max_uint32", + // "test_max_uint64", + // "test_max_uint8", + // "test_maxpool_1d_default", + // "test_maxpool_2d_ceil", + "test_maxpool_2d_default", + // "test_maxpool_2d_dilations", + "test_maxpool_2d_pads", + "test_maxpool_2d_precomputed_pads", + "test_maxpool_2d_precomputed_same_upper", + "test_maxpool_2d_precomputed_strides", + "test_maxpool_2d_same_lower", + "test_maxpool_2d_same_upper", + "test_maxpool_2d_strides", + // "test_maxpool_2d_uint8", + // "test_maxpool_3d_default", + // "test_maxpool_with_argmax_2d_precomputed_pads", + // "test_maxpool_with_argmax_2d_precomputed_strides", + // // "test_maxunpool_export_with_output_shape", + // // "test_maxunpool_export_without_output_shape", + // // "test_mean_example", + // // "test_mean_one_input", + // // "test_mean_two_inputs", + // // "test_melweightmatrix", + // "test_min_example", + // "test_min_float16", + // "test_min_float32", + // "test_min_float64", + // "test_min_int16", + // "test_min_int32", + // "test_min_int64", + // "test_min_int8", + // "test_min_one_input", + // "test_min_two_inputs", + // "test_min_uint16", + // "test_min_uint32", + // "test_min_uint64", + // "test_min_uint8", + // "test_mod_bcast", + // "test_mod_broadcast", + // "test_mod_float_mixed_sign_example", + // "test_mod_fmod_mixed_sign_example", + // "test_mod_int64_fmod", + // "test_mod_int64_mixed_sign_example", + // "test_mod_mixed_sign_float16", + // "test_mod_mixed_sign_float32", + // "test_mod_mixed_sign_float64", + // "test_mod_mixed_sign_int16", + // "test_mod_mixed_sign_int32", + // "test_mod_mixed_sign_int64", + // "test_mod_mixed_sign_int8", + // "test_mod_uint16", + // "test_mod_uint32", + // "test_mod_uint64", + // "test_mod_uint8", + // // "test_momentum_multiple", + // // "test_momentum", + "test_mul_bcast", + "test_mul_example", + // "test_mul_uint8", + "test_mul", + // "test_mvn_expanded", + // "test_mvn", + "test_neg_example", + "test_neg", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NC_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NC", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // // "test_nesterov_momentum", + // // "test_nllloss_NC_expanded", + // // "test_nllloss_NC", + // // "test_nllloss_NCd1_expanded", + // // "test_nllloss_NCd1_ii_expanded", + // // "test_nllloss_NCd1_ii", + // // "test_nllloss_NCd1_mean_weight_negative_ii_expanded", + // // "test_nllloss_NCd1_mean_weight_negative_ii", + // // "test_nllloss_NCd1_weight_expanded", + // // "test_nllloss_NCd1_weight_ii_expanded", + // // "test_nllloss_NCd1_weight_ii", + // // "test_nllloss_NCd1_weight", + // // "test_nllloss_NCd1", + // // "test_nllloss_NCd1d2_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii", + // // "test_nllloss_NCd1d2_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_reduction_mean", + // // "test_nllloss_NCd1d2_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight", + // // "test_nllloss_NCd1d2", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight", + // "test_nonmaxsuppression_center_point_box_format", + // "test_nonmaxsuppression_flipped_coordinates", + // "test_nonmaxsuppression_identical_boxes", + // "test_nonmaxsuppression_limit_output_size", + // "test_nonmaxsuppression_single_box", + // "test_nonmaxsuppression_suppress_by_IOU_and_scores", + // "test_nonmaxsuppression_suppress_by_IOU", + // "test_nonmaxsuppression_two_batches", + // "test_nonmaxsuppression_two_classes", + // "test_nonzero_example", + "test_not_2d", + "test_not_3d", + "test_not_4d", + // // "test_onehot_negative_indices", + // // "test_onehot_with_axis", + // // "test_onehot_with_negative_axis", + // // "test_onehot_without_axis", + // // "test_optional_get_element_sequence", + // // "test_optional_get_element", + // // "test_optional_has_element_empty", + // // "test_optional_has_element", + // "test_or_bcast3v1d", + // "test_or_bcast3v2d", + // "test_or_bcast4v2d", + // "test_or_bcast4v3d", + // "test_or_bcast4v4d", + // "test_or2d", + // "test_or3d", + // "test_or4d", + "test_pow_bcast_array", + "test_pow_bcast_scalar", + "test_pow_example", + // "test_pow_types_float", + // "test_pow_types_float32_int32", + // "test_pow_types_float32_int64", + // "test_pow_types_float32_uint32", + // "test_pow_types_float32_uint64", + // "test_pow_types_int", + // "test_pow_types_int32_float32", + // "test_pow_types_int32_int32", + // "test_pow_types_int64_float32", + // "test_pow_types_int64_int64", + "test_pow", + // "test_prelu_broadcast", + // "test_prelu_example", + // // "test_qlinearconv", + // // "test_qlinearmatmul_2D", + // // "test_qlinearmatmul_3D", + // // "test_quantizelinear_axis", + // // "test_quantizelinear", + "test_range_float_type_positive_delta_expanded", + "test_range_float_type_positive_delta", + "test_range_int32_type_negative_delta_expanded", + "test_range_int32_type_negative_delta", + "test_reciprocal_example", + "test_reciprocal", + "test_reduce_l1_default_axes_keepdims_example", + "test_reduce_l1_default_axes_keepdims_random", + "test_reduce_l1_do_not_keepdims_example", + "test_reduce_l1_do_not_keepdims_random", + "test_reduce_l1_keep_dims_example", + "test_reduce_l1_keep_dims_random", + "test_reduce_l1_negative_axes_keep_dims_example", + "test_reduce_l1_negative_axes_keep_dims_random", + "test_reduce_l2_default_axes_keepdims_example", + "test_reduce_l2_default_axes_keepdims_random", + "test_reduce_l2_do_not_keepdims_example", + "test_reduce_l2_do_not_keepdims_random", + "test_reduce_l2_keep_dims_example", + "test_reduce_l2_keep_dims_random", + "test_reduce_l2_negative_axes_keep_dims_example", + "test_reduce_l2_negative_axes_keep_dims_random", + "test_reduce_log_sum_asc_axes", + "test_reduce_log_sum_default", + "test_reduce_log_sum_desc_axes", + // tests "test_reduce_log_sum_exp_*" on opset17/opset18 are excluded because they use float64. + "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example", + "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random", + "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_example", + "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_random", + "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_example", + "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random", + "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example", + "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random", + "test_reduce_log_sum_negative_axes", + "test_reduce_log_sum", + "test_reduce_max_default_axes_keepdim_example", + "test_reduce_max_default_axes_keepdims_random", + "test_reduce_max_do_not_keepdims_example", + "test_reduce_max_do_not_keepdims_random", + "test_reduce_max_keepdims_example", + "test_reduce_max_keepdims_random", + "test_reduce_max_negative_axes_keepdims_example", + "test_reduce_max_negative_axes_keepdims_random", + "test_reduce_mean_default_axes_keepdims_example", + "test_reduce_mean_default_axes_keepdims_random", + "test_reduce_mean_do_not_keepdims_example", + "test_reduce_mean_do_not_keepdims_random", + "test_reduce_mean_keepdims_example", + "test_reduce_mean_keepdims_random", + "test_reduce_mean_negative_axes_keepdims_example", + "test_reduce_mean_negative_axes_keepdims_random", + "test_reduce_min_default_axes_keepdims_example", + "test_reduce_min_default_axes_keepdims_random", + "test_reduce_min_do_not_keepdims_example", + "test_reduce_min_do_not_keepdims_random", + "test_reduce_min_keepdims_example", + "test_reduce_min_keepdims_random", + "test_reduce_min_negative_axes_keepdims_example", + "test_reduce_min_negative_axes_keepdims_random", + "test_reduce_prod_default_axes_keepdims_example", + "test_reduce_prod_default_axes_keepdims_random", + "test_reduce_prod_do_not_keepdims_example", + "test_reduce_prod_do_not_keepdims_random", + "test_reduce_prod_keepdims_example", + "test_reduce_prod_keepdims_random", + "test_reduce_prod_negative_axes_keepdims_example", + "test_reduce_prod_negative_axes_keepdims_random", + "test_reduce_sum_default_axes_keepdims_example", + "test_reduce_sum_default_axes_keepdims_random", + "test_reduce_sum_do_not_keepdims_example", + "test_reduce_sum_do_not_keepdims_random", + "test_reduce_sum_empty_axes_input_noop_example", + "test_reduce_sum_empty_axes_input_noop_random", + "test_reduce_sum_keepdims_example", + "test_reduce_sum_keepdims_random", + "test_reduce_sum_negative_axes_keepdims_example", + "test_reduce_sum_negative_axes_keepdims_random", + "test_reduce_sum_square_default_axes_keepdims_example", + "test_reduce_sum_square_default_axes_keepdims_random", + "test_reduce_sum_square_do_not_keepdims_example", + "test_reduce_sum_square_do_not_keepdims_random", + "test_reduce_sum_square_keepdims_example", + "test_reduce_sum_square_keepdims_random", + "test_reduce_sum_square_negative_axes_keepdims_example", + "test_reduce_sum_square_negative_axes_keepdims_random", + "test_reflect_pad", + "test_relu", + // "test_reshape_allowzero_reordered", + "test_reshape_extended_dims", + "test_reshape_negative_dim", + "test_reshape_negative_extended_dims", + "test_reshape_one_dim", + "test_reshape_reduced_dims", + "test_reshape_reordered_all_dims", + "test_reshape_reordered_dims", + "test_reshape_reordered_last_dims", + "test_reshape_zero_and_negative_dim", + "test_reshape_zero_dim", + "test_resize_downsample_linear", + "test_resize_downsample_nearest", + "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", + // "test_resize_downsample_scales_cubic_align_corners", + "test_resize_downsample_scales_cubic", + // "test_resize_downsample_scales_linear_align_corners", + "test_resize_downsample_scales_linear", + "test_resize_downsample_scales_nearest", + "test_resize_downsample_sizes_cubic", + "test_resize_downsample_sizes_linear_pytorch_half_pixel", + "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", + "test_resize_downsample_sizes_nearest", + "test_resize_nearest", + "test_resize_tf_crop_and_resize", + "test_resize_upsample_linear", + "test_resize_upsample_nearest", + "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", + "test_resize_upsample_scales_cubic_align_corners", + "test_resize_upsample_scales_cubic_asymmetric", + "test_resize_upsample_scales_cubic", + "test_resize_upsample_scales_linear_align_corners", + "test_resize_upsample_scales_linear", + "test_resize_upsample_scales_nearest", + "test_resize_upsample_sizes_cubic", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + "test_resize_upsample_sizes_nearest", + // // "test_reversesequence_batch", + // // "test_reversesequence_time", + // // "test_rnn_seq_length", + // // "test_roialign_aligned_false", + // // "test_roialign_aligned_true", + // // "test_roialign", + // // "test_round", + // // "test_scan_sum", + // // "test_scan9_sum", + // // "test_scatter_elements_with_axis", + // // "test_scatter_elements_with_duplicate_indices", + // // "test_scatter_elements_with_negative_indices", + // // "test_scatter_elements_without_axis", + // // "test_scatter_with_axis", + // // "test_scatter_without_axis", + // // "test_scatternd_add", + // // "test_scatternd_multiply", + // // "test_scatternd", + // // "test_sce_mean_3d_expanded", + // // "test_sce_mean_3d_log_prob_expanded", + // // "test_sce_mean_3d_log_prob", + // // "test_sce_mean_3d", + // // "test_sce_mean_expanded", + // // "test_sce_mean_log_prob_expanded", + // // "test_sce_mean_log_prob", + // // "test_sce_mean_no_weight_ii_3d_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob", + // // "test_sce_mean_no_weight_ii_3d", + // // "test_sce_mean_no_weight_ii_4d_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob", + // // "test_sce_mean_no_weight_ii_4d", + // // "test_sce_mean_no_weight_ii_expanded", + // // "test_sce_mean_no_weight_ii_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_log_prob", + // // "test_sce_mean_no_weight_ii", + // // "test_sce_mean_weight_expanded", + // // "test_sce_mean_weight_ii_3d_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob", + // // "test_sce_mean_weight_ii_3d", + // // "test_sce_mean_weight_ii_4d_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob", + // // "test_sce_mean_weight_ii_4d", + // // "test_sce_mean_weight_ii_expanded", + // // "test_sce_mean_weight_ii_log_prob_expanded", + // // "test_sce_mean_weight_ii_log_prob", + // // "test_sce_mean_weight_ii", + // // "test_sce_mean_weight_log_prob_expanded", + // // "test_sce_mean_weight_log_prob", + // // "test_sce_mean_weight", + // // "test_sce_mean", + // // "test_sce_NCd1_mean_weight_negative_ii_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob", + // // "test_sce_NCd1_mean_weight_negative_ii", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob", + // // "test_sce_NCd1d2d3_sum_weight_high_ii", + // // "test_sce_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_mean_weight", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_none_no_weight", + // // "test_sce_none_expanded", + // // "test_sce_none_log_prob_expanded", + // // "test_sce_none_log_prob", + // // "test_sce_none_weights_expanded", + // // "test_sce_none_weights_log_prob_expanded", + // // "test_sce_none_weights_log_prob", + // // "test_sce_none_weights", + // // "test_sce_none", + // // "test_sce_sum_expanded", + // // "test_sce_sum_log_prob_expanded", + // // "test_sce_sum_log_prob", + // // "test_sce_sum", + // "test_selu_default", + // "test_selu_example", + // "test_selu", + // // "test_sequence_insert_at_back", + // // "test_sequence_insert_at_front", + // // "test_sequence_map_add_1_sequence_1_tensor_expanded", + // // "test_sequence_map_add_1_sequence_1_tensor", + // // "test_sequence_map_add_2_sequences_expanded", + // // "test_sequence_map_add_2_sequences", + // // "test_sequence_map_extract_shapes_expanded", + // // "test_sequence_map_extract_shapes", + // // "test_sequence_map_identity_1_sequence_1_tensor_expanded", + // // "test_sequence_map_identity_1_sequence_1_tensor", + // // "test_sequence_map_identity_1_sequence_expanded", + // // "test_sequence_map_identity_1_sequence", + // // "test_sequence_map_identity_2_sequences_expanded", + // // "test_sequence_map_identity_2_sequences", + // "test_shrink_hard", + // "test_shrink_soft", + "test_sigmoid_example", + "test_sigmoid", + // "test_sign", + // "test_simple_rnn_batchwise", + // "test_simple_rnn_defaults", + // "test_simple_rnn_with_initial_bias", + "test_sin_example", + "test_sin", + "test_sinh_example", + "test_sinh", + // // "test_size_example", + // // "test_size", + "test_slice_default_axes", + "test_slice_default_steps", + // "test_slice_end_out_of_bounds", + "test_slice_neg_steps", + "test_slice_neg", + "test_slice_negative_axes", + // "test_slice_start_out_of_bounds", + "test_slice", + // "test_softmax_axis_0_expanded", + // "test_softmax_axis_0", + // "test_softmax_axis_1_expanded", + // "test_softmax_axis_1", + "test_softmax_axis_2_expanded", + "test_softmax_axis_2", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // "test_softmax_cross_entropy_mean_3d_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob", + // "test_softmax_cross_entropy_mean_3d", + // "test_softmax_cross_entropy_mean_expanded", + // "test_softmax_cross_entropy_mean_log_prob_expanded", + // "test_softmax_cross_entropy_mean_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_log_prob", + // "test_softmax_cross_entropy_mean_weight", + // "test_softmax_cross_entropy_mean", + // "test_softmax_cross_entropy_none_expanded", + // "test_softmax_cross_entropy_none_log_prob_expanded", + // "test_softmax_cross_entropy_none_log_prob", + // "test_softmax_cross_entropy_none_weights_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob", + // "test_softmax_cross_entropy_none_weights", + // "test_softmax_cross_entropy_none", + // "test_softmax_cross_entropy_sum_expanded", + // "test_softmax_cross_entropy_sum_log_prob_expanded", + // "test_softmax_cross_entropy_sum_log_prob", + // "test_softmax_cross_entropy_sum", + "opset13/test_softmax_default_axis_expanded", + "opset13/test_softmax_default_axis", + "test_softmax_example_expanded", + "test_softmax_example", + "test_softmax_large_number_expanded", + "test_softmax_large_number", + "test_softmax_negative_axis_expanded", + "test_softmax_negative_axis", + // // "test_softplus_example", + // // "test_softplus", + // // "test_softsign_example", + // // "test_softsign", + // "test_spacetodepth_example", + // "test_spacetodepth", + "test_split_equal_parts_1d", + "test_split_equal_parts_2d", + "test_split_equal_parts_default_axis", + "test_split_variable_parts_1d", + "test_split_variable_parts_2d", + "test_split_variable_parts_default_axis", + "test_split_zero_size_splits", + "test_sqrt_example", + "test_sqrt", + "test_squeeze_negative_axes", + "test_squeeze", + // // "test_stft_with_window", + // // "test_stft", + // // "test_strnormalizer_export_monday_casesensintive_lower", + // // "test_strnormalizer_export_monday_casesensintive_nochangecase", + // // "test_strnormalizer_export_monday_casesensintive_upper", + // // "test_strnormalizer_export_monday_empty_output", + // // "test_strnormalizer_export_monday_insensintive_upper_twodim", + // // "test_strnormalizer_nostopwords_nochangecase", + "test_sub_bcast", + "test_sub_example", + // "test_sub_uint8", + "test_sub", + // "test_sum_example", + // "test_sum_one_input", + // "test_sum_two_inputs", + "test_tan_example", + "test_tan", + "test_tanh_example", + "test_tanh", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip0", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", + // // "test_tfidfvectorizer_tf_only_bigrams_skip0", + // // "test_tfidfvectorizer_tf_onlybigrams_levelempty", + // // "test_tfidfvectorizer_tf_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_uniandbigrams_skip5", + "test_thresholdedrelu_default", + "test_thresholdedrelu_example", + "test_thresholdedrelu", + "test_tile_precomputed", + "test_tile", + // // "test_top_k_negative_axis", + // // "test_top_k_smallest", + // // "test_top_k", + // // "test_training_dropout_default_mask", + // // "test_training_dropout_default", + // // "test_training_dropout_mask", + // // "test_training_dropout_zero_ratio_mask", + // // "test_training_dropout_zero_ratio", + // // "test_training_dropout", + "test_transpose_all_permutations_0", + "test_transpose_all_permutations_1", + "test_transpose_all_permutations_2", + "test_transpose_all_permutations_3", + "test_transpose_all_permutations_4", + "test_transpose_all_permutations_5", + "test_transpose_default", + // "test_tril_neg", + // "test_tril_one_row_neg", + // "test_tril_out_neg", + // "test_tril_out_pos", + // "test_tril_pos", + // "test_tril_square_neg", + // "test_tril_square", + // "test_tril_zero", + // "test_tril", + // "test_triu_neg", + // "test_triu_one_row", + // "test_triu_out_neg_out", + // "test_triu_out_pos", + // "test_triu_pos", + // "test_triu_square_neg", + // "test_triu_square", + // "test_triu_zero", + // "test_triu", + // // "test_unique_not_sorted_without_axis", + // // "test_unique_sorted_with_axis_3d", + // // "test_unique_sorted_with_axis", + // // "test_unique_sorted_with_negative_axis", + // // "test_unique_sorted_without_axis", + "test_unsqueeze_axis_0", + "test_unsqueeze_axis_1", + "test_unsqueeze_axis_2", + "test_unsqueeze_axis_3", + "test_unsqueeze_negative_axes", + "test_unsqueeze_three_axes", + "test_unsqueeze_two_axes", + "test_unsqueeze_unsorted_axes", + "test_unsqueeze", + "test_wrap_pad" + // "test_upsample_nearest", + // "test_where_example", + // "test_where_long_example", + // "test_xor_bcast3v1d", + // "test_xor_bcast3v2d", + // "test_xor_bcast4v2d", + // "test_xor_bcast4v3d", + // "test_xor_bcast4v4d", + // "test_xor2d", + // "test_xor3d", + // "test_xor4d" ]; const path = require('path'); @@ -21,6 +1061,12 @@ const { spawnSync } = require('child_process'); const ONNX_TEST_RUNNER_FILENAME = path.join(__dirname, 'onnx_test_runner' + (process.platform === 'win32' ? '.exe' : '')); +if (!fs.existsSync(ONNX_TEST_RUNNER_FILENAME)) { + console.error('Error: onnx_test_runner not found.'); + console.error('Please perform a build and run this script in the build folder.'); + process.exit(1); +} + if (process.argv.includes('-h')) { console.log(HELP); process.exit(0); @@ -34,65 +1080,55 @@ if (!test_data_path) { test_data_path = test_data_path.substring(3); } -const test_models = []; +let test_models = DEFAULT_TESTS; const test_model_list = process.argv.find(arg => arg.startsWith('-m=')); if (test_model_list) { + test_models = []; test_model_list.substring(3).split(';').forEach(test_model => { test_models.push(test_model); }); } -const tests = new Set(test_model_list ? test_models : DEFAULT_TESTS); -const test_cases = []; -fs.readdirSync(test_data_path, { withFileTypes: true }).forEach(dirent => { - if (dirent.isDirectory()) { - const opset = dirent.name; - fs.readdirSync(path.join(test_data_path, opset), { withFileTypes: true }).forEach(dirent => { - if (dirent.isDirectory()) { - const name = dirent.name; - if (tests.has(name)) { - test_cases.push(path.join(test_data_path, opset, name)); - } - } - }); - } -}); +const tests = new Set(test_models); + +const TEST_ROOT = path.join(__dirname, 'webgpu_test_root'); -let passed = []; -let not_implemented = []; -let failed = []; -test_cases.forEach(test_case => { - process.stdout.write(`Running test case: "${test_case}"...`); - const args = [ - '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', test_case, - ]; - if (VERBOSE) { - args.unshift('-v'); +let test_data_ready = false; +const test_list_json_data = JSON.stringify(test_models, null, 2); +const test_list_json_filepath = path.join(TEST_ROOT, 'test_list.json'); +if (fs.existsSync(TEST_ROOT)) { + if (fs.existsSync(test_list_json_filepath)) { + test_data_ready = fs.readFileSync(test_list_json_filepath).toString() == test_list_json_data; } - const p = spawnSync(ONNX_TEST_RUNNER_FILENAME, args, { shell: true, stdio: ['ignore', 'pipe', 'pipe'] }); - if (p.status !== 0) { - process.stdout.write('Failed\n'); - failed.push(test_case); - } else if (!p.stdout.toString().includes('Not implemented: 0')) { - process.stdout.write('Not Implemented\n'); - not_implemented.push(test_case); - } else { - process.stdout.write('OK\n'); - passed.push(test_case); + if (!test_data_ready) { + fs.rmdirSync(TEST_ROOT, { recursive: true }); } -}); +} +if (!test_data_ready) { + fs.mkdirSync(TEST_ROOT); -console.log(`\n${passed.length} tests passed.`); -console.log(`\n${not_implemented.length} tests not implemented:`); -not_implemented.slice(0, 3).forEach(test_case => { - console.log(` ${test_case}`); -}); -if (not_implemented.length > 3) { - console.log(` ...`); + fs.readdirSync(test_data_path, { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const opset = dirent.name; + fs.readdirSync(path.join(test_data_path, opset), { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const name = dirent.name; + if (tests.has(name)) { + fs.symlinkSync(path.join(test_data_path, opset, name), path.join(TEST_ROOT, `${opset}_${name}`), 'junction'); + } + } + }); + } + }); + fs.writeFileSync(test_list_json_filepath, test_list_json_data); } -console.log(`\n${failed.length} tests failed:`); -failed.slice(0, 3).forEach(test_case => { - console.log(` ${test_case}`); -}); -if (failed.length > 3) { - console.log(` ...`); + +const args = ['-A', '-M', '-j', '1', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +if (VERBOSE) { + args.unshift('-v'); } +process.exit( + spawnSync( + ONNX_TEST_RUNNER_FILENAME, + args, + { shell: true, cwd: __dirname, stdio: 'inherit' } + ).status); From 947aee18a2b15d1aa2501f546aa6426a20bf6466 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 12:46:01 -0700 Subject: [PATCH 027/114] device lost handler --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 9e51cc08eec0..776fbb069bb5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -86,10 +86,15 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info wgpu::RequiredLimits required_limits = GetAvailableRequiredLimits(adapter_); device_desc.requiredLimits = &required_limits; - // TODO: temporary error handling + // TODO: revise temporary error handling device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, const char* message) { LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << message; }); + // TODO: revise temporary device lost handling + device_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device& /*device*/, wgpu::DeviceLostReason reason, const char* message) { + // cannot use ORT logger because it may be already destroyed + std::cerr << "WebGPU device lost (" << int(reason) << "): " << message; + }); wgpu::RequestDeviceCallbackInfo req_device_callback_info = {}; req_device_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; From 99b2578a49444684ff820afa56ebb84a911c35de Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 14:38:47 -0700 Subject: [PATCH 028/114] add '-a' and '-t' to test runner --- .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 19 ++++++++++++++++--- .../test/providers/webgpu/test_webgpu.js | 3 ++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 624cfd80dd8f..9bc19a2099a4 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -182,10 +182,23 @@ To add more tests to the suite list, edit the file at `C:\code\onnxruntime\onnxr to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` -onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` -The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. +The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. See `onnxruntime\core\providers\webgpu\webgpu_provider_options.h` for available WebGPU EP options. + +The `-a` and `-t` flags are used to specify the absolute and relative tolerance for the test. +- currently the value is set to `0.0001` for both absolute and relative tolerance for the WebGPU EP. +- `onnx_test_runner` will try to load file `\testdata\onnx_backend_test_series_overrides.jsonc>` if available to set the default tolerance values. It is recommended to set the tolerance values in the command line to ensure consistent behavior. + > This is why the following command may have different results: + > + > ``` + > C:\code\onnxruntime> build\Windows\Debug\Debug\onnx_test_runner.exe -e webgpu C:\code\onnxruntime\js\test\data\node\opset9\test_asin_example + > ``` + > + > ``` + > C:\code\onnxruntime\build\Windows\Debug\Debug> onnx_test_runner.exe -e webgpu C:\code\onnxruntime\js\test\data\node\opset9\test_asin_example + > ``` Some features are useful but if you are troubleshooting and want to rule out the cause, you can: @@ -195,5 +208,5 @@ Some features are useful but if you are troubleshooting and want to rule out the Example: ``` -onnx_test_runner.exe -v -A -M -j 1 -e webgpu -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -A -M -j 1 -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js index 254bded19ae7..e6d28c9e5b4d 100644 --- a/onnxruntime/test/providers/webgpu/test_webgpu.js +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -1122,7 +1122,8 @@ if (!test_data_ready) { fs.writeFileSync(test_list_json_filepath, test_list_json_data); } -const args = ['-A', '-M', '-j', '1', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +// const args = ['-A', '-M', '-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +const args = ['-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', 'webgpu_test_root']; if (VERBOSE) { args.unshift('-v'); } From aa7b3f52aaef02e6faed4acd4668f597507b6672 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 15:25:51 -0700 Subject: [PATCH 029/114] atol/rtol 0.0001 -> 0.001 --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 6 +++--- onnxruntime/test/providers/webgpu/test_webgpu.js | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 9bc19a2099a4..3e501cd957e0 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -182,13 +182,13 @@ To add more tests to the suite list, edit the file at `C:\code\onnxruntime\onnxr to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` -onnx_test_runner.exe -v -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -e webgpu -a 0.001 -t 0.001 -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. See `onnxruntime\core\providers\webgpu\webgpu_provider_options.h` for available WebGPU EP options. The `-a` and `-t` flags are used to specify the absolute and relative tolerance for the test. -- currently the value is set to `0.0001` for both absolute and relative tolerance for the WebGPU EP. +- currently the value is set to `0.001` for both absolute and relative tolerance for the WebGPU EP. - `onnx_test_runner` will try to load file `\testdata\onnx_backend_test_series_overrides.jsonc>` if available to set the default tolerance values. It is recommended to set the tolerance values in the command line to ensure consistent behavior. > This is why the following command may have different results: > @@ -208,5 +208,5 @@ Some features are useful but if you are troubleshooting and want to rule out the Example: ``` -onnx_test_runner.exe -v -A -M -j 1 -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -A -M -j 1 -e webgpu -a 0.001 -t 0.001 -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js index e6d28c9e5b4d..d6c452e1625c 100644 --- a/onnxruntime/test/providers/webgpu/test_webgpu.js +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -1122,8 +1122,8 @@ if (!test_data_ready) { fs.writeFileSync(test_list_json_filepath, test_list_json_data); } -// const args = ['-A', '-M', '-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; -const args = ['-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', 'webgpu_test_root']; +// const args = ['-A', '-M', '-j', '1', '-t', '0.001', '-a', '0.001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +const args = ['-j', '1', '-t', '0.001', '-a', '0.001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', 'webgpu_test_root']; if (VERBOSE) { args.unshift('-v'); } From e659acd0eb3d35603fb77e7ff1e7ac3943c6d1a1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:31:52 -0700 Subject: [PATCH 030/114] Fix uniform --- .../core/providers/webgpu/program_manager.cc | 48 +++++------- .../core/providers/webgpu/program_manager.h | 8 -- .../core/providers/webgpu/shader_helper.cc | 4 + .../core/providers/webgpu/shader_variable.cc | 8 +- .../core/providers/webgpu/shader_variable.h | 61 ++++++++------- .../core/providers/webgpu/webgpu_context.cc | 74 ++++++++++++------- 6 files changed, 105 insertions(+), 98 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index de228a038b7d..00036a915f69 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -14,37 +14,7 @@ namespace onnxruntime { namespace webgpu { ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline) - : name{program.Name()}, compute_pipeline{compute_pipeline} { - // prepare uniform info - size_t current_offset = 0; - for (const auto& uniform : program.UniformVariables()) { - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; - size_t length = uniform.length; - size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; - // https://www.w3.org/TR/WGSL/#alignof - size_t base_alignment = is_f16 - ? (length > 4 ? 16 : length > 2 ? 8 - : length * element_size) - : (length > 2 ? 16 : length * element_size); - size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; - - current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; - uniforms.push_back({uniform.data_type, current_offset, length}); - - // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). - // For float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). - size_t element_per_struct = is_f16 ? 8 : 4; - current_offset += - length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; - } - - // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set - // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. - const int max_alignment_of_field = 16; - uniform_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; -} + : name{program.Name()}, compute_pipeline{compute_pipeline} {} Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -108,6 +78,22 @@ Status ProgramManager::Build(const ProgramBase& program, auto shader_module = device_.CreateShaderModule(&descriptor); + // TODO: a new cache hierarchy for constants. + // + // Explaination: + // Currently, we use Uniforms for dynamic data. This helps to reduce the number of program artifacts. + // + // "dynamic data" here means the data the determined at runtime, such as the shape of the input tensor. + // + // However, some programs may not necessarily depend on dynamic data. For example, "Clip" may depend on the value of "min" and "max". + // We are using uniforms for the value of "min" and "max" in the current implementation, but usually "min" and "max" are determined + // earlier because they are either from Attributes or from the initializers of the model. + // + // Questions: + // - can we use one instance of ShaderModule to create multiple ComputePipeline? + // - is there any benefit to do so compared to the current implementation? + // + // process overridable constants if available size_t constant_count = program.OverridableConstants().size(); diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 9d1b7655c864..087c75bfee77 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -21,20 +21,12 @@ class Tensor; namespace webgpu { -struct ProgramUniformInfo { - ProgramUniformVariableDataType data_type; - size_t offset; - size_t length; -}; - class ProgramArtifact { public: ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline); std::string name; wgpu::ComputePipeline compute_pipeline; - std::vector uniforms; - size_t uniform_total_size; ProgramArtifact(ProgramArtifact&&) = default; ProgramArtifact& operator=(ProgramArtifact&&) = default; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 3986b13e0a7d..5883696430de 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -148,6 +148,10 @@ std::string ShaderHelper::GetFinalSourceCode() { const auto& data_type = uniform_def.data_type; const auto length = uniform_value.length; + if (length == 0) { + continue; + } + if (first) { first = false; } else { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 9483ab19036c..fda4ad72deb2 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -12,12 +12,12 @@ namespace onnxruntime { namespace webgpu { -ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank) +ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank) : name_(name), type_(type), rank_(rank), usage_(UseUniform) { Init(); } -ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims) +ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims) : name_(name), type_(type), rank_(static_cast(dims.NumDimensions())), dims_(dims), usage_(None) { Init(); } @@ -171,7 +171,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } } -std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const { +std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -199,7 +199,7 @@ std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const { return ss.str(); } -std::string ShaderVariable::SetByOffsetImpl(const std::string& offset, const std::string& value) const { +std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string_view value) const { std::ostringstream ss; ss.imbue(std::locale::classic()); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index fbdb6590a735..34d767414841 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -14,7 +14,7 @@ namespace onnxruntime { namespace webgpu { template -std::string GetElementAt(const std::string& var, const TIdx& idx, int rank, bool is_f16 = false) { +std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool is_f16 = false) { // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. if (var.rfind("uniform.", 0) == 0) { if (rank > 4) { @@ -31,34 +31,32 @@ std::string GetElementAt(const std::string& var, const TIdx& idx, int rank, bool return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); } } - } else { - return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; } - } else { - return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; } + + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : std::string{var}; } class ShaderVariable { public: - ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank); - ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims); + ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank); + ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims); ShaderVariable(ShaderVariable&&) = default; ShaderVariable& operator=(ShaderVariable&&) = default; // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. - inline std::string OffsetToIndices(const std::string& offset_expr) const; + inline std::string OffsetToIndices(std::string_view offset_expr) const; // create a WGSL expression (u32) for getting offset from indices. // \param indices: a WGSL expression ({varname}_indices_t) representing the indices. - inline std::string IndicesToOffset(const std::string& indices_expr) const; + inline std::string IndicesToOffset(std::string_view indices_expr) const; // create a WGSL expression (u32) for getting original offset from broadcasted indices. // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. // \param broadcasted_result: the broadcasted result variable. - inline std::string BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const; + inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const; // create a WGSL expression ({varname}_indices_t) as an indices literal // \param init: a list of indices values. @@ -70,13 +68,13 @@ class ShaderVariable { // \param idx: the index (i32|u32) of the dimension to set. // \param value: the value (u32) to set. template - inline std::string IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const; + inline std::string IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const; // create a WGSL expression (u32) for getting value of the specified dimension of the indices. // \param indices_var: name of the indices variable ({varname}_indices_t). // \param idx: the index (i32|u32) of the dimension to get. template - inline std::string IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const; + inline std::string IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const; // create a WGSL statement for setting data at the given indices. // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). @@ -86,7 +84,7 @@ class ShaderVariable { // create a WGSL statement for setting data at the given indices. // \param indices_var: name of the indices variable ({varname}_indices_t). // \param value: the value ({varname}_value_t) to set. - inline std::string SetByIndices(const std::string& indices_var, const std::string& value) const; + inline std::string SetByIndices(std::string_view indices_var, std::string_view value) const; // create a WGSL statement for setting data at the given offset. // \param offset: a WGSL expression (u32) representing the offset. @@ -101,7 +99,7 @@ class ShaderVariable { // create a WGSL expression ({varname}_value_t) for getting data at the given indices. // \param indices_var: name of the indices variable ({varname}_indices_t). - inline std::string GetByIndices(const std::string& indices_var) const; + inline std::string GetByIndices(std::string_view indices_var) const; // create a WGSL expression ({varname}_value_t) for getting data at the given offset. // \param offset: a WGSL expression (u32) representing the offset. @@ -131,8 +129,8 @@ class ShaderVariable { void Init(); void Impl(std::ostringstream& ss) const; - std::string GetByOffsetImpl(const std::string& offset) const; - std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; + std::string GetByOffsetImpl(std::string_view offset) const; + std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; std::string_view StorageType() const; std::string_view ValueType() const; @@ -167,23 +165,29 @@ template >> std::string pass_as_string(T&& v) { return std::to_string(std::forward(v)); } +template +std::string_view pass_as_string(std::string_view sv) { + return sv; +} template -std::string pass_as_string(const T& v) { - return v; +std::string pass_as_string(T&& v) { + return std::forward(v); } } // namespace detail -inline std::string ShaderVariable::OffsetToIndices(const std::string& offset_expr) const { +inline std::string ShaderVariable::OffsetToIndices(std::string_view offset_expr) const { usage_ |= UseOffsetToIndices; - return rank_ < 2 ? offset_expr : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); + return rank_ < 2 ? std::string{offset_expr} + : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); } -inline std::string ShaderVariable::IndicesToOffset(const std::string& indices_expr) const { +inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr) const { usage_ |= UseIndicesToOffset; - return rank_ < 2 ? indices_expr : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); + return rank_ < 2 ? std::string{indices_expr} + : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); } -inline std::string ShaderVariable::BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const { +inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { usage_ |= UseBroadcastedIndicesToOffset; broadcasted_to_.push_back(broadcasted_result); return MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); @@ -199,14 +203,15 @@ inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { } template -inline std::string ShaderVariable::IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const { +inline std::string ShaderVariable::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); } template -inline std::string ShaderVariable::IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const { - return rank_ < 2 ? indices_var : GetElementAt(indices_var, idx_expr, rank_); +inline std::string ShaderVariable::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + return rank_ < 2 ? std::string{indices_var} + : GetElementAt(indices_var, idx_expr, rank_); } template @@ -229,7 +234,7 @@ inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { } } -inline std::string ShaderVariable::SetByIndices(const std::string& indices_var, const std::string& value) const { +inline std::string ShaderVariable::SetByIndices(std::string_view indices_var, std::string_view value) const { if (rank_ < 2) { return SetByOffset(indices_var, value); } else { @@ -258,7 +263,7 @@ inline std::string ShaderVariable::Get(TIndices&&... indices) const { } } -inline std::string ShaderVariable::GetByIndices(const std::string& indices_var) const { +inline std::string ShaderVariable::GetByIndices(std::string_view indices_var) const { if (rank_ < 2) { return GetByOffset(indices_var); } else { diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 776fbb069bb5..d2428d8bb7be 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -157,7 +157,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return Status::OK(); } - ProgramMetadata metadata = program.GetMetadata(); + const ProgramMetadata metadata = program.GetMetadata(); // validate program metadata { @@ -227,35 +227,55 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog #endif } + // prepare uniform info + const auto& uniforms = program.UniformVariables(); + size_t current_offset = 0; + std::vector> uniform_and_offsets; + uniform_and_offsets.reserve(uniforms.size()); + for (const auto& uniform : uniforms) { + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + size_t length = uniform.length; + + // skip zero-length uniform + if (length == 0) { + continue; + } + + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // https://www.w3.org/TR/WGSL/#alignof + size_t base_alignment = is_f16 + ? (length > 4 ? 16 : length > 2 ? 8 + : length * element_size) + : (length > 2 ? 16 : length * element_size); + size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; + + current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + uniform_and_offsets.emplace_back(uniform, current_offset); + + // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). + // For float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). + size_t element_per_struct = is_f16 ? 8 : 4; + current_offset += + length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + } + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. + const size_t max_alignment_of_field = 16; + const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; + WGPUBuffer uniform_buffer = nullptr; - auto uniform_buffer_size = program_artifact->uniform_total_size; - if (uniform_buffer_size > 0) { - auto num_uniforms = program.UniformVariables().size(); - ORT_ENFORCE(program_artifact->uniforms.size() == num_uniforms, - "Uniforms size mismatch. Artifact: ", program_artifact->uniforms.size(), ", Current: ", num_uniforms); - - std::vector uniform_data(uniform_buffer_size); - - for (size_t i = 0; i < num_uniforms; ++i) { - const auto& uniform = program.UniformVariables()[i]; - const auto& artifact_uniform = program_artifact->uniforms[i]; - - ORT_ENFORCE(uniform.data_type == artifact_uniform.data_type, - "Uniform[", i, "] data type mismatch. Artifact: ", artifact_uniform.data_type, - ", Current: ", uniform.data_type); - ORT_ENFORCE(uniform.length == artifact_uniform.length, - "Uniform[", i, "] elements number mismatch. Artifact: ", artifact_uniform.length, ", Current: ", uniform.length); - ORT_ENFORCE(uniform.data.size() == artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], - "Uniform[", i, "] data size mismatch. Artifact: ", artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], - ", Current: ", uniform.data.size()); - - auto offset = artifact_uniform.offset; - auto size = uniform.data.size(); - memcpy(uniform_data.data() + offset, uniform.data.data(), size); + if (uniform_buffer_total_size > 0) { + std::vector uniform_data_buffer(uniform_buffer_total_size); + + for (auto const& [uniform, offset] : uniform_and_offsets) { + memcpy(uniform_data_buffer.data() + offset, uniform.data.data(), uniform.data.size()); } - uniform_buffer = buffer_mgr_->Create(uniform_buffer_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); - device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data.data(), uniform_buffer_size); + uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); } const auto& compute_pass_encoder = GetComputePassEncoder(); From 6ad89c56bfe85794c50f1e547446d649083be53a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:32:51 -0700 Subject: [PATCH 031/114] add some unary ops --- .../webgpu/math/unary_elementwise_ops.cc | 193 +++++++++++++++--- .../webgpu/math/unary_elementwise_ops.h | 33 ++- .../webgpu/webgpu_execution_provider.cc | 66 +++--- 3 files changed, 225 insertions(+), 67 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 5c774df84638..0ae48ccbd634 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -24,43 +24,172 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ - class OP_TYPE final : public WebGpuKernel { \ - public: \ - OP_TYPE(const OpKernelInfo& info) : WebGpuKernel{info} {} \ - \ - protected: \ - Status ComputeInternal(ComputeContext& context) const override { \ - const auto* input_tensor = context.Input(0); \ - auto* output_tensor = context.Output(0, input_tensor->Shape()); \ - SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; \ - UnaryElementwiseProgram program{#OP_TYPE, __VA_ARGS__}; \ - program \ - .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) \ - .Outputs({output_tensor}) \ - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) \ - .UniformVariables({ \ - {static_cast(vec_size)}, \ - }); \ - return context.RunProgram(program); \ - } \ +Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; + UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_}; + program + .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) + .Outputs({output_tensor}) + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .UniformVariables({ + {static_cast(vec_size)}, + }); + ORT_RETURN_IF_ERROR(ConfigureProgram(program)); + return context.RunProgram(program); +} + +#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public UnaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : UnaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ }; -#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ - KERNEL_CLASS); +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION_FROM, VERSION_TO, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); -#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ - KERNEL_CLASS); +// +// math +// WEBGPU_ELEMENTWISE_IMPL(Abs, "abs(a)") -WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, Abs, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, Abs, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Neg, "-a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Neg, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Neg, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Floor, "floor(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Floor, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Ceil, "ceil(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Ceil, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Reciprocal, "1.0/a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Reciprocal, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sqrt, "sqrt(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sqrt, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) + +constexpr char ErfImpl[] = R"( +const r0: x_value_t = 0.3275911; +const r1: x_value_t = 0.254829592; +const r2: x_value_t = -0.284496736; +const r3: x_value_t = 1.421413741; +const r4: x_value_t = -1.453152027; +const r5: x_value_t = 1.061405429; + +fn erf_v(v: vec4) -> vec4 { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +} +)"; + +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Log, "log(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Log, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) + +class HardSigmoid final : public UnaryElementwise { + public: + HardSigmoid(const OpKernelInfo& info) + : UnaryElementwise{ + info, + "HardSigmoid", + // alpha = uniforms.f32_attr[0] + // beta = uniforms.f32_attr[1] + "max(vec4(0.0), min(vec4(1.0), x_value_t(uniforms.f32_attr[0]) * a + vec4(uniforms.f32_attr[1])))"} { + info.GetAttrOrDefault("alpha", attr, 0.2f); + info.GetAttrOrDefault("beta", attr + 1, 0.5f); + } + + Status ConfigureProgram(UnaryElementwiseProgram& program) const override { + program.UniformVariables({gsl::make_span(attr, 2), {}}); + return Status::OK(); + } + + protected: + float attr[2]; +}; + +WEBGPU_ELEMENTWISE_KERNEL(HardSigmoid, 6, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sin, "sin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cos, "cos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tan, "tan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Tan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asin, "asin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acos, "acos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Atan, "atan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Atan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sinh, "sinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asinh, "asinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acosh, "acosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) + +// todo: logical ops + +// +// activation +// + +// todo: clip + +// constexpr char EluImpl[] = R"( +//)"; +// +// WEBGPU_ELEMENTWISE_IMPL(Elu, "elu_v(a)", ) // TODO: add other unary elementwise ops diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 837f66af30dd..dbf15248b6b1 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -11,15 +11,44 @@ namespace webgpu { class UnaryElementwiseProgram final : public Program { public: - UnaryElementwiseProgram(const std::string& kernel_name, const std::string& expression, const std::string& additional_impl = "") + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl) : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl} { } Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size + {"f32_attr", ProgramUniformVariableDataType::Float32}, // float type attribute(s) + {"i32_attr", ProgramUniformVariableDataType::Int32}); // int type attribute(s) private: + std::string_view expression_; + std::string_view additional_impl_; +}; + +// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch +// the std::string to std::string_view. This will avoid the cost of copying the string. + +class UnaryElementwise : public WebGpuKernel { + public: + UnaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl = "") : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl} {} + + protected: + Status ComputeInternal(ComputeContext& context) const final; + virtual Status ConfigureProgram(UnaryElementwiseProgram& program) const { + program.UniformVariables({{}, {}}); // empty for both float and int attribute(s) + return Status::OK(); + } + + private: + std::string kernel_name_; std::string expression_; std::string additional_impl_; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e7688d1fafb9..202742a1c79b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -397,39 +397,39 @@ std::unique_ptr RegisterKernels() { // unary - math KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), KERNEL_CREATE_INFO(13, Abs), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), - // KERNEL_CREATE_INFO(13, Neg), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), - // KERNEL_CREATE_INFO(13, Floor), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), - // KERNEL_CREATE_INFO(13, Ceil), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), - // KERNEL_CREATE_INFO(13, Reciprocal), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), - // KERNEL_CREATE_INFO(13, Sqrt), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), - // KERNEL_CREATE_INFO(13, Exp), - // KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), - // KERNEL_CREATE_INFO(13, Erf), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), - // KERNEL_CREATE_INFO(13, Sigmoid), - // KERNEL_CREATE_INFO(6, HardSigmoid), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), - // KERNEL_CREATE_INFO(13, Log), - - // KERNEL_CREATE_INFO(7, Sin), - // KERNEL_CREATE_INFO(7, Cos), - // KERNEL_CREATE_INFO(7, Tan), - // KERNEL_CREATE_INFO(7, Asin), - // KERNEL_CREATE_INFO(7, Acos), - // KERNEL_CREATE_INFO(7, Atan), - // KERNEL_CREATE_INFO(9, Sinh), - // KERNEL_CREATE_INFO(9, Cosh), - // KERNEL_CREATE_INFO(9, Asinh), - // KERNEL_CREATE_INFO(9, Acosh), - // KERNEL_CREATE_INFO(9, Atanh), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), - // KERNEL_CREATE_INFO(13, Tanh), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + KERNEL_CREATE_INFO(13, Neg), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + KERNEL_CREATE_INFO(13, Floor), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + KERNEL_CREATE_INFO(13, Ceil), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + KERNEL_CREATE_INFO(13, Reciprocal), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + KERNEL_CREATE_INFO(13, Sqrt), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + KERNEL_CREATE_INFO(13, Exp), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + KERNEL_CREATE_INFO(13, Erf), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + KERNEL_CREATE_INFO(13, Log), + + KERNEL_CREATE_INFO(7, Sin), + KERNEL_CREATE_INFO(7, Cos), + KERNEL_CREATE_INFO(7, Tan), + KERNEL_CREATE_INFO(7, Asin), + KERNEL_CREATE_INFO(7, Acos), + KERNEL_CREATE_INFO(7, Atan), + KERNEL_CREATE_INFO(9, Sinh), + KERNEL_CREATE_INFO(9, Cosh), + KERNEL_CREATE_INFO(9, Asinh), + KERNEL_CREATE_INFO(9, Acosh), + KERNEL_CREATE_INFO(9, Atanh), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), + KERNEL_CREATE_INFO(13, Tanh), // KERNEL_CREATE_INFO(1, Not), // KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), From 8361fc3e440b53bb20671831ed7e10631d4fb528 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:58:24 -0700 Subject: [PATCH 032/114] various of fixes --- .../webgpu/math/unary_elementwise_ops.cc | 58 +++--- .../webgpu/math/unary_elementwise_ops.h | 17 +- onnxruntime/core/providers/webgpu/program.cc | 40 ++++- onnxruntime/core/providers/webgpu/program.h | 53 ++++-- .../providers/webgpu/program_cache_key.cc | 50 ++++-- .../core/providers/webgpu/program_manager.cc | 3 +- .../core/providers/webgpu/shader_helper.cc | 169 +++++++++++++++++- .../core/providers/webgpu/shader_helper.h | 43 +++-- .../core/providers/webgpu/shader_variable.cc | 90 ++++++---- .../core/providers/webgpu/shader_variable.h | 34 ++-- .../core/providers/webgpu/webgpu_context.cc | 5 +- 11 files changed, 410 insertions(+), 152 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 0ae48ccbd634..97dd2c598463 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -2,20 +2,17 @@ // Licensed under the MIT License. #include "core/providers/webgpu/math/unary_elementwise_ops.h" -#include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" namespace onnxruntime { namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddVariable(ProgramVariableScope::Input, - "x", - ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), - 1); - const auto& output = shader.AddVariable(ProgramVariableScope::Output, - "y", - ToProgramVariableDataType(Outputs()[0]->GetElementType(), 4), - 1); + const auto& input = shader.AddInput("x", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + ShaderVariable::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", + ToProgramVariableDataType(Outputs()[0].tensor->GetElementType(), 4), + ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), "let a = ", input.GetByOffset("global_idx"), ";\n", @@ -27,11 +24,12 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { const auto* input_tensor = context.Input(0); auto* output_tensor = context.Output(0, input_tensor->Shape()); - SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; - UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_}; + int64_t size = input_tensor->Shape().Size(); + SafeInt vec_size = (size + 3) / 4; + UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program - .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) - .Outputs({output_tensor}) + .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}}}) + .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}}}) .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .UniformVariables({ {static_cast(vec_size)}, @@ -91,21 +89,21 @@ WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) constexpr char ErfImpl[] = R"( -const r0: x_value_t = 0.3275911; -const r1: x_value_t = 0.254829592; -const r2: x_value_t = -0.284496736; -const r3: x_value_t = 1.421413741; -const r4: x_value_t = -1.453152027; -const r5: x_value_t = 1.061405429; - -fn erf_v(v: vec4) -> vec4 { +const r0 = 0.3275911; +const r1 = 0.254829592; +const r2 = -0.284496736; +const r3 = 1.421413741; +const r4 = -1.453152027; +const r5 = 1.061405429; + +fn erf_v(v: x_value_t) -> x_value_t { let absv = abs(v); let x = 1.0 / (1.0 + r0 * absv); return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); } )"; -WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl) +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) @@ -117,15 +115,19 @@ WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) +constexpr char HardSigmoidImpl[] = R"( +fn hard_sigmoid_v(v: x_value_t) -> x_value_t { + let alpha = x_element_t(uniforms.f32_attr[0]); + let beta_v = vec4(uniforms.f32_attr[1]); + return max(vec4(0.0), + min(vec4(1.0), alpha * v + beta_v)); +} +)"; class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) - : UnaryElementwise{ - info, - "HardSigmoid", - // alpha = uniforms.f32_attr[0] - // beta = uniforms.f32_attr[1] - "max(vec4(0.0), min(vec4(1.0), x_value_t(uniforms.f32_attr[0]) * a + vec4(uniforms.f32_attr[1])))"} { + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias} { + // attr[0] is alpha, attr[1] is beta info.GetAttrOrDefault("alpha", attr, 0.2f); info.GetAttrOrDefault("beta", attr + 1, 0.5f); } diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index dbf15248b6b1..2d084bf227f7 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/program.h" namespace onnxruntime { @@ -11,8 +12,8 @@ namespace webgpu { class UnaryElementwiseProgram final : public Program { public: - UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl) - : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl} { + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderVariable::Usage usage) + : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -25,6 +26,7 @@ class UnaryElementwiseProgram final : public Program { private: std::string_view expression_; std::string_view additional_impl_; + ShaderVariable::Usage additional_usage_; }; // TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch @@ -35,10 +37,12 @@ class UnaryElementwise : public WebGpuKernel { UnaryElementwise(const OpKernelInfo& info, const std::string& kernel_name, const std::string& expression, - const std::string& additional_impl = "") : WebGpuKernel{info}, - kernel_name_{kernel_name}, - expression_{expression}, - additional_impl_{additional_impl} {} + const std::string& additional_impl = "", + ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl}, + additional_usage_{usage} {} protected: Status ComputeInternal(ComputeContext& context) const final; @@ -51,6 +55,7 @@ class UnaryElementwise : public WebGpuKernel { std::string kernel_name_; std::string expression_; std::string additional_impl_; + ShaderVariable::Usage additional_usage_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 8ba33bcafb31..91f86d2cf681 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/session/onnxruntime_c_api.h" @@ -49,27 +50,27 @@ ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableD } std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) { - os << ProgramUniformVariableDataTypeName[static_cast(type)]; + os << ProgramUniformVariableDataTypeName[std::underlying_type::type(type)]; return os; } std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) { - os << ProgramConstantDataTypeName[static_cast(type)]; + os << ProgramConstantDataTypeName[std::underlying_type::type(type)]; return os; } -std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency dep) { +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) { bool first = true; - if ((dep & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { + if ((dep & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { os << "Type"; first = false; } - if ((dep & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { + if ((dep & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { if (!first) os << "|"; os << "Rank"; first = false; } - if ((dep & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { + if ((dep & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { if (!first) os << "|"; os << "Shape"; first = false; @@ -81,6 +82,31 @@ std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency dep) { return os; } +int NumberOfComponents(ProgramVariableDataType type) { + switch (type) { + case ProgramVariableDataType::Float32: + case ProgramVariableDataType::Int32: + case ProgramVariableDataType::Uint32: + case ProgramVariableDataType::Int64: + case ProgramVariableDataType::Uint64: + case ProgramVariableDataType::Float16: + return 1; + case ProgramVariableDataType::Vec2Float32: + case ProgramVariableDataType::Vec2Int32: + case ProgramVariableDataType::Vec2Uint32: + case ProgramVariableDataType::Vec2Float16: + return 2; + case ProgramVariableDataType::Vec4Float32: + case ProgramVariableDataType::Vec4Int32: + case ProgramVariableDataType::Vec4Uint32: + case ProgramVariableDataType::Vec4Float16: + case ProgramVariableDataType::Vec4Bool: + return 4; + default: + return -1; + } +} + ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { if (component == 1) { switch (element_type) { @@ -147,7 +173,7 @@ ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { return *this; } -ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { +ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { outputs_.assign(outputs.begin(), outputs.end()); return *this; } diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index d056ee8577f1..c48bdb1a4ff1 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -140,7 +140,7 @@ struct ProgramOverridableConstantDefinition { }; // represents whether the program shader depends on the type, rank, or shape of an input/output tensor -enum class ProgramInputTensorDependency : int { +enum class ProgramTensorMetadataDependency : int { None = 0, Type = 1, Rank = 2, @@ -148,24 +148,47 @@ enum class ProgramInputTensorDependency : int { TypeAndRank = Type | Rank, TypeAndShape = Type | Shape, }; -std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency); +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency); -inline ProgramInputTensorDependency operator|(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency)((int&)a | (int&)b); +inline ProgramTensorMetadataDependency operator|(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a | (int&)b); } -inline ProgramInputTensorDependency operator&(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency)((int&)a & (int&)b); +inline ProgramTensorMetadataDependency operator&(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a & (int&)b); } -inline ProgramInputTensorDependency& operator|=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency&)((int&)a |= (int&)b); +inline ProgramTensorMetadataDependency& operator|=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a |= (int&)b); } -inline ProgramInputTensorDependency& operator&=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency&)((int&)a &= (int&)b); +inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b); } struct ProgramInput { + ProgramInput(const Tensor* tensor) + : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency) + : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) + : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} + const Tensor* tensor; - ProgramInputTensorDependency dependency; + ProgramTensorMetadataDependency dependency; + bool use_override_shape; + TensorShape override_shape; +}; + +struct ProgramOutput { + ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency) + : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) + : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} + + Tensor* tensor; + ProgramTensorMetadataDependency dependency; + bool use_override_shape; + TensorShape override_shape; }; constexpr SafeInt WORKGROUP_SIZE = 64; @@ -205,6 +228,8 @@ enum class ProgramVariableDataType { Vec4Bool, }; +int NumberOfComponents(ProgramVariableDataType type); + ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); namespace detail { @@ -229,7 +254,7 @@ class ProgramBase { // set one or more program inputs ProgramBase& Inputs(std::initializer_list inputs); // set one or more program outputs - ProgramBase& Outputs(std::initializer_list outputs); + ProgramBase& Outputs(std::initializer_list outputs); // set the size of dispatch groups. Y and Z are 1 if not specified. ProgramBase& DispatchGroupSize(uint32_t x); @@ -289,7 +314,7 @@ class ProgramBase { inline const std::string& Name() const { return name_; } inline const std::string& CacheHint() const { return cache_hint_; } inline const std::vector& Inputs() const { return inputs_; } - inline const std::vector& Outputs() const { return outputs_; } + inline const std::vector& Outputs() const { return outputs_; } inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } @@ -310,7 +335,7 @@ class ProgramBase { std::string name_; std::string cache_hint_; std::vector inputs_; - std::vector outputs_; + std::vector outputs_; uint32_t dispatch_group_size_x_; uint32_t dispatch_group_size_y_; diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index a4530910944d..7bea82a1b0c6 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -8,6 +8,29 @@ namespace onnxruntime { namespace webgpu { +namespace { +void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTensorMetadataDependency dependency, bool& first) { + if (first) { + first = false; + } else { + ss << "|"; + } + if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { +#ifndef NDEBUG // if debug build + ss << DataTypeImpl::ToString(tensor.DataType()); +#else + ss << output.tensor->GetElementType(); +#endif + } + ss << ";"; + if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + ss D("Rank=") << tensor.Shape().NumDimensions(); + } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + ss D("Dims=") << tensor.Shape().ToString(); + } +} +} // namespace + std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -34,6 +57,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp x != 0 || y != 0 || z != 0) { ss << ":" D("WorkgroupSize="); // only append non-zero values. zero values are considered as use default + // todo: this is actually not working correctly. revisit this logic. currently even if it's default, the value is not zero and will be appended if (x > 0) { ss << x; } @@ -60,27 +84,17 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp ss << uniform.length; } } + ss << ":" D("Inputs="); first = true; for (const auto& input : program.Inputs()) { - if (first) { - first = false; - } else { - ss << "|"; - } - if ((input.dependency & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { -#ifndef NDEBUG // if debug build - ss << DataTypeImpl::ToString(input.tensor->DataType()); -#else - ss << input.tensor->GetElementType(); -#endif - } - ss << ";"; - if ((input.dependency & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { - ss D("Rank=") << input.tensor->Shape().NumDimensions(); - } else if ((input.dependency & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { - ss D("Dims=") << input.tensor->Shape().ToString(); - } + AppendTensorInfo(ss, *input.tensor, input.dependency, first); + } + + ss << ":" D("Outputs="); + first = true; + for (const auto& output : program.Outputs()) { + AppendTensorInfo(ss, *output.tensor, output.dependency, first); } return ss.str(); diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 00036a915f69..a10412f21f49 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -56,7 +56,8 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); // code is a large std::string that contains the final shader code - auto code = shader_helper.GetFinalSourceCode(); + std::string code; + ORT_RETURN_IF_ERROR(shader_helper.GetFinalSourceCode(code)); LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() #ifndef NDEBUG // if debug build diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 5883696430de..c8c79dd6233d 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -4,10 +4,12 @@ #include #include #include +#include #include "core/session/onnxruntime_c_api.h" #include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" namespace onnxruntime { namespace webgpu { @@ -79,7 +81,145 @@ Status ShaderHelper::Init() { return Status::OK(); } -std::string ShaderHelper::GetFinalSourceCode() { +const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { + const size_t input_index = vars_[std::underlying_type::type(ProgramVariableScope::Input)].size(); + ORT_ENFORCE(input_index < program_.Inputs().size(), + "Too many inputs in the program (", program_.Inputs().size(), ")"); + + const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape + : program_.Inputs()[input_index].tensor->Shape(); + return AddVariableImpl(ProgramVariableScope::Input, name, type, usage, dims); +} + +const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { + const size_t output_index = vars_[std::underlying_type::type(ProgramVariableScope::Output)].size(); + ORT_ENFORCE(output_index < program_.Outputs().size(), + "Too many outputs in the program (", program_.Outputs().size(), ")"); + + const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape + : program_.Outputs()[output_index].tensor->Shape(); + return AddVariableImpl(ProgramVariableScope::Output, name, type, usage, dims); +} + +#ifndef NDEBUG // if debug build +namespace { +Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || + var_type == ProgramVariableDataType::Vec2Float32 || + var_type == ProgramVariableDataType::Vec4Float32, + "Unexpected program variable type ", int(var_type), " for float32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 || + var_type == ProgramVariableDataType::Vec2Float16 || + var_type == ProgramVariableDataType::Vec4Float16, + "Unexpected program variable type ", int(var_type), " for float16 tensor"); + + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || + var_type == ProgramVariableDataType::Vec2Int32 || + var_type == ProgramVariableDataType::Vec4Int32, + "Unexpected program variable type ", int(var_type), " for int32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 || + var_type == ProgramVariableDataType::Vec2Uint32 || + var_type == ProgramVariableDataType::Vec4Uint32, + "Unexpected program variable type ", int(var_type), " for uint32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int64, + "Unexpected program variable type ", int(var_type), " for int64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint64, + "Unexpected program variable type ", int(var_type), " for uint64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Vec4Bool, + "Unexpected program variable type ", int(var_type), " for bool tensor"); + break; + default: + ORT_RETURN_IF(true, "Unsupported data type: ", element_type); + // todo: add int4/uint4 + } + return Status::OK(); +} + +using RankOrShape = std::variant>; + +Status ValidateVariableShape(const TensorShape& origin_shape, + bool use_override_shape, + const TensorShape& override_shape, + int num_components) { + if (use_override_shape) { + // if override shape specified, assert override_size == ceil( origin_size / 4 ) + ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(), + "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); + } else if (num_components > 1) { + // if shape is not overriden, assert origin_shape[-1] % 4 == 0 + ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.Size() - 1] % num_components == 0, + "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); + } + + // if (use_uniform) { + // const auto& rank = std::get(rank_or_shape); + // ORT_RETURN_IF_NOT(rank == SafeInt(override_shape.NumDimensions()), + // "Shader variable rank ", rank, " does not match the tensor shape ", override_shape); + // } else { + // const auto& shape = std::get>(rank_or_shape).get(); + // ORT_RETURN_IF(use_override_shape, "Cannot specify both variable shape and program input/output shape override"); + // ORT_RETURN_IF_NOT(origin_shape.Size() == shape.Size() * num_components, + // "Tensor original shape ", origin_shape, " cannot reshape to ", shape, " with component number ", num_components); + // } + return Status::OK(); +} +} // namespace + +const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, + const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage, + const TensorShape& dims) { + if (scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) { + ORT_ENFORCE(vars_[std::underlying_type::type(ProgramVariableScope::Input)].size() + + vars_[std::underlying_type::type(ProgramVariableScope::Output)].size() < + limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); + } + + if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { + use_f16_ = true; + } + + if (scope == ProgramVariableScope::Local) { + ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); + } + + return vars_[std::underlying_type::type(scope)].emplace_back(name, type, usage, dims); +} + +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), + input.use_override_shape, + input.use_override_shape ? input.override_shape : input.tensor->Shape(), + var.num_components_)); + + return Status::OK(); +} +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + + // todo: add reshaped shape and check + return Status::OK(); +} +#endif + +Status ShaderHelper::GetFinalSourceCode(std::string& code) { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -87,14 +227,14 @@ std::string ShaderHelper::GetFinalSourceCode() { // Section feature enabling // if (use_f16_) { - ORT_ENFORCE(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); + ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; } // // Section constants // - ss << "\nconst workgroup_size_x: u32 = " << program_.WorkgroupSizeX() + ss << "const workgroup_size_x: u32 = " << program_.WorkgroupSizeX() << ";\nconst workgroup_size_y: u32 = " << program_.WorkgroupSizeY() << ";\nconst workgroup_size_z: u32 = " << program_.WorkgroupSizeZ() << ";\n"; @@ -122,11 +262,23 @@ std::string ShaderHelper::GetFinalSourceCode() { // // Input/output variables // - int variable_count = 0; - for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + size_t variable_count = 0; + const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; + ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", variable_count, ", Program: ", program_.Inputs().size()); + for (const auto& input : input_vars) { +#ifndef NDEBUG // if debug build + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[variable_count], input)); +#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; } - for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; + ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", variable_count, ", Program: ", program_.Outputs().size()); + for (const auto& output : output_vars) { +#ifndef NDEBUG // if debug build + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[variable_count - input_vars.size()], output)); +#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; } @@ -188,8 +340,8 @@ std::string ShaderHelper::GetFinalSourceCode() { for (const auto& var : var_group) { var.Impl(ss); } - ss << "\n"; } + ss << "\n"; // // Additional Implementation @@ -205,7 +357,8 @@ std::string ShaderHelper::GetFinalSourceCode() { ss << "\n" "}\n"; - return ss.str(); + code = ss.str(); + return Status::OK(); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index ac6dfebfef81..e1f008ff6a90 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -77,12 +77,13 @@ class ShaderHelper final { Status Init(); - const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, int rank = 1) { - return AddVariableImpl(scope, name, type, rank); - } - const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, const TensorShape& dims) { - return AddVariableImpl(scope, name, type, dims); - } + const ShaderVariable& AddInput(const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + + const ShaderVariable& AddOutput(const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); template inline std::ostringstream& AppendImplementation(Strs&&... impl) { @@ -91,8 +92,8 @@ class ShaderHelper final { } template - inline std::ostringstream& MainFunctionBody(Strs&&... body) { - onnxruntime::detail::MakeStringImpl(body_, std::forward(body)...); + inline std::ostringstream& MainFunctionBody(const Strs&... body) { + onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); return body_; } @@ -101,19 +102,6 @@ class ShaderHelper final { } private: - template // T is one of {int, const TensorShape&} - const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, T&& arg) { - ORT_ENFORCE((scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) && - vars_[static_cast(ProgramVariableScope::Input)].size() + vars_[static_cast(ProgramVariableScope::Output)].size() < limits_.maxStorageBuffersPerShaderStage, - "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); - - if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { - use_f16_ = true; - } - - return vars_[static_cast(scope)].emplace_back(name, type, std::forward(arg)); - } - template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} void WriteConstantValue(std::ostringstream& ss, const ConstantType& constant) const { switch (constant.type) { @@ -137,7 +125,18 @@ class ShaderHelper final { } } - std::string GetFinalSourceCode(); + const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, + const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage, + const TensorShape& dims); + +#ifndef NDEBUG // if debug build + Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; + Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; +#endif + + Status GetFinalSourceCode(std::string& code); friend class ProgramManager; const wgpu::Device& device_; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index fda4ad72deb2..9a4ebc80bf66 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -5,6 +5,7 @@ #include #include +#include "core/common/safeint.h" #include "core/providers/webgpu/shader_variable.h" #include "core/providers/webgpu/shader_macros.h" @@ -12,18 +13,15 @@ namespace onnxruntime { namespace webgpu { -ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank) - : name_(name), type_(type), rank_(rank), usage_(UseUniform) { - Init(); -} - -ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims) - : name_(name), type_(type), rank_(static_cast(dims.NumDimensions())), dims_(dims), usage_(None) { - Init(); -} - -void ShaderVariable::Init() { +ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) + : name_(name), + type_(type), + num_components_{NumberOfComponents(type)}, + rank_{SafeInt(dims.NumDimensions())}, + dims_{dims}, + usage_(usage) { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); + ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } void ShaderVariable::Impl(std::ostringstream& ss) const { @@ -31,17 +29,27 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { const std::string value_t = name_ + "_value_t"; const std::string indices_t = name_ + "_indices_t"; + const std::string element_t = name_ + "_element_t"; const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; // Types - SS("alias ", value_t, " = ", ValueType(), ";\n"); - SS("alias ", indices_t, " = ", IndicesType(), ";\n"); + std::string_view value_type = (usage_ & UseValueTypeAlias) ? value_t : ValueType(); + if (usage_ & UseValueTypeAlias) { + SS("alias ", name_, "_value_t = ", ValueType(), ";\n"); + } + std::string_view indices_type = (usage_ & UseIndicesTypeAlias) ? indices_t : IndicesType(); + if (usage_ & UseIndicesTypeAlias) { + SS("alias ", name_, "_indices_t = ", IndicesType(), ";\n"); + } + if (usage_ & UseElementTypeAlias) { + SS("alias ", name_, "_element_t = ", ElementType(), ";\n"); + } // Need shape and strides when (not use uniform) and (any other usage is enabled) if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { - SS("const ", shape, " = ", indices_t, "("); + SS("const ", shape, " = ", indices_type, "("); bool first = true; for (auto dim : dims_.GetDims()) { @@ -54,7 +62,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } ss << ");\n"; - SS("const ", stride, " = ", indices_t, "("); + SS("const ", stride, " = ", indices_type, "("); first = true; for (int i = rank_ - 1; i >= 0; i--) { if (!first) { @@ -69,8 +77,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn o2i_{name}" if (usage_ & UseOffsetToIndices) { if (rank_ >= 2) { - SS("fn o2i_", name_, "(offset : u32)->", indices_t, " {\n"); - SS(" var indices: ", indices_t, ";\n"); + SS("fn o2i_", name_, "(offset : u32)->", indices_type, " {\n"); + SS(" var indices: ", indices_type, ";\n"); SS(" var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_); @@ -88,7 +96,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn i2o_{name}" if (usage_ & UseIndicesToOffset) { if (rank_ >= 2) { - SS("fn i2o_", name_, "(indices : ", indices_t, ")->u32 {\n"); + SS("fn i2o_", name_, "(indices : ", indices_type, ")->u32 {\n"); SS(" return "); for (int i = 0; i < rank_ - 1; i++) { SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); @@ -125,7 +133,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(", value: ", value_t, ") {\n"); + SS(", value: ", value_type, ") {\n"); SS(" set_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -138,7 +146,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn set_{name}_by_indices" if (usage_ & UseSetByIndices) { if (rank_ >= 2) { - SS("fn set_", name_, "_by_indices(indices: ", indices_t, ", value: ", value_t, ") {\n"); + SS("fn set_", name_, "_by_indices(indices: ", indices_type, ", value: ", value_type, ") {\n"); SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); SS("}\n"); } @@ -151,7 +159,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(")->", value_t, " {\n"); + SS(")->", value_type, " {\n"); SS(" return get_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -164,7 +172,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn get_{name}_by_indices" if (usage_ & UseGetByIndices) { if (rank_ >= 2) { - SS("fn get_", name_, "_by_indices(indices: ", indices_t, ")->", value_t, " {\n"); + SS("fn get_", name_, "_by_indices(indices: ", indices_type, ")->", value_type, " {\n"); SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); SS("}\n"); } @@ -248,17 +256,17 @@ std::string_view ShaderVariable::StorageType() const { std::string_view ShaderVariable::ValueType() const { constexpr static const std::string_view VALUE_TYPE[] = { "f32", // f32 - "f32", // vec2f32 - "f32", // vec4f32 + "vec2", // vec2f32 + "vec4", // vec4f32 "f16", // f16 - "f16", // vec2f16 - "f16", // vec4f16 + "vec2", // vec2f16 + "vec4", // vec4f16 "i32", // i32 - "i32", // vec2i32 - "i32", // vec4i32 + "vec2", // vec2i32 + "vec4", // vec4i32 "u32", // u32 - "u32", // vec2u32 - "u32", // vec4u32 + "vec2", // vec2u32 + "vec4", // vec4u32 "i32", // int64 (trancated to i32) "u32", // uint64 (trancated to u32) "vec4", // vec4bool @@ -267,6 +275,28 @@ std::string_view ShaderVariable::ValueType() const { return VALUE_TYPE[static_cast(type_)]; } +std::string_view ShaderVariable::ElementType() const { + constexpr static const std::string_view ELEMENT_TYPE[] = { + "f32", // f32 + "f32", // vec2f32 + "f32", // vec4f32 + "f16", // f16 + "f16", // vec2f16 + "f16", // vec4f16 + "i32", // i32 + "i32", // vec2i32 + "i32", // vec4i32 + "u32", // u32 + "u32", // vec2u32 + "u32", // vec4u32 + "i32", // int64 + "u32", // uint64 + "bool", // vec4bool + }; + + return ELEMENT_TYPE[static_cast(type_)]; +} + std::string ShaderVariable::IndicesType() const { return rank_ < 2 ? "u32" : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 34d767414841..86eaaac5e159 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -5,7 +5,6 @@ #include -#include "core/common/safeint.h" #include "core/framework/tensor_shape.h" #include "core/providers/webgpu/program.h" @@ -39,8 +38,22 @@ std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool i class ShaderVariable { public: - ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank); - ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims); + enum Usage : uint32_t { + None = 0, // no usage. this means no additional implementation code will be generated. + UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) + UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) + UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) + UseOffsetToIndices = 8, // use implementation of fn o2i_{name} + UseIndicesToOffset = 16, // use implementation of fn i2o_{name} + UseBroadcastedIndicesToOffset = 32, // use implementation of fn {broadcasted_result_name}_bi2o_{name} + UseSet = 64, // use implementation of fn set_{name} + UseSetByIndices = 128, // use implementation of fn set_{name}_by_indices + UseGet = 256, // use implementation of fn get_{name} + UseGetByIndices = 512, // use implementation of fn get_{name}_by_indices + UseUniform = 1024, // use uniform for shape and stride + }; + + ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims); ShaderVariable(ShaderVariable&&) = default; ShaderVariable& operator=(ShaderVariable&&) = default; @@ -107,18 +120,6 @@ class ShaderVariable { inline std::string GetByOffset(TOffset&& offset) const; private: - enum Usage : uint32_t { - None = 0, - UseOffsetToIndices = 1, - UseIndicesToOffset = 2, - UseBroadcastedIndicesToOffset = 4, - UseSet = 8, - UseSetByIndices = 16, - UseGet = 32, - UseGetByIndices = 64, - UseUniform = 128, - }; - friend ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b); friend ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b); friend ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b); @@ -126,7 +127,6 @@ class ShaderVariable { ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); - void Init(); void Impl(std::ostringstream& ss) const; std::string GetByOffsetImpl(std::string_view offset) const; @@ -134,10 +134,12 @@ class ShaderVariable { std::string_view StorageType() const; std::string_view ValueType() const; + std::string_view ElementType() const; std::string IndicesType() const; std::string name_; ProgramVariableDataType type_; + int num_components_; int rank_; TensorShape dims_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index d2428d8bb7be..e5852d9a3a6a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -144,7 +144,8 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog }), "All inputs must be tensors on WebGPU buffers."); - ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](Tensor* tensor) { + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { + const auto* tensor = output.tensor; return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && @@ -288,7 +289,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(const_cast(input.tensor->DataRaw()))}); } for (const auto& output : outputs) { - bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output->MutableDataRaw())}); + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output.tensor->MutableDataRaw())}); } if (uniform_buffer) { bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer}); From c89159d4407685451a0284a0569cbc69b8be09da Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:28:30 -0700 Subject: [PATCH 033/114] fix workgroup_size, cache key stringnify and indices type --- onnxruntime/core/providers/webgpu/program.cc | 6 +++--- onnxruntime/core/providers/webgpu/program_cache_key.cc | 10 +++++----- onnxruntime/core/providers/webgpu/shader_helper.cc | 8 +++----- onnxruntime/core/providers/webgpu/shader_variable.cc | 4 ++-- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 91f86d2cf681..4a5785dc4def 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -163,9 +163,9 @@ ProgramBase::ProgramBase(const std::string& name) dispatch_group_size_x_{0}, dispatch_group_size_y_{0}, dispatch_group_size_z_{0}, - workgroup_size_x_{WORKGROUP_SIZE}, - workgroup_size_y_{1}, - workgroup_size_z_{1} { + workgroup_size_x_{0}, + workgroup_size_y_{0}, + workgroup_size_z_{0} { } ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 7bea82a1b0c6..944fbb0bf8a5 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -13,7 +13,7 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTenso if (first) { first = false; } else { - ss << "|"; + ss << '|'; } if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { #ifndef NDEBUG // if debug build @@ -21,12 +21,12 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTenso #else ss << output.tensor->GetElementType(); #endif + ss << ';'; } - ss << ";"; if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { - ss D("Rank=") << tensor.Shape().NumDimensions(); - } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { ss D("Dims=") << tensor.Shape().ToString(); + } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + ss D("Rank=") << tensor.Shape().NumDimensions(); } } } // namespace @@ -49,7 +49,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp // append custom cache hint if any if (auto& hint = program.CacheHint(); !hint.empty()) { - ss << "[" D("CacheHint=") << hint << "]"; + ss << '[' D("CacheHint=") << hint << ']'; } // append workgroup size if overridden diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index c8c79dd6233d..054910a7dd57 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -35,12 +35,10 @@ Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here // validate workgroup size - auto workgroup_size_x = program_.WorkgroupSizeX(); - auto workgroup_size_y = program_.WorkgroupSizeY(); - auto workgroup_size_z = program_.WorkgroupSizeZ(); + auto workgroup_size_x = program_.WorkgroupSizeX() == 0 ? WORKGROUP_SIZE : program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY() == 0 ? 1 : program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ() == 0 ? 1 : program_.WorkgroupSizeZ(); - ORT_RETURN_IF_NOT(workgroup_size_x > 0 && workgroup_size_y > 0 && workgroup_size_z > 0, - "Workgroup size must be greater than 0"); ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && workgroup_size_z <= limits_.maxComputeWorkgroupSizeZ, diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 9a4ebc80bf66..ef80fd3c57f6 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -28,7 +28,6 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Start generating code const std::string value_t = name_ + "_value_t"; - const std::string indices_t = name_ + "_indices_t"; const std::string element_t = name_ + "_element_t"; const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; @@ -36,10 +35,11 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Types std::string_view value_type = (usage_ & UseValueTypeAlias) ? value_t : ValueType(); + const std::string indices_type = (usage_ & UseIndicesTypeAlias) ? name_ + "_indices_t" : IndicesType(); + if (usage_ & UseValueTypeAlias) { SS("alias ", name_, "_value_t = ", ValueType(), ";\n"); } - std::string_view indices_type = (usage_ & UseIndicesTypeAlias) ? indices_t : IndicesType(); if (usage_ & UseIndicesTypeAlias) { SS("alias ", name_, "_indices_t = ", IndicesType(), ";\n"); } From 5ea5936a2e0927bd2f54242de83087e21376f75e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 2 Sep 2024 20:37:33 -0700 Subject: [PATCH 034/114] shape_uniforms preparation --- .../core/providers/webgpu/program_manager.cc | 9 ++++-- .../core/providers/webgpu/program_manager.h | 6 ++-- .../core/providers/webgpu/shader_helper.cc | 23 ++++++++++++++ .../core/providers/webgpu/shader_helper.h | 31 ++++++++++++++----- .../core/providers/webgpu/webgpu_context.cc | 15 +++++++-- 5 files changed, 69 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index a10412f21f49..ff956b46697c 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -13,8 +13,8 @@ namespace onnxruntime { namespace webgpu { -ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline) - : name{program.Name()}, compute_pipeline{compute_pipeline} {} +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms) + : name{program.Name()}, compute_pipeline{compute_pipeline}, shape_uniforms{shape_uniforms} {} Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -43,7 +43,8 @@ Status ProgramManager::Build(const ProgramBase& program, uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, - wgpu::ComputePipeline& compute_pipeline) const { + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniforms) const { ShaderHelper shader_helper{program, program_metadata, device_, @@ -55,6 +56,8 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + ORT_RETURN_IF_ERROR(shader_helper.AppendShapeUniformValues(shape_uniforms)); + // code is a large std::string that contains the final shader code std::string code; ORT_RETURN_IF_ERROR(shader_helper.GetFinalSourceCode(code)); diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 087c75bfee77..5f4c28a140a5 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -23,10 +23,11 @@ namespace webgpu { class ProgramArtifact { public: - ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline); + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms); std::string name; wgpu::ComputePipeline compute_pipeline; + std::vector shape_uniforms; ProgramArtifact(ProgramArtifact&&) = default; ProgramArtifact& operator=(ProgramArtifact&&) = default; @@ -49,7 +50,8 @@ class ProgramManager { uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, - wgpu::ComputePipeline& compute_pipeline) const; + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniforms) const; const ProgramArtifact* Get(const std::string& key) const; const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 054910a7dd57..cf040ddfa427 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -215,6 +215,29 @@ Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderV // todo: add reshaped shape and check return Status::OK(); } + +Status ShaderHelper::AppendShapeUniformValues(std::vector& /*shape_uniforms*/) const { + // TODO: move input/output check(validation) here + // TODO: also check input dependencies with actual usages. + // [deps] [usages] + // input -> use shape && !use_uniform -> OK + // input -> use shape && use_uniform -> err + // input -> !use shape && !use_uniform -> err: must use shape if not using uniform + // input -> !use shape && use_uniform -> + // use_rank -> OK + // !use_rank -> err: must use rank + // + // output -> do not check + + // TODO: tensor shape and strides adding to uniforms (in front) + // when: use_rank && rank >=2 + // info need for codegen: [rank, variable name] content -> "vecN {name}_shape, vecN {name}_strides" + // // further optimization: strides can be vecN-1 + // minimal info stored in artifact: array<[rank, variable name] | not_use > + + return Status::OK(); +} + #endif Status ShaderHelper::GetFinalSourceCode(std::string& code) { diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index e1f008ff6a90..dc46b6275426 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -77,28 +77,41 @@ class ShaderHelper final { Status Init(); + // Add an input variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddInput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + // Add an output variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddOutput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + // Append additional implementation code to the shader. + // + // can be called multiple times. template - inline std::ostringstream& AppendImplementation(Strs&&... impl) { + inline ShaderHelper& AppendImplementation(Strs&&... impl) { onnxruntime::detail::MakeStringImpl(additional_implementation_, std::forward(impl)...); - return additional_implementation_; + return *this; } + // Set the main function body of the shader. + // + // can be called only once. template - inline std::ostringstream& MainFunctionBody(const Strs&... body) { + inline void MainFunctionBody(const Strs&... body) { + ORT_ENFORCE(!body_set_, "Main function body is already set"); onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); - return body_; + body_set_ = true; } - std::string GuardAgainstOutOfBoundsWorkgroupSizes(const std::string& size) const { - return " if (global_idx >= " + size + ") { return; }\n"; + std::string GuardAgainstOutOfBoundsWorkgroupSizes(std::string_view size) const { + return MakeStringWithClassicLocale(" if (global_idx >= ", size, ") { return; }\n"); } private: @@ -135,7 +148,9 @@ class ShaderHelper final { Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; #endif - + // Append the uniform values of all shape variables. Including shape/strides of input/output variables, + // if UseUniform is set in the usage of the variable. + Status AppendShapeUniformValues(std::vector& shape_uniforms) const; Status GetFinalSourceCode(std::string& code); friend class ProgramManager; @@ -149,11 +164,11 @@ class ShaderHelper final { const ProgramMetadata& program_metadata_; std::array, static_cast(ProgramVariableScope::Count)> vars_; - std::ostringstream ss2; std::ostringstream additional_implementation_; std::ostringstream body_; bool use_f16_ = false; + bool body_set_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index e5852d9a3a6a..638bc7c1f7c0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -212,6 +212,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog const auto* program_artifact = program_mgr_->Get(key); if (program_artifact == nullptr) { wgpu::ComputePipeline compute_pipeline; + std::vector shape_uniforms; auto status = program_mgr_->Build(program, metadata, #ifndef NDEBUG // if debug build @@ -220,15 +221,25 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog x, y, z, - compute_pipeline); + compute_pipeline, + shape_uniforms); ORT_RETURN_IF_ERROR(status); - program_artifact = program_mgr_->Set(key, ProgramArtifact{program, std::move(compute_pipeline)}); + program_artifact = program_mgr_->Set(key, ProgramArtifact{program, + std::move(compute_pipeline), + std::move(shape_uniforms)}); #ifndef NDEBUG // if debug build ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); #endif } // prepare uniform info + + // TODO: also append artifacts uniform info and fill in actual input/output (override) shape value + + // foreach (uniform in artifact) { + // check if match; + // if match, create ProgramUniformVariableValue + // } const auto& uniforms = program.UniformVariables(); size_t current_offset = 0; std::vector> uniform_and_offsets; From 7d8305445379227709457b46389d02bf77dd048f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 01:06:34 -0700 Subject: [PATCH 035/114] allow uniforms of input/output shape/stride being added automatically --- .../providers/webgpu/program_cache_key.cc | 1 - .../core/providers/webgpu/program_manager.cc | 14 +- .../core/providers/webgpu/program_manager.h | 10 +- .../core/providers/webgpu/shader_helper.cc | 174 ++++++++++++------ .../core/providers/webgpu/shader_helper.h | 19 +- .../core/providers/webgpu/webgpu_context.cc | 47 +++-- 6 files changed, 177 insertions(+), 88 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 944fbb0bf8a5..c6ab16a73423 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -57,7 +57,6 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp x != 0 || y != 0 || z != 0) { ss << ":" D("WorkgroupSize="); // only append non-zero values. zero values are considered as use default - // todo: this is actually not working correctly. revisit this logic. currently even if it's default, the value is not zero and will be appended if (x > 0) { ss << x; } diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index ff956b46697c..3e4fbd33a6bd 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/common/common.h" #include "core/common/safeint.h" @@ -13,8 +15,10 @@ namespace onnxruntime { namespace webgpu { -ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms) - : name{program.Name()}, compute_pipeline{compute_pipeline}, shape_uniforms{shape_uniforms} {} +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks) + : name{program.Name()}, + compute_pipeline{compute_pipeline}, + shape_uniform_ranks{shape_uniform_ranks} {} Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -44,7 +48,7 @@ Status ProgramManager::Build(const ProgramBase& program, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, wgpu::ComputePipeline& compute_pipeline, - std::vector& shape_uniforms) const { + std::vector& shape_uniform_ranks) const { ShaderHelper shader_helper{program, program_metadata, device_, @@ -56,11 +60,11 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); - ORT_RETURN_IF_ERROR(shader_helper.AppendShapeUniformValues(shape_uniforms)); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputsAndOutputs()); // code is a large std::string that contains the final shader code std::string code; - ORT_RETURN_IF_ERROR(shader_helper.GetFinalSourceCode(code)); + ORT_RETURN_IF_ERROR(shader_helper.GenerateSourceCode(code, shape_uniform_ranks)); LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() #ifndef NDEBUG // if debug build diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 5f4c28a140a5..782788910e3a 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -23,11 +23,11 @@ namespace webgpu { class ProgramArtifact { public: - ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms); + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks); - std::string name; - wgpu::ComputePipeline compute_pipeline; - std::vector shape_uniforms; + const std::string name; + const wgpu::ComputePipeline compute_pipeline; + const std::vector shape_uniform_ranks; ProgramArtifact(ProgramArtifact&&) = default; ProgramArtifact& operator=(ProgramArtifact&&) = default; @@ -51,7 +51,7 @@ class ProgramManager { uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, wgpu::ComputePipeline& compute_pipeline, - std::vector& shape_uniforms) const; + std::vector& shape_uniform_ranks) const; const ProgramArtifact* Get(const std::string& key) const; const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index cf040ddfa427..d06a6573ab2b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -35,9 +35,9 @@ Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here // validate workgroup size - auto workgroup_size_x = program_.WorkgroupSizeX() == 0 ? WORKGROUP_SIZE : program_.WorkgroupSizeX(); - auto workgroup_size_y = program_.WorkgroupSizeY() == 0 ? 1 : program_.WorkgroupSizeY(); - auto workgroup_size_z = program_.WorkgroupSizeZ() == 0 ? 1 : program_.WorkgroupSizeZ(); + auto workgroup_size_x = program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ(); ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && @@ -163,16 +163,6 @@ Status ValidateVariableShape(const TensorShape& origin_shape, "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); } - // if (use_uniform) { - // const auto& rank = std::get(rank_or_shape); - // ORT_RETURN_IF_NOT(rank == SafeInt(override_shape.NumDimensions()), - // "Shader variable rank ", rank, " does not match the tensor shape ", override_shape); - // } else { - // const auto& shape = std::get>(rank_or_shape).get(); - // ORT_RETURN_IF(use_override_shape, "Cannot specify both variable shape and program input/output shape override"); - // ORT_RETURN_IF_NOT(origin_shape.Size() == shape.Size() * num_components, - // "Tensor original shape ", origin_shape, " cannot reshape to ", shape, " with component number ", num_components); - // } return Status::OK(); } } // namespace @@ -211,36 +201,75 @@ Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVar } Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); - - // todo: add reshaped shape and check + ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), + output.use_override_shape, + output.use_override_shape ? output.override_shape : output.tensor->Shape(), + var.num_components_)); return Status::OK(); } -Status ShaderHelper::AppendShapeUniformValues(std::vector& /*shape_uniforms*/) const { - // TODO: move input/output check(validation) here - // TODO: also check input dependencies with actual usages. - // [deps] [usages] - // input -> use shape && !use_uniform -> OK - // input -> use shape && use_uniform -> err - // input -> !use shape && !use_uniform -> err: must use shape if not using uniform - // input -> !use shape && use_uniform -> - // use_rank -> OK - // !use_rank -> err: must use rank - // - // output -> do not check +Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { + const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; + const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; + + // Validate input/output as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", input_vars.size(), ", Program: ", program_.Inputs().size()); + ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", output_vars.size(), ", Program: ", program_.Outputs().size()); + + for (size_t i = 0; i < input_vars.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate input shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], input_vars[i])); +#endif - // TODO: tensor shape and strides adding to uniforms (in front) - // when: use_rank && rank >=2 - // info need for codegen: [rank, variable name] content -> "vecN {name}_shape, vecN {name}_strides" - // // further optimization: strides can be vecN-1 - // minimal info stored in artifact: array<[rank, variable name] | not_use > + // check input dependencies with actual usages. + auto usage = input_vars[i].usage_; + bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto dependency = program_.Inputs()[i].dependency; + bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + if (use_uniform) { + ORT_RETURN_IF_NOT((use_rank || input_vars[i].rank_ < 2) && !use_shape, + "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); + // If you want neither hard-coded shape nor shape uniform, set UseUniform with a flattened shape (rank=1). + // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + } + } + + for (size_t i = 0; i < output_vars.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate output shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], output_vars[i])); +#endif + + // check output dependencies with actual usages. + auto usage = output_vars[i].usage_; + bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto dependency = program_.Outputs()[i].dependency; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (use_uniform) { + // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not + // necessarily a part of the cache key. + ORT_RETURN_IF_NOT(!use_shape, + "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + } + } return Status::OK(); } #endif -Status ShaderHelper::GetFinalSourceCode(std::string& code) { +Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -255,9 +284,10 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { // // Section constants // - ss << "const workgroup_size_x: u32 = " << program_.WorkgroupSizeX() - << ";\nconst workgroup_size_y: u32 = " << program_.WorkgroupSizeY() - << ";\nconst workgroup_size_z: u32 = " << program_.WorkgroupSizeZ() << ";\n"; + ss << "const workgroup_size_x: u32 = " << (program_.WorkgroupSizeX() == 0 ? uint32_t(WORKGROUP_SIZE) : program_.WorkgroupSizeX()) + << ";\nconst workgroup_size_y: u32 = " << (program_.WorkgroupSizeY() == 0 ? uint32_t(1) : program_.WorkgroupSizeY()) + << ";\nconst workgroup_size_z: u32 = " << (program_.WorkgroupSizeZ() == 0 ? uint32_t(1) : program_.WorkgroupSizeZ()) + << ";\n"; for (const auto& constant : program_metadata_.constants) { ss << "const " << constant.name << ": " << constant.type << " = "; @@ -285,44 +315,44 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { // size_t variable_count = 0; const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; - ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), - "Mismatched input variable count. Shader: ", variable_count, ", Program: ", program_.Inputs().size()); for (const auto& input : input_vars) { -#ifndef NDEBUG // if debug build - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[variable_count], input)); -#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; } const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; - ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), - "Mismatched output variable count. Shader: ", variable_count, ", Program: ", program_.Outputs().size()); for (const auto& output : output_vars) { -#ifndef NDEBUG // if debug build - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[variable_count - input_vars.size()], output)); -#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; } // // uniform variables // - if (std::any_of(program_.UniformVariables().cbegin(), - program_.UniformVariables().cend(), - [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { - bool first = true; - ss << "struct Uniforms {"; - size_t uniform_count = program_.UniformVariables().size(); - for (size_t i = 0; i < uniform_count; i++) { - const auto& uniform_def = program_metadata_.uniform_variables[i]; - const auto& uniform_value = program_.UniformVariables()[i]; + // store shape uniform ranks in shape_uniform_ranks + bool use_any_shape_uniform = false; + ORT_ENFORCE(shape_uniform_ranks.size() == 0); + shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); + + for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && input.rank_ > 1; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); + } + for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && output.rank_ > 1; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); + } - const auto& name = uniform_def.name; - const auto& data_type = uniform_def.data_type; - const auto length = uniform_value.length; + if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), + program_.UniformVariables().cend(), + [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { + bool first = true; + ss << "struct Uniforms {"; + // lambda append_uniform is used to append one uniform variable to the uniform struct + auto append_uniform = [&ss, &first](std::string_view name, ProgramUniformVariableDataType data_type, size_t length) { if (length == 0) { - continue; + return; } if (first) { @@ -346,6 +376,30 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { } else { ss << data_type; } + }; + + for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + if (input.rank_ > 1) { + std::string shape = input.name_ + "_shape"; + std::string stride = input.name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, input.rank_); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, input.rank_); + } + } + + for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + if (output.rank_ > 1) { + std::string shape = output.name_ + "_shape"; + std::string stride = output.name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, output.rank_); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, output.rank_); + } + } + + for (size_t i = 0; i < program_.UniformVariables().size(); i++) { + const auto& uniform_def = program_metadata_.uniform_variables[i]; + const auto& uniform_value = program_.UniformVariables()[i]; + append_uniform(uniform_def.name, uniform_def.data_type, uniform_value.length); } ss << "\n};\n" @@ -368,13 +422,11 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { // Additional Implementation // ss << additional_implementation_.str(); - additional_implementation_.str(""); // // Main Function Body // ss << body_.str(); - body_.str(""); ss << "\n" "}\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index dc46b6275426..bb04c4ad628a 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -148,10 +148,21 @@ class ShaderHelper final { Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; #endif - // Append the uniform values of all shape variables. Including shape/strides of input/output variables, - // if UseUniform is set in the usage of the variable. - Status AppendShapeUniformValues(std::vector& shape_uniforms) const; - Status GetFinalSourceCode(std::string& code); + + Status ShaderHelper::ValidateShapeForInputsAndOutputs() const; + + // Generate source code. + // + // This function: + // - performs validation if neccessary, + // - appends the ranks for variables to the shape_uniform_ranks. + // (The rank value is zero if no uniform is needed for the variable.) + // - generates the final source code. + // + // \param code The generated full WGSL source code. + // \param shape_uniform_ranks The ranks for variables that need a uniform for the shape. + // + Status GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const; friend class ProgramManager; const wgpu::Device& device_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 638bc7c1f7c0..7c9763d6937f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -212,7 +212,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog const auto* program_artifact = program_mgr_->Get(key); if (program_artifact == nullptr) { wgpu::ComputePipeline compute_pipeline; - std::vector shape_uniforms; + std::vector shape_uniform_ranks; auto status = program_mgr_->Build(program, metadata, #ifndef NDEBUG // if debug build @@ -222,29 +222,52 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog y, z, compute_pipeline, - shape_uniforms); + shape_uniform_ranks); ORT_RETURN_IF_ERROR(status); program_artifact = program_mgr_->Set(key, ProgramArtifact{program, std::move(compute_pipeline), - std::move(shape_uniforms)}); + std::move(shape_uniform_ranks)}); #ifndef NDEBUG // if debug build ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); #endif } - // prepare uniform info + // prepare shape uniforms for shader variables (if any) and user defined uniforms + std::vector shape_uniforms; + shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), + "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), + ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + for (size_t i = 0; i < program_artifact->shape_uniform_ranks.size(); ++i) { + SafeInt expected_rank = program_artifact->shape_uniform_ranks[i]; + if (expected_rank > 0) { + const auto& shape = i < inputs.size() ? (inputs[i].use_override_shape ? inputs[i].override_shape + : inputs[i].tensor->Shape()) + : (outputs[i - inputs.size()].use_override_shape ? outputs[i - inputs.size()].override_shape + : outputs[i - inputs.size()].tensor->Shape()); + ORT_RETURN_IF(expected_rank != shape.NumDimensions(), + "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", (int)expected_rank, + ", Actual: ", shape.NumDimensions()); + + std::vector dims(shape.NumDimensions()); + std::vector stride(shape.NumDimensions()); + for (size_t j = 0; j < shape.NumDimensions(); ++j) { + dims[j] = SafeInt(shape[j]); + stride[j] = SafeInt(shape.SizeFromDimension(j)); + } - // TODO: also append artifacts uniform info and fill in actual input/output (override) shape value + shape_uniforms.emplace_back(gsl::make_span(dims)); + shape_uniforms.emplace_back(gsl::make_span(stride)); + } + } - // foreach (uniform in artifact) { - // check if match; - // if match, create ProgramUniformVariableValue - // } - const auto& uniforms = program.UniformVariables(); + const size_t uniform_count = shape_uniforms.size() + program.UniformVariables().size(); size_t current_offset = 0; std::vector> uniform_and_offsets; - uniform_and_offsets.reserve(uniforms.size()); - for (const auto& uniform : uniforms) { + uniform_and_offsets.reserve(uniform_count); + for (size_t i = 0; i < uniform_count; i++) { + const auto& uniform = i < shape_uniforms.size() ? shape_uniforms[i] + : program.UniformVariables()[i - shape_uniforms.size()]; bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; size_t length = uniform.length; From 1d53ac89429586768170aaef23133983b0e78bc3 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 02:18:41 -0700 Subject: [PATCH 036/114] fix build (linux) --- onnxruntime/core/providers/webgpu/shader_helper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index bb04c4ad628a..ca1bf9ce7ff5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -149,7 +149,7 @@ class ShaderHelper final { Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; #endif - Status ShaderHelper::ValidateShapeForInputsAndOutputs() const; + Status ValidateShapeForInputsAndOutputs() const; // Generate source code. // From 4d52602a208daae68fab6510eabac5eeb5aac87f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 02:21:21 -0700 Subject: [PATCH 037/114] fix stride --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 7c9763d6937f..755ebbfd174c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -253,7 +253,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog std::vector stride(shape.NumDimensions()); for (size_t j = 0; j < shape.NumDimensions(); ++j) { dims[j] = SafeInt(shape[j]); - stride[j] = SafeInt(shape.SizeFromDimension(j)); + stride[j] = SafeInt(shape.SizeFromDimension(j + 1)); } shape_uniforms.emplace_back(gsl::make_span(dims)); From 3761aad4bd668af38813a330959d70ed30fb2060 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 02:38:24 -0700 Subject: [PATCH 038/114] fix "{res_name}_bi2o_{name}" --- onnxruntime/core/providers/webgpu/shader_variable.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index ef80fd3c57f6..d19116f57099 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -116,11 +116,11 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS(" return 0;\n"); } else { SS(" return "); - for (int i = 0; i < rank_ - 1; i++) { + for (int i = rank_ - 1; i >= 0; i--) { auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); } - SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + SS("0;\n"); } SS("}\n"); } From 351da844d364d7c762deea845e7e4a81d816c8b2 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 4 Sep 2024 02:57:38 +0800 Subject: [PATCH 039/114] Add Expand operator (#21933) ### Description ### Motivation and Context --- .../core/providers/webgpu/shader_helper.cc | 4 +- .../core/providers/webgpu/shader_variable.cc | 4 +- .../core/providers/webgpu/tensor/expand.cc | 95 +++++++++++++++++++ .../core/providers/webgpu/tensor/expand.h | 30 ++++++ .../webgpu/webgpu_execution_provider.cc | 4 +- 5 files changed, 131 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/expand.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/expand.h diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index d06a6573ab2b..e6ae5ae0d940 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -379,7 +379,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha }; for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - if (input.rank_ > 1) { + if (input.rank_ > 1 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = input.name_ + "_shape"; std::string stride = input.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, input.rank_); @@ -388,7 +388,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - if (output.rank_ > 1) { + if (output.rank_ > 1 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = output.name_ + "_shape"; std::string stride = output.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, output.rank_); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index d19116f57099..a652d720dbf7 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -64,11 +64,11 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("const ", stride, " = ", indices_type, "("); first = true; - for (int i = rank_ - 1; i >= 0; i--) { + for (int i = 1; i <= rank_; i++) { if (!first) { ss << ","; } - ss << dims_.SizeToDimension(i); + ss << dims_.SizeFromDimension(i); first = false; } ss << ");\n"; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc new file mode 100644 index 000000000000..4d241da54415 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/expand.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) { + size_t lhs_rank = lhs_shape.NumDimensions(); + size_t rhs_rank = rhs_shape.NumDimensions(); + size_t out_rank = std::max(lhs_rank, rhs_rank); + + std::vector output_dims(out_rank, 0); + for (size_t i = 0; i < out_rank; ++i) { + int64_t lhs_dim = 1; + if (i < lhs_rank) + lhs_dim = lhs_shape[lhs_rank - 1 - i]; + int64_t rhs_dim = 1; + if (i < rhs_rank) + rhs_dim = rhs_shape[rhs_rank - 1 - i]; + int64_t max = std::max(lhs_dim, rhs_dim); + int64_t min = std::min(lhs_dim, rhs_dim); + int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. + if (lhs_dim != out_dim && lhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + if (rhs_dim != out_dim && rhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + output_dims[out_rank - 1 - i] = out_dim; + } + out_shape = TensorShape(output_dims); + return Status::OK(); +} +} // namespace + +Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType()), + ShaderVariable::UseUniform); + const auto& output = shader.AddOutput("output", + ToProgramVariableDataType(Outputs()[0].tensor->GetElementType()), + ShaderVariable::UseUniform); + + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", + "let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n", + output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); + + return Status::OK(); +} + +Status Expand::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const auto* input_shape_tensor = context.Input(1); + + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + + auto* output_tensor = context.Output(0, output_shape); + SafeInt vec_size = output_shape.Size(); + ExpandProgram program{"Expand"}; + program + .Inputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .Outputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .UniformVariables({ + {static_cast(vec_size)}, + }); + return context.RunProgram(program); +} + +#define WEBGPU_EXPAND_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +#define WEBGPU_EXPAND_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) + +} // namespace webgpu +}; // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h new file mode 100644 index 000000000000..a5c24f1fa496 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class ExpandProgram final : public Program { + public: + ExpandProgram(const std::string& kernel_name) : Program{kernel_name} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); +}; + +class Expand final : public WebGpuKernel { + public: + Expand(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 202742a1c79b..1ee7a51618f7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -622,8 +622,8 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, From 0b7ce771a7c8008607a223a2e9ed4fceab333b82 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:36:52 -0700 Subject: [PATCH 040/114] support onnxruntime_test_all --- .../webgpu/webgpu_provider_factory.cc | 18 +++++++++--------- .../webgpu/webgpu_provider_factory_creator.h | 6 ++++-- .../core/session/provider_registration.cc | 2 +- onnxruntime/test/providers/base_tester.cc | 3 +++ onnxruntime/test/util/default_providers.cc | 7 ++++++- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index e871b66f1dc9..3848ccfc19f5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -31,7 +31,7 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { WebGpuExecutionProviderInfo info_; }; -std::shared_ptr WebGpuProviderFactoryCreator::Create(const SessionOptions* session_options) { +std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { // // STEP.1 - prepare WebGpuExecutionProviderInfo // @@ -43,7 +43,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( }; std::string preferred_layout_str; - if (session_options->config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { if (preferred_layout_str == kPreferredLayout_NHWC) { webgpu_ep_info.data_layout = DataLayout::NHWC; } else if (preferred_layout_str == kPreferredLayout_NCHW) { @@ -56,7 +56,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( << preferred_layout_str << "\")"; std::string enable_graph_capture_str; - if (session_options->config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { if (enable_graph_capture_str == kkEnableGraphCapture_ON) { webgpu_ep_info.enable_graph_capture = true; } else if (enable_graph_capture_str == kkEnableGraphCapture_OFF) { @@ -67,10 +67,10 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( } LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; - auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, + auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { std::string buffer_cache_mode_str; - if (session_options->config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { return webgpu::BufferCacheMode::Disabled; } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { @@ -104,14 +104,14 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( // int context_id = 0; std::string context_id_str; - if (session_options->config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + if (config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { ORT_ENFORCE(std::errc{} == std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); } size_t webgpu_instance = 0; std::string webgpu_instance_str; - if (session_options->config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + if (config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); @@ -119,7 +119,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( size_t webgpu_adapter = 0; std::string webgpu_adapter_str; - if (session_options->config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { + if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch"); ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec); @@ -127,7 +127,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( size_t webgpu_device = 0; std::string webgpu_device_str; - if (session_options->config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 7fac9234b949..e0030a3ec2a1 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -8,11 +8,13 @@ #include "core/framework/provider_options.h" #include "core/providers/providers.h" +#include "core/providers/webgpu/webgpu_provider_options.h" + namespace onnxruntime { -struct SessionOptions; +struct ConfigOptions; struct WebGpuProviderFactoryCreator { - static std::shared_ptr Create(const SessionOptions* session_options); + static std::shared_ptr Create(const ConfigOptions& config_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index da97cdc25ab1..156b59a7af10 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -135,7 +135,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #endif } else if (strcmp(provider_name, "WebGPU") == 0) { #if defined(USE_WEBGPU) - options->provider_factories.push_back(WebGpuProviderFactoryCreator::Create(&(options->value))); + options->provider_factories.push_back(WebGpuProviderFactoryCreator::Create(options->value.config_options)); #else status = create_not_supported_status(); #endif diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 01de15e6f8ec..dea39bc99d3e 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -657,6 +657,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, kQnnExecutionProvider, kSnpeExecutionProvider, kXnnpackExecutionProvider, + kWebGpuExecutionProvider, }; // need to special case any synthetic EP names in the exclude list @@ -712,6 +713,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultXnnpackExecutionProvider(); else if (provider_type == onnxruntime::kDmlExecutionProvider) execution_provider = DefaultDmlExecutionProvider(); + else if (provider_type == onnxruntime::kWebGpuExecutionProvider) + execution_provider = DefaultWebGpuExecutionProvider(); // skip if execution provider is disabled if (execution_provider == nullptr) diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 871285269daf..c9c64003ddab 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -303,7 +303,12 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { std::unique_ptr DefaultWebGpuExecutionProvider() { #ifdef USE_WEBGPU - return WebGpuProviderFactoryCreator::Create(nullptr)->CreateProvider(); + ConfigOptions config_options{}; + // Disable storage buffer cache + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + webgpu::options::kBufferCacheMode_Disabled) + .IsOK()); + return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); #else return nullptr; #endif From 33726b1aa5f435a0d6e7701b2de6eff29404f66d Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 3 Sep 2024 15:34:00 -0700 Subject: [PATCH 041/114] reflect change in WebGpuProviderFactoryCreator::Create signature (#21971) reflect change in WebGpuProviderFactoryCreator::Create signature --- onnxruntime/python/onnxruntime_pybind_state.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 036585586d9a..01889df8fec1 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1210,7 +1210,7 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kWebGpuExecutionProvider) { #if defined(USE_WEBGPU) - return onnxruntime::WebGpuProviderFactoryCreator::Create(&session_options)->CreateProvider(); + return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options)->CreateProvider(); #endif } else if (type == kCannExecutionProvider) { #ifdef USE_CANN From 50ea9eb959855544f4f42d62073bb3c0ad1505c4 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 3 Sep 2024 16:41:49 -0700 Subject: [PATCH 042/114] compare the content of WEBGPU_BUFFER, not the address (#21967) On linux (not sure about windows) WEBGPU_BUFFER is defined in multiple object files and comparing the address is not sufficient - use strcmp. onnxruntime uses strcmp for the most but there are some other places that compare against address which might make trouble if passed acrross object file boundary. --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 755ebbfd174c..343da693c716 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -140,7 +140,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && - tensor->Location().name == WEBGPU_BUFFER; + !strcmp(tensor->Location().name, WEBGPU_BUFFER); }), "All inputs must be tensors on WebGPU buffers."); @@ -149,7 +149,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && - tensor->Location().name == WEBGPU_BUFFER; + !strcmp(tensor->Location().name, WEBGPU_BUFFER); }), "All outputs must be tensors on WebGPU buffers."); #endif From d6f6148fd58e33287c49ba5c0db13c2cac4d7d30 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:38:32 -0700 Subject: [PATCH 043/114] fix tanh --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 4 +++- .../test/providers/cpu/math/element_wise_ops_test.cc | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 97dd2c598463..9d47cab34729 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -167,7 +167,9 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh(a)") +// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) +// https://github.com/gpuweb/gpuweb/issues/4458 +WEBGPU_ELEMENTWISE_IMPL(Tanh, "sign(a) * (1 - exp(-2 * abs(a))) / (1 + exp(-2 * abs(a)))") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index bd3d21d4929f..4ca915dd394c 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -3016,7 +3016,14 @@ TEST(MathOpTest, Tan) { TEST(MathOpTest, Asin) { OpTester test("Asin"); - float abs_error = DefaultDmlExecutionProvider().get() != nullptr ? 0.0001f : -1.0f; + float abs_error = +#ifdef _WIN32 + // Set abs_error to 0.0001f for built-in function asin() in HLSL based EPs (DML and WebGPU) + DefaultDmlExecutionProvider().get() != nullptr || DefaultWebGpuExecutionProvider().get() != nullptr + ? 0.0001f + : +#endif + -1.0f; TrigFloatTest<::asinf>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}, abs_error); } From 626edafbd87e70cdc4040786c92edda0f8845b82 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:08:50 -0700 Subject: [PATCH 044/114] support size==0 for element wise operators --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 9d47cab34729..079a19221377 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -25,6 +25,9 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { const auto* input_tensor = context.Input(0); auto* output_tensor = context.Output(0, input_tensor->Shape()); int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } SafeInt vec_size = (size + 3) / 4; UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program From bacc54cc09e2140a1aa33fb1202dfd2295ecc8c0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 20:20:58 -0700 Subject: [PATCH 045/114] use shared ComputeBroadcastOutputShape() --- .../core/providers/webgpu/tensor/expand.cc | 34 ++----------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 4d241da54415..53991365d654 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/common.h" + #include "core/providers/webgpu/tensor/expand.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -8,36 +10,6 @@ namespace onnxruntime { namespace webgpu { -namespace { -Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) { - size_t lhs_rank = lhs_shape.NumDimensions(); - size_t rhs_rank = rhs_shape.NumDimensions(); - size_t out_rank = std::max(lhs_rank, rhs_rank); - - std::vector output_dims(out_rank, 0); - for (size_t i = 0; i < out_rank; ++i) { - int64_t lhs_dim = 1; - if (i < lhs_rank) - lhs_dim = lhs_shape[lhs_rank - 1 - i]; - int64_t rhs_dim = 1; - if (i < rhs_rank) - rhs_dim = rhs_shape[rhs_rank - 1 - i]; - int64_t max = std::max(lhs_dim, rhs_dim); - int64_t min = std::min(lhs_dim, rhs_dim); - int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. - if (lhs_dim != out_dim && lhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - if (rhs_dim != out_dim && rhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - output_dims[out_rank - 1 - i] = out_dim; - } - out_shape = TensorShape(output_dims); - return Status::OK(); -} -} // namespace - Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ToProgramVariableDataType(Inputs()[0].tensor->GetElementType()), @@ -61,7 +33,7 @@ Status Expand::ComputeInternal(ComputeContext& context) const { const auto* p_shape = input_shape_tensor->Data(); TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); auto* output_tensor = context.Output(0, output_shape); SafeInt vec_size = output_shape.Size(); From 7ecc5bbaac5e09508dc789add1ce2a8cf96e5338 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:11:08 -0700 Subject: [PATCH 046/114] add workgroup_idx --- onnxruntime/core/providers/webgpu/shader_helper.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index e6ae5ae0d940..245de6d7c2ed 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -67,10 +67,11 @@ Status ShaderHelper::Init() { body_ << ") {\n"; if (is_1d_dispatch) { body_ << " let global_idx = global_id.x;\n" - " let local_idx = local_id.x;\n"; + " let local_idx = local_id.x;\n" + " let workgroup_idx = workgroup_id.x;\n"; } else { - body_ << " let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x)\n" - " * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + body_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; } // init additional implementation string stream From ae836b129c71ee942e5db5a3319eced43dbec279 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:05:04 -0700 Subject: [PATCH 047/114] expose name for shader variable --- .../core/providers/webgpu/shader_variable.cc | 162 +++++++++--------- .../core/providers/webgpu/shader_variable.h | 14 +- 2 files changed, 93 insertions(+), 83 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index a652d720dbf7..0b7a7d390057 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -13,13 +13,76 @@ namespace onnxruntime { namespace webgpu { +namespace { +constexpr static const std::string_view STORAGE_TYPE[] = { + "f32", // f32 + "vec2", // vec2f32 + "vec4", // vec4f32 + "f16", // f16 + "vec2", // vec2f16 + "vec4", // vec4f16 + "i32", // i32 + "vec2", // vec2i32 + "vec4", // vec4i32 + "u32", // u32 + "vec2", // vec2u32 + "vec4", // vec4u32 + "vec2", // int64 + "vec2", // uint64 + "u32", // vec4bool +}; + +constexpr static const std::string_view VALUE_TYPE[] = { + "f32", // f32 + "vec2", // vec2f32 + "vec4", // vec4f32 + "f16", // f16 + "vec2", // vec2f16 + "vec4", // vec4f16 + "i32", // i32 + "vec2", // vec2i32 + "vec4", // vec4i32 + "u32", // u32 + "vec2", // vec2u32 + "vec4", // vec4u32 + "i32", // int64 (trancated to i32) + "u32", // uint64 (trancated to u32) + "vec4", // vec4bool +}; + +constexpr static const std::string_view ELEMENT_TYPE[] = { + "f32", // f32 + "f32", // vec2f32 + "f32", // vec4f32 + "f16", // f16 + "f16", // vec2f16 + "f16", // vec4f16 + "i32", // i32 + "i32", // vec2i32 + "i32", // vec4i32 + "u32", // u32 + "u32", // vec2u32 + "u32", // vec4u32 + "i32", // int64 + "u32", // uint64 + "bool", // vec4bool +}; + +} // namespace + ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) : name_(name), type_(type), num_components_{NumberOfComponents(type)}, rank_{SafeInt(dims.NumDimensions())}, dims_{dims}, - usage_(usage) { + usage_(usage), + indices_type_{rank_ < 2 ? "u32" + : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") + : MakeStringWithClassicLocale("array"))}, + value_type_alias_{name_ + "_value_t"}, + element_type_alias_{name_ + "_element_t"}, + indices_type_alias_{name_ + "_indices_t"} { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } @@ -27,29 +90,23 @@ ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType ty void ShaderVariable::Impl(std::ostringstream& ss) const { // Start generating code - const std::string value_t = name_ + "_value_t"; - const std::string element_t = name_ + "_element_t"; - const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; // Types - std::string_view value_type = (usage_ & UseValueTypeAlias) ? value_t : ValueType(); - const std::string indices_type = (usage_ & UseIndicesTypeAlias) ? name_ + "_indices_t" : IndicesType(); - if (usage_ & UseValueTypeAlias) { - SS("alias ", name_, "_value_t = ", ValueType(), ";\n"); + SS("alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); } if (usage_ & UseIndicesTypeAlias) { - SS("alias ", name_, "_indices_t = ", IndicesType(), ";\n"); + SS("alias ", indices_type_alias_, " = ", indices_type_, ";\n"); } if (usage_ & UseElementTypeAlias) { - SS("alias ", name_, "_element_t = ", ElementType(), ";\n"); + SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } // Need shape and strides when (not use uniform) and (any other usage is enabled) if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { - SS("const ", shape, " = ", indices_type, "("); + SS("const ", shape, " = ", IndicesType(), "("); bool first = true; for (auto dim : dims_.GetDims()) { @@ -62,7 +119,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } ss << ");\n"; - SS("const ", stride, " = ", indices_type, "("); + SS("const ", stride, " = ", IndicesType(), "("); first = true; for (int i = 1; i <= rank_; i++) { if (!first) { @@ -77,8 +134,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn o2i_{name}" if (usage_ & UseOffsetToIndices) { if (rank_ >= 2) { - SS("fn o2i_", name_, "(offset : u32)->", indices_type, " {\n"); - SS(" var indices: ", indices_type, ";\n"); + SS("fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); + SS(" var indices: ", IndicesType(), ";\n"); SS(" var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_); @@ -96,7 +153,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn i2o_{name}" if (usage_ & UseIndicesToOffset) { if (rank_ >= 2) { - SS("fn i2o_", name_, "(indices : ", indices_type, ")->u32 {\n"); + SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); SS(" return "); for (int i = 0; i < rank_ - 1; i++) { SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); @@ -111,7 +168,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // TODO: do we need this if rank < 2? for (const auto& iter : broadcasted_to_) { const auto& broadcasted_result = iter.get(); - SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.IndicesType(), ")->u32 {\n"); + SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); if (rank_ == 0) { SS(" return 0;\n"); } else { @@ -133,7 +190,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(", value: ", value_type, ") {\n"); + SS(", value: ", ValueType(), ") {\n"); SS(" set_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -146,7 +203,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn set_{name}_by_indices" if (usage_ & UseSetByIndices) { if (rank_ >= 2) { - SS("fn set_", name_, "_by_indices(indices: ", indices_type, ", value: ", value_type, ") {\n"); + SS("fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); SS("}\n"); } @@ -159,7 +216,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(")->", value_type, " {\n"); + SS(")->", ValueType(), " {\n"); SS(" return get_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -172,7 +229,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn get_{name}_by_indices" if (usage_ & UseGetByIndices) { if (rank_ >= 2) { - SS("fn get_", name_, "_by_indices(indices: ", indices_type, ")->", value_type, " {\n"); + SS("fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); SS("}\n"); } @@ -232,76 +289,19 @@ std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string } std::string_view ShaderVariable::StorageType() const { - constexpr static const std::string_view STORAGE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "vec2", // int64 - "vec2", // uint64 - "u32", // vec4bool - }; - return STORAGE_TYPE[static_cast(type_)]; } std::string_view ShaderVariable::ValueType() const { - constexpr static const std::string_view VALUE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "i32", // int64 (trancated to i32) - "u32", // uint64 (trancated to u32) - "vec4", // vec4bool - }; - - return VALUE_TYPE[static_cast(type_)]; + return (usage_ & UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; } std::string_view ShaderVariable::ElementType() const { - constexpr static const std::string_view ELEMENT_TYPE[] = { - "f32", // f32 - "f32", // vec2f32 - "f32", // vec4f32 - "f16", // f16 - "f16", // vec2f16 - "f16", // vec4f16 - "i32", // i32 - "i32", // vec2i32 - "i32", // vec4i32 - "u32", // u32 - "u32", // vec2u32 - "u32", // vec4u32 - "i32", // int64 - "u32", // uint64 - "bool", // vec4bool - }; - - return ELEMENT_TYPE[static_cast(type_)]; + return (usage_ & UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; } -std::string ShaderVariable::IndicesType() const { - return rank_ < 2 ? "u32" - : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") - : MakeStringWithClassicLocale("array")); +std::string_view ShaderVariable::IndicesType() const { + return (usage_ & UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; } - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 86eaaac5e159..778017a50dda 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -58,6 +58,9 @@ class ShaderVariable { ShaderVariable(ShaderVariable&&) = default; ShaderVariable& operator=(ShaderVariable&&) = default; + // get the name of the variable. + std::string_view Name() const; + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. inline std::string OffsetToIndices(std::string_view offset_expr) const; @@ -131,11 +134,10 @@ class ShaderVariable { std::string GetByOffsetImpl(std::string_view offset) const; std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; - std::string_view StorageType() const; std::string_view ValueType() const; std::string_view ElementType() const; - std::string IndicesType() const; + std::string_view IndicesType() const; std::string name_; ProgramVariableDataType type_; @@ -146,6 +148,14 @@ class ShaderVariable { mutable Usage usage_; mutable std::vector> broadcasted_to_; + // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. + std::string indices_type_; + + // the alias for the types + std::string value_type_alias_; + std::string element_type_alias_; + std::string indices_type_alias_; + friend class ShaderHelper; }; From 243078b0de15bbfebbf8b639da7d02ea8863ed70 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:08:16 -0700 Subject: [PATCH 048/114] add uniform for 1D variable --- .../core/providers/webgpu/shader_helper.cc | 18 ++--- .../core/providers/webgpu/shader_variable.cc | 66 ++++++++++--------- .../core/providers/webgpu/shader_variable.h | 6 +- .../core/providers/webgpu/webgpu_context.cc | 21 +++--- 4 files changed, 62 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 245de6d7c2ed..cd3507a6439a 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -334,12 +334,12 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && input.rank_ > 1; + bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && input.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && output.rank_ > 1; + bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && output.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); } @@ -380,20 +380,22 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha }; for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - if (input.rank_ > 1 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { + const size_t rank = input.rank_; + if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = input.name_ + "_shape"; std::string stride = input.name_ + "_stride"; - append_uniform(shape, ProgramUniformVariableDataType::Uint32, input.rank_); - append_uniform(stride, ProgramUniformVariableDataType::Uint32, input.rank_); + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - if (output.rank_ > 1 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { + const size_t rank = output.rank_; + if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = output.name_ + "_shape"; std::string stride = output.name_ + "_stride"; - append_uniform(shape, ProgramUniformVariableDataType::Uint32, output.rank_); - append_uniform(stride, ProgramUniformVariableDataType::Uint32, output.rank_); + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } } diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 0b7a7d390057..98720c785481 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -68,6 +68,12 @@ constexpr static const std::string_view ELEMENT_TYPE[] = { "bool", // vec4bool }; +inline std::string GetIndicesType(int rank) { + return rank < 2 ? "u32" + : (rank < 4 ? MakeStringWithClassicLocale("vec", rank, "") + : MakeStringWithClassicLocale("array")); +} + } // namespace ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) @@ -77,9 +83,7 @@ ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType ty rank_{SafeInt(dims.NumDimensions())}, dims_{dims}, usage_(usage), - indices_type_{rank_ < 2 ? "u32" - : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") - : MakeStringWithClassicLocale("array"))}, + indices_type_{GetIndicesType(rank_)}, value_type_alias_{name_ + "_value_t"}, element_type_alias_{name_ + "_element_t"}, indices_type_alias_{name_ + "_indices_t"} { @@ -105,7 +109,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Need shape and strides when (not use uniform) and (any other usage is enabled) - if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { + if (!(usage_ & UseUniform) && (usage_ & ~UseUniform) && rank_ > 0) { SS("const ", shape, " = ", IndicesType(), "("); bool first = true; @@ -119,16 +123,18 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } ss << ");\n"; - SS("const ", stride, " = ", IndicesType(), "("); - first = true; - for (int i = 1; i <= rank_; i++) { - if (!first) { - ss << ","; + if (rank_ > 1) { + SS("const ", stride, " = ", GetIndicesType(rank_ - 1), "("); + first = true; + for (int i = 1; i < rank_; i++) { + if (!first) { + ss << ","; + } + ss << dims_.SizeFromDimension(i); + first = false; } - ss << dims_.SizeFromDimension(i); - first = false; + ss << ");\n"; } - ss << ");\n"; } // Implementation of "fn o2i_{name}" @@ -138,7 +144,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS(" var indices: ", IndicesType(), ";\n"); SS(" var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { - auto current_stride = GetElementAt(stride, i, rank_); + auto current_stride = GetElementAt(stride, i, rank_ - 1); SS(" let dim", i, " = current / ", current_stride, ";\n"); SS(" let rest", i, " = current % ", current_stride, ";\n"); SS(" indices[", i, "] = dim", i, ";\n"); @@ -156,7 +162,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); SS(" return "); for (int i = 0; i < rank_ - 1; i++) { - SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); + SS("indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); } SS("indices[", rank_ - 1, "];\n"); SS("}\n"); @@ -165,21 +171,23 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn {res_name}_bi2o_{name}" if (usage_ & UseBroadcastedIndicesToOffset) { - // TODO: do we need this if rank < 2? - for (const auto& iter : broadcasted_to_) { - const auto& broadcasted_result = iter.get(); - SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); - if (rank_ == 0) { - SS(" return 0;\n"); - } else { - SS(" return "); - for (int i = rank_ - 1; i >= 0; i--) { - auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); - SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); + if (rank_ > 0) { + for (const auto& iter : broadcasted_to_) { + const auto& broadcasted_result = iter.get(); + SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); + if (rank_ == 1) { + SS(" return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); + } else { + SS(" return "); + for (int i = 0; i < rank_ - 1; i++) { + auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); + std::string current_stride = rank_ == 2 ? stride : GetElementAt(stride, i, rank_ - 1); + SS(current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); + } + SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); } - SS("0;\n"); + SS("}\n"); } - SS("}\n"); } } @@ -245,10 +253,8 @@ std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { ORT_THROW("Invalid type"); break; case onnxruntime::webgpu::ProgramVariableDataType::Int64: - ss << "i32(" << name_ << "[" << offset << "].x)"; - break; case onnxruntime::webgpu::ProgramVariableDataType::Uint64: - ss << "u32(" << name_ << "[" << offset << "].x)"; + ss << ElementType() << "(" << name_ << "[" << offset << "].x)"; break; case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: ss << "vec4(bool(" diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 778017a50dda..c6d28975bae3 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -202,13 +202,15 @@ inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { usage_ |= UseBroadcastedIndicesToOffset; broadcasted_to_.push_back(broadcasted_result); - return MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); } template inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { return rank_ == 0 - ? "" + ? "0" : MakeStringWithClassicLocale(name_, "_indices_t(", absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), ')'); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 343da693c716..599ee9bbb82f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -249,15 +249,19 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", (int)expected_rank, ", Actual: ", shape.NumDimensions()); - std::vector dims(shape.NumDimensions()); - std::vector stride(shape.NumDimensions()); - for (size_t j = 0; j < shape.NumDimensions(); ++j) { + std::vector dims(expected_rank); + std::vector stride(expected_rank - 1); + for (size_t j = 0; j < expected_rank; ++j) { dims[j] = SafeInt(shape[j]); - stride[j] = SafeInt(shape.SizeFromDimension(j + 1)); + if (j < expected_rank - 1) { + stride[j] = SafeInt(shape.SizeFromDimension(j + 1)); + } } shape_uniforms.emplace_back(gsl::make_span(dims)); - shape_uniforms.emplace_back(gsl::make_span(stride)); + if (expected_rank > 1) { + shape_uniforms.emplace_back(gsl::make_span(stride)); + } } } @@ -268,14 +272,13 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog for (size_t i = 0; i < uniform_count; i++) { const auto& uniform = i < shape_uniforms.size() ? shape_uniforms[i] : program.UniformVariables()[i - shape_uniforms.size()]; - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; size_t length = uniform.length; - - // skip zero-length uniform - if (length == 0) { + if (length == 0) { // skip zero-length uniform continue; } + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; // https://www.w3.org/TR/WGSL/#alignof size_t base_alignment = is_f16 From 4d48d287feb3ae48ffb01e905fce70411d78416e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:50:08 -0700 Subject: [PATCH 049/114] fix GetElementAt with uniform --- onnxruntime/core/providers/webgpu/shader_variable.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index c6d28975bae3..b8b44de92911 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -15,7 +15,7 @@ namespace webgpu { template std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool is_f16 = false) { // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. - if (var.rfind("uniform.", 0) == 0) { + if (var.rfind("uniforms.", 0) == 0) { if (rank > 4) { if constexpr (std::is_integral_v) { if (is_f16) { From dbe673bebc2e3159360fac62a49b486f5058fe9c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 00:37:28 -0700 Subject: [PATCH 050/114] document update folder --- onnxruntime/core/providers/webgpu/README.md | 80 ++----------------- .../providers/webgpu/docs/Best_Practices.md | 37 +++++++++ .../core/providers/webgpu/docs/Conventions.md | 33 ++++++++ .../How_to_Write_WebGPU_EP_Kernel.md | 0 4 files changed, 75 insertions(+), 75 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/docs/Best_Practices.md create mode 100644 onnxruntime/core/providers/webgpu/docs/Conventions.md rename onnxruntime/core/providers/webgpu/{ => docs}/How_to_Write_WebGPU_EP_Kernel.md (100%) diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md index 999f1fecbda7..fe0d99b1d602 100644 --- a/onnxruntime/core/providers/webgpu/README.md +++ b/onnxruntime/core/providers/webgpu/README.md @@ -4,9 +4,7 @@ This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU E ## Build WebGPU EP -Just append `--use_webgpu --skip_tests` to the `build.bat`/`build.sh` command line. - -NOTE: `--skip_tests` is required for now. All existing tests are for CPU EP anyway so no need to run them. +Just append `--use_webgpu` to the `build.bat`/`build.sh` command line. For linux, a few dependencies need to be installed: ```sh @@ -19,83 +17,15 @@ TODO: add solutions to common problems. ## Development Guide -See [How to write WebGPU EP kernel](./How_to_Write_WebGPU_EP_Kernel.md) for more information. - -## Convention - -### Use "webgpu" other than "wgpu" in this folder - -This is referring to the naming convention of variables, classes and namespace. - -ORT C API is using "wgpu". - -Let's keep it "webgpu" for this folder for now. I have a very good reason to do so: - -- search for "webgpu" in the code base shows the WebGPU EP related code and search for "wgpu" shows the WebGPU API related code. This helps me easier to find the code I want to look at. - -And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") - -### Use macros defined in shader_macros.h - -Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. - -I prefer to use the macro because I feel like it's easier to read. Check the following code: - -```cpp -ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; -``` - -vs. - -```cpp -SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); -``` - -### Use the subfolder for kernel implementation +See [How to write WebGPU EP kernel](./docs/How_to_Write_WebGPU_EP_Kernel.md) for more information. -Operator implementation source code need to be put under a subfolder like "math"/"nn"/"tensor". +## Conventions -See folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples. +See [Conventions](./docs/Conventions.md) for more information. ## Best Practices -### Always use std::ostringstream to generate shader code if possible - -This helps to the performance of code generation. - -For example: - -```cpp -ss << "var " << name << " = " << value << ";\n"; -``` - -is better than - -```cpp -ss << ("var " + name + " = " + value + ";\n"); -``` - -### Avoid creating template class for kernel using data type as template parameter. - -This basically means that we should define class like this: - -```cpp -class Abs : public WebGpuKernel { - ... -}; -``` - -instead of - -```cpp - -template // T is tensor element type -class Abs : public WebGpuKernel { - ... -}; -``` - -This is because we don't really read and use `Tensor::Data()`. Tensor stores a handle to a WebGPU buffer but not a pointer to the data. Using template for data type only increases the binary size with no real benefit. +See [Best Practices](./docs/Best_Practices.md) for more information. ## TODO items diff --git a/onnxruntime/core/providers/webgpu/docs/Best_Practices.md b/onnxruntime/core/providers/webgpu/docs/Best_Practices.md new file mode 100644 index 000000000000..d519292b226d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/docs/Best_Practices.md @@ -0,0 +1,37 @@ +### Always use std::ostringstream to generate shader code if possible + +This helps to the performance of code generation. + +For example: + +```cpp +ss << "var " << name << " = " << value << ";\n"; +``` + +is better than + +```cpp +ss << ("var " + name + " = " + value + ";\n"); +``` + +### Avoid creating template class for kernel using data type as template parameter. + +This basically means that we should define class like this: + +```cpp +class Abs : public WebGpuKernel { + ... +}; +``` + +instead of + +```cpp + +template // T is tensor element type +class Abs : public WebGpuKernel { + ... +}; +``` + +This is because we don't really read and use `Tensor::Data()`. Tensor stores a handle to a WebGPU buffer but not a pointer to the data. Using template for data type only increases the binary size with no real benefit. diff --git a/onnxruntime/core/providers/webgpu/docs/Conventions.md b/onnxruntime/core/providers/webgpu/docs/Conventions.md new file mode 100644 index 000000000000..1a86e508cdda --- /dev/null +++ b/onnxruntime/core/providers/webgpu/docs/Conventions.md @@ -0,0 +1,33 @@ +### Use "webgpu" other than "wgpu" in this folder + +This is referring to the naming convention of variables, classes and namespace. + +ORT C API is using "wgpu". + +Let's keep it "webgpu" for this folder for now. I have a very good reason to do so: + +- search for "webgpu" in the code base shows the WebGPU EP related code and search for "wgpu" shows the WebGPU API related code. This helps me easier to find the code I want to look at. + +And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") + +### Use macros defined in shader_macros.h + +Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. + +I prefer to use the macro because I feel like it's easier to read. Check the following code: + +```cpp +ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; +``` + +vs. + +```cpp +SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); +``` + +### Use the subfolder for kernel implementation + +Operator implementation source code need to be put under a subfolder like "math"/"nn"/"tensor". + +See folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples. diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/docs/How_to_Write_WebGPU_EP_Kernel.md similarity index 100% rename from onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md rename to onnxruntime/core/providers/webgpu/docs/How_to_Write_WebGPU_EP_Kernel.md From 38f182e65e7a312d60eacd4bc29c8d5de19141c0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 01:04:13 -0700 Subject: [PATCH 051/114] fix adapter/device creating: add toggles --- .../external/onnxruntime_external_deps.cmake | 2 + .../core/providers/webgpu/webgpu_context.cc | 61 ++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 6640609aa71d..a8ab4a53b9f3 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -594,6 +594,8 @@ if (onnxruntime_USE_WEBGPU) set(DAWN_FETCH_DEPENDENCIES ON) set(DAWN_ENABLE_INSTALL ON) set(TINT_BUILD_TESTS OFF) + set(DAWN_USE_BUILT_DXC ON) + set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) onnxruntime_fetchcontent_makeavailable(dawn) endif() diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 599ee9bbb82f..276d74905adb 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -17,12 +17,46 @@ namespace onnxruntime { namespace webgpu { +namespace { + +std::vector GetEnabledAdapterToggles() { + // See the description of all the toggles in toggles.cpp + // "use_dxc" for Shader Model 6+ features (e.g. float16) + // "allow_unsafe_apis" for chromium experimental features + constexpr const char* toggles[] = { + "use_dxc", + "allow_unsafe_apis", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector GetEnabledDeviceToggles() { + // Enable / disable other toggles that may affect the performance. + // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" + constexpr const char* toggles[] = { + "skip_validation", + "disable_robustness", + "disable_workgroup_init", + "d3d_disable_ieee_strictness", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector GetDisabledDeviceToggles() { + constexpr const char* toggles[] = { + "lazy_clear_resource_on_first_use", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) { std::vector required_features; constexpr wgpu::FeatureName features[]{ wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, wgpu::FeatureName::TimestampQuery, - wgpu::FeatureName::ShaderF16}; + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; for (auto feature : features) { if (adapter.HasFeature(feature)) { required_features.push_back(feature); @@ -31,7 +65,7 @@ std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& return required_features; } -wgpu::RequiredLimits GetAvailableRequiredLimits(const wgpu::Adapter& adapter) { +wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) { wgpu::RequiredLimits required_limits{}; wgpu::SupportedLimits adapter_limits; ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); @@ -49,6 +83,8 @@ wgpu::RequiredLimits GetAvailableRequiredLimits(const wgpu::Adapter& adapter) { return required_limits; } +} // namespace + void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info) { std::call_once(init_flag_, [this, &webgpu_ep_info]() { // Initialization.Step.1 - Create wgpu::Instance @@ -63,6 +99,13 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info // Initialization.Step.2 - Create wgpu::Adapter if (adapter_ == nullptr) { wgpu::RequestAdapterOptions req_adapter_options = {}; + wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; + req_adapter_options.nextInChain = &adapter_toggles_desc; + + auto enabled_adapter_toggles = GetEnabledAdapterToggles(); + adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size(); + adapter_toggles_desc.enabledToggles = enabled_adapter_toggles.data(); + wgpu::RequestAdapterCallbackInfo req_adapter_callback_info = {}; req_adapter_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; req_adapter_callback_info.callback = [](WGPURequestAdapterStatus status, @@ -79,11 +122,23 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info // Initialization.Step.3 - Create wgpu::Device if (device_ == nullptr) { wgpu::DeviceDescriptor device_desc = {}; + wgpu::DawnTogglesDescriptor device_toggles_desc = {}; + device_desc.nextInChain = &device_toggles_desc; + + auto enabled_device_toggles = GetEnabledDeviceToggles(); + device_toggles_desc.enabledToggleCount = enabled_device_toggles.size(); + device_toggles_desc.enabledToggles = enabled_device_toggles.data(); + + auto disabled_device_toggles = GetDisabledDeviceToggles(); + device_toggles_desc.disabledToggleCount = disabled_device_toggles.size(); + device_toggles_desc.disabledToggles = disabled_device_toggles.data(); + std::vector required_features = GetAvailableRequiredFeatures(adapter_); if (required_features.size() > 0) { device_desc.requiredFeatures = required_features.data(); + device_desc.requiredFeatureCount = required_features.size(); } - wgpu::RequiredLimits required_limits = GetAvailableRequiredLimits(adapter_); + wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_); device_desc.requiredLimits = &required_limits; // TODO: revise temporary error handling From eb80f7c4e28c08d735e1ad8efdb91335dfb5cd1c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 17:16:51 -0700 Subject: [PATCH 052/114] more strict shape&stride usage check --- .../core/providers/webgpu/shader_helper.cc | 51 ++++++++++++++++--- .../core/providers/webgpu/shader_variable.cc | 4 +- .../core/providers/webgpu/shader_variable.h | 40 +++++++++------ 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index cd3507a6439a..bf791a36858b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -102,6 +102,7 @@ const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ProgramVa #ifndef NDEBUG // if debug build namespace { +// Validate if the tensor element type matches the program variable data type Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: @@ -148,8 +149,7 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va return Status::OK(); } -using RankOrShape = std::variant>; - +// Validate if the number of components and override shape match the original shape Status ValidateVariableShape(const TensorShape& origin_shape, bool use_override_shape, const TensorShape& override_shape, @@ -166,6 +166,36 @@ Status ValidateVariableShape(const TensorShape& origin_shape, return Status::OK(); } + +// Validate if the dependency and variable usage match +Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderVariable::Usage usage, bool is_input) { + bool dependency_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool dependency_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + bool dependency_type = (dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type; + + // if dependency is already set for shape, it is no need to set for rank. + ORT_RETURN_IF(dependency_rank && dependency_shape, + "Dependency cannot set for both \"Rank\" and \"Shape\"."); + + // if dependency is set for shape, it's already part of the shader cache. no need to use uniform. + ORT_RETURN_IF(dependency_shape && (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform, + "Dependency is set for \"Shape\", using uniform for shape is not allowed."); + + // for input variable, check is more strict. + // this is because usually output shape is determined by the existing information, which is already part of the shader cache. + if (is_input) { + // if dependency is not set for type, should not use type alias for element and value. + // storage type is always used. so setting not depending on type is at user's own risk. + ORT_RETURN_IF(!dependency_type && (usage & (ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias)), + "Input dependency is not set for \"Type\", but type alias for element type or value type is used."); + + // if dependency is not set for rank and shape, the shader should not use shape and stride. + ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderVariable::UseShapeAndStride), + "Input dependency is set for neither \"Rank\" nor \"Shape\", but variable shape and stride is used."); + } + + return Status::OK(); +} } // namespace const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, @@ -197,6 +227,7 @@ Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVar input.use_override_shape, input.use_override_shape ? input.override_shape : input.tensor->Shape(), var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); return Status::OK(); } @@ -206,6 +237,8 @@ Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderV output.use_override_shape, output.use_override_shape ? output.override_shape : output.tensor->Shape(), var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); + return Status::OK(); } @@ -280,6 +313,12 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha if (use_f16_) { ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; + if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { + ss << "enable subgroups_f16;\n"; + } + } + if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { + ss << "enable subgroups;\n"; } // @@ -334,12 +373,12 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && input.rank_ > 0; + bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && (input.usage_ & ShaderVariable::UseShapeAndStride) && input.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && output.rank_ > 0; + bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && (output.usage_ & ShaderVariable::UseShapeAndStride) && output.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); } @@ -381,7 +420,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { const size_t rank = input.rank_; - if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { + if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform) && (input.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { std::string shape = input.name_ + "_shape"; std::string stride = input.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); @@ -391,7 +430,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { const size_t rank = output.rank_; - if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { + if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform) && (output.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { std::string shape = output.name_ + "_shape"; std::string stride = output.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 98720c785481..f5fc236aca71 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -108,8 +108,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } - // Need shape and strides when (not use uniform) and (any other usage is enabled) - if (!(usage_ & UseUniform) && (usage_ & ~UseUniform) && rank_ > 0) { + // Need shape and strides when (not use uniform) and (use shape and stride is enabled) + if (!(usage_ & UseUniform) && (usage_ & UseShapeAndStride) && rank_ > 0) { SS("const ", shape, " = ", IndicesType(), "("); bool first = true; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index b8b44de92911..aa186d58740e 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -39,18 +39,19 @@ std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool i class ShaderVariable { public: enum Usage : uint32_t { - None = 0, // no usage. this means no additional implementation code will be generated. - UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) - UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) - UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) - UseOffsetToIndices = 8, // use implementation of fn o2i_{name} - UseIndicesToOffset = 16, // use implementation of fn i2o_{name} - UseBroadcastedIndicesToOffset = 32, // use implementation of fn {broadcasted_result_name}_bi2o_{name} - UseSet = 64, // use implementation of fn set_{name} - UseSetByIndices = 128, // use implementation of fn set_{name}_by_indices - UseGet = 256, // use implementation of fn get_{name} - UseGetByIndices = 512, // use implementation of fn get_{name}_by_indices - UseUniform = 1024, // use uniform for shape and stride + None = 0, // no usage. this means no additional implementation code will be generated. + UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) + UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) + UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) + UseShapeAndStride = 16, // use shape and stride for the variable + UseOffsetToIndices = 32, // use implementation of fn o2i_{name} + UseIndicesToOffset = 64, // use implementation of fn i2o_{name} + UseBroadcastedIndicesToOffset = 128, // use implementation of fn {broadcasted_result_name}_bi2o_{name} + UseSet = 256, // use implementation of fn set_{name} + UseSetByIndices = 512, // use implementation of fn set_{name}_by_indices + UseGet = 1024, // use implementation of fn get_{name} + UseGetByIndices = 2048, // use implementation of fn get_{name}_by_indices + UseUniform = 32768, // use uniform for shape and stride }; ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims); @@ -188,19 +189,19 @@ std::string pass_as_string(T&& v) { } // namespace detail inline std::string ShaderVariable::OffsetToIndices(std::string_view offset_expr) const { - usage_ |= UseOffsetToIndices; + usage_ |= UseOffsetToIndices | UseShapeAndStride; return rank_ < 2 ? std::string{offset_expr} : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); } inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr) const { - usage_ |= UseIndicesToOffset; + usage_ |= UseIndicesToOffset | UseShapeAndStride; return rank_ < 2 ? std::string{indices_expr} : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); } inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { - usage_ |= UseBroadcastedIndicesToOffset; + usage_ |= UseBroadcastedIndicesToOffset | UseShapeAndStride; broadcasted_to_.push_back(broadcasted_result); return rank_ == 0 ? "0" @@ -209,21 +210,24 @@ inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view i template inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { + usage_ |= UseShapeAndStride; return rank_ == 0 ? "0" - : MakeStringWithClassicLocale(name_, "_indices_t(", + : MakeStringWithClassicLocale(IndicesType(), "(", absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), ')'); } template inline std::string ShaderVariable::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { + usage_ |= UseShapeAndStride; return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); } template inline std::string ShaderVariable::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + usage_ |= UseShapeAndStride; return rank_ < 2 ? std::string{indices_var} : GetElementAt(indices_var, idx_expr, rank_); } @@ -235,6 +239,7 @@ inline std::string ShaderVariable::SetByOffset(TOffset&& offset, TValue&& value) template inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { + usage_ |= UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); if constexpr (sizeof...(TIndicesAndValue) == 1) { return SetByOffset("0", std::forward(args)...); @@ -249,6 +254,7 @@ inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { } inline std::string ShaderVariable::SetByIndices(std::string_view indices_var, std::string_view value) const { + usage_ |= UseShapeAndStride; if (rank_ < 2) { return SetByOffset(indices_var, value); } else { @@ -264,6 +270,7 @@ inline std::string ShaderVariable::GetByOffset(TOffset&& offset) const { template inline std::string ShaderVariable::Get(TIndices&&... indices) const { + usage_ |= UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); if constexpr (sizeof...(TIndices) == 0) { return GetByOffset("0"); @@ -278,6 +285,7 @@ inline std::string ShaderVariable::Get(TIndices&&... indices) const { } inline std::string ShaderVariable::GetByIndices(std::string_view indices_var) const { + usage_ |= UseShapeAndStride; if (rank_ < 2) { return GetByOffset(indices_var); } else { From 39d55098f9db0d5c46b3085e9e8fcd5f76abb943 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 23:47:23 -0700 Subject: [PATCH 053/114] fix vector realloc --- .../core/providers/webgpu/shader_helper.cc | 47 ++++++++++--------- .../core/providers/webgpu/shader_helper.h | 2 +- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index bf791a36858b..f43806d1406c 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -218,7 +218,8 @@ const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); } - return vars_[std::underlying_type::type(scope)].emplace_back(name, type, usage, dims); + const auto& var = vars_[std::underlying_type::type(scope)].emplace_back(std::make_unique(name, type, usage, dims)); + return *var; } Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { @@ -255,18 +256,18 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { for (size_t i = 0; i < input_vars.size(); i++) { #ifndef NDEBUG // if debug build // Validate input shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], input_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars[i])); #endif // check input dependencies with actual usages. - auto usage = input_vars[i].usage_; + auto usage = input_vars[i]->usage_; bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; auto dependency = program_.Inputs()[i].dependency; bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; if (use_uniform) { - ORT_RETURN_IF_NOT((use_rank || input_vars[i].rank_ < 2) && !use_shape, + ORT_RETURN_IF_NOT((use_rank || input_vars[i]->rank_ < 2) && !use_shape, "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); } else { ORT_RETURN_IF_NOT(use_shape, @@ -279,11 +280,11 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { for (size_t i = 0; i < output_vars.size(); i++) { #ifndef NDEBUG // if debug build // Validate output shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], output_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars[i])); #endif // check output dependencies with actual usages. - auto usage = output_vars[i].usage_; + auto usage = output_vars[i]->usage_; bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; auto dependency = program_.Outputs()[i].dependency; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; @@ -356,11 +357,11 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha size_t variable_count = 0; const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; for (const auto& input : input_vars) { - ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; + ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; } const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; for (const auto& output : output_vars) { - ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; + ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; } // @@ -373,14 +374,18 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && (input.usage_ & ShaderVariable::UseShapeAndStride) && input.rank_ > 0; + bool use_uniform = (input->usage_ & ShaderVariable::UseUniform) && + (input->usage_ & ShaderVariable::UseShapeAndStride) && + input->rank_ > 0; use_any_shape_uniform |= use_uniform; - shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); + shape_uniform_ranks.push_back(use_uniform ? input->rank_ : 0); } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && (output.usage_ & ShaderVariable::UseShapeAndStride) && output.rank_ > 0; + bool use_uniform = (output->usage_ & ShaderVariable::UseUniform) && + (output->usage_ & ShaderVariable::UseShapeAndStride) && + output->rank_ > 0; use_any_shape_uniform |= use_uniform; - shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); + shape_uniform_ranks.push_back(use_uniform ? output->rank_ : 0); } if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), @@ -419,20 +424,20 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha }; for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - const size_t rank = input.rank_; - if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform) && (input.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { - std::string shape = input.name_ + "_shape"; - std::string stride = input.name_ + "_stride"; + const size_t rank = input->rank_; + if (rank > 0 && (input->usage_ & ShaderVariable::Usage::UseUniform) && (input->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + std::string shape = input->name_ + "_shape"; + std::string stride = input->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - const size_t rank = output.rank_; - if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform) && (output.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { - std::string shape = output.name_ + "_shape"; - std::string stride = output.name_ + "_stride"; + const size_t rank = output->rank_; + if (rank > 0 && (output->usage_ & ShaderVariable::Usage::UseUniform) && (output->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + std::string shape = output->name_ + "_shape"; + std::string stride = output->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } @@ -455,7 +460,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha ss << "\n"; for (const auto& var_group : vars_) { for (const auto& var : var_group) { - var.Impl(ss); + var->Impl(ss); } } ss << "\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index ca1bf9ce7ff5..23c1ff42b0df 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -174,7 +174,7 @@ class ShaderHelper final { const ProgramBase& program_; const ProgramMetadata& program_metadata_; - std::array, static_cast(ProgramVariableScope::Count)> vars_; + std::array>, static_cast(ProgramVariableScope::Count)> vars_; std::ostringstream additional_implementation_; std::ostringstream body_; From cd961c3a75d12014f472a41e6eca81fa0639f63c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 23:59:39 -0700 Subject: [PATCH 054/114] simplify cache hint interface. --- onnxruntime/core/providers/webgpu/program.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index c48bdb1a4ff1..6b339af767f5 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -245,9 +245,9 @@ class ProgramBase { // // set the cache hint for the program - template - ProgramBase& CacheHint(CacheHintArgs&&... args) { - cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(args)...), "|"); + template + ProgramBase& CacheHint(T&& hint) { + cache_hint_ = std::forward(hint); return *this; } From ddc2fbb7e948e21b566f2826b0e2c96e196db8ab Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:43:04 -0700 Subject: [PATCH 055/114] revise expand --- .../core/providers/webgpu/tensor/expand.cc | 17 ++++++++--------- .../core/providers/webgpu/tensor/expand.h | 5 ++--- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 53991365d654..9052095dec67 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -18,7 +18,7 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { ToProgramVariableDataType(Outputs()[0].tensor->GetElementType()), ShaderVariable::UseUniform); - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", "let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n", output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); @@ -30,20 +30,19 @@ Status Expand::ComputeInternal(ComputeContext& context) const { const auto* input_tensor = context.Input(0); const auto* input_shape_tensor = context.Input(1); - const auto* p_shape = input_shape_tensor->Data(); - TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; - TensorShape output_shape(output_dims); + auto output_dims = input_shape_tensor->DataAsSpan(); + TensorShape output_shape{}; ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); auto* output_tensor = context.Output(0, output_shape); - SafeInt vec_size = output_shape.Size(); - ExpandProgram program{"Expand"}; + uint32_t data_size = SafeInt(output_shape.Size()); + ExpandProgram program{}; program .Inputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) .Outputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .DispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .UniformVariables({ - {static_cast(vec_size)}, + {data_size}, }); return context.RunProgram(program); } @@ -64,4 +63,4 @@ WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes( WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) } // namespace webgpu -}; // namespace onnxruntime +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h index a5c24f1fa496..046520b47925 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.h +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -11,12 +11,11 @@ namespace webgpu { class ExpandProgram final : public Program { public: - ExpandProgram(const std::string& kernel_name) : Program{kernel_name} { - } + ExpandProgram() : Program{"Expand"} {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); }; class Expand final : public WebGpuKernel { From e8be835cae04ed21efa72a8e0d5ab417cbef7992 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:50:54 -0700 Subject: [PATCH 056/114] revise unary --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 079a19221377..630cfce486ca 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -15,7 +15,7 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - "let a = ", input.GetByOffset("global_idx"), ";\n", + " let a = ", input.GetByOffset("global_idx"), ";\n ", output.SetByOffset("global_idx", expression_)); return Status::OK(); @@ -119,7 +119,7 @@ WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) constexpr char HardSigmoidImpl[] = R"( -fn hard_sigmoid_v(v: x_value_t) -> x_value_t { +fn hard_sigmoid_v(v: vec4) -> vec4 { let alpha = x_element_t(uniforms.f32_attr[0]); let beta_v = vec4(uniforms.f32_attr[1]); return max(vec4(0.0), @@ -129,7 +129,7 @@ fn hard_sigmoid_v(v: x_value_t) -> x_value_t { class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) - : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias} { + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias} { // attr[0] is alpha, attr[1] is beta info.GetAttrOrDefault("alpha", attr, 0.2f); info.GetAttrOrDefault("beta", attr + 1, 0.5f); From bd7d592386932b5dd55793dd4a44328808114269 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:28:35 -0700 Subject: [PATCH 057/114] Elu/Relu/LeakyRelu/ThresholdedRelu/Gelu --- .../webgpu/math/unary_elementwise_ops.cc | 87 +++++++++++++++++-- .../webgpu/math/unary_elementwise_ops.h | 2 + .../webgpu/webgpu_execution_provider.cc | 19 ++-- 3 files changed, 94 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 630cfce486ca..baa92fdc5c3d 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -37,6 +37,9 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { .UniformVariables({ {static_cast(vec_size)}, }); + if (!cache_hint.empty()) { + program.CacheHint(cache_hint); + } ORT_RETURN_IF_ERROR(ConfigureProgram(program)); return context.RunProgram(program); } @@ -172,7 +175,13 @@ WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) // built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) // https://github.com/gpuweb/gpuweb/issues/4458 -WEBGPU_ELEMENTWISE_IMPL(Tanh, "sign(a) * (1 - exp(-2 * abs(a))) / (1 + exp(-2 * abs(a)))") +constexpr char TanhImpl[] = R"( +fn tanh_v(a: x_value_t) -> x_value_t { + let expr = exp(-2 * abs(a)); + return sign(a) * (1 - expr) / (1 + expr); +} +)"; +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) @@ -193,10 +202,78 @@ WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) // todo: clip -// constexpr char EluImpl[] = R"( -//)"; -// -// WEBGPU_ELEMENTWISE_IMPL(Elu, "elu_v(a)", ) +class LinearUnit : public UnaryElementwise { + public: + LinearUnit(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl, + float default_alpha) + : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} { + info.GetAttrOrDefault("alpha", &alpha_, default_alpha); + } + + Status ConfigureProgram(UnaryElementwiseProgram& program) const override { + program.UniformVariables({alpha_, {}}); + return Status::OK(); + } + + protected: + float alpha_; +}; + +#define WEBGPU_LU_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public LinearUnit { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +constexpr char EluImpl[] = R"( +fn elu(a: x_element_t) -> x_element_t { + let alpha = x_element_t(uniforms.f32_attr); + return select((exp(a) - 1.0) * alpha, a, a >= 0.0); +} + +fn elu_v(v: vec4) -> vec4 { + return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); +} +)"; + +WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) +WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) + +// TODO: support attribute "approximate" +class Gelu : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) + : UnaryElementwise{info, + "Gelu", + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhBasedImpl : DefaultImpl, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, + ShaderVariable::UseValueTypeAlias} { + cache_hint = info.GetAttrOrDefault("approximate", "none"); + } + + constexpr static const char DefaultImpl[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; + constexpr static const char TanhBasedImpl[] = "0.5 * a * (1 + tanh_v(0.7978845608028654 * (a + 0.044715 * a * a * a)))"; + + protected: + float alpha_; +}; + +WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.f32_attr) * a, a, a >= vec4(0))", "", 0.01f) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.f32_attr))", "", 1.0f) +WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) // TODO: add other unary elementwise ops diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 2d084bf227f7..711b0b0a6044 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -45,6 +45,8 @@ class UnaryElementwise : public WebGpuKernel { additional_usage_{usage} {} protected: + std::string cache_hint; + Status ComputeInternal(ComputeContext& context) const final; virtual Status ConfigureProgram(UnaryElementwiseProgram& program) const { program.UniformVariables({{}, {}}); // empty for both float and int attribute(s) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 1ee7a51618f7..decc74b59cae 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -134,6 +134,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, Gelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); @@ -186,8 +188,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add); @@ -442,13 +442,14 @@ std::unique_ptr RegisterKernels() { // KERNEL_CREATE_INFO_VERSIONED(11, 11, Clip), // KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip), // KERNEL_CREATE_INFO(13, Clip), - // KERNEL_CREATE_INFO(6, Elu), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), - // KERNEL_CREATE_INFO(14, Relu), - // KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), - // KERNEL_CREATE_INFO(16, LeakyRelu), - // KERNEL_CREATE_INFO(10, ThresholdedRelu), + KERNEL_CREATE_INFO(6, Elu), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + KERNEL_CREATE_INFO(14, Relu), + KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + KERNEL_CREATE_INFO(16, LeakyRelu), + KERNEL_CREATE_INFO(10, ThresholdedRelu), + KERNEL_CREATE_INFO(20, Gelu), // // binary - math // KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), From 601e50f142478db99923453112c051918eef2a07 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:40:03 -0700 Subject: [PATCH 058/114] remove unused field in class Gelu --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index baa92fdc5c3d..2c015524e1ac 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -256,9 +256,6 @@ class Gelu : public UnaryElementwise { constexpr static const char DefaultImpl[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; constexpr static const char TanhBasedImpl[] = "0.5 * a * (1 + tanh_v(0.7978845608028654 * (a + 0.044715 * a * a * a)))"; - - protected: - float alpha_; }; WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) From 8f36da219dab7073f7a054e4b24ade0ea934d39c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 16:53:12 -0700 Subject: [PATCH 059/114] remove out-of-dated comments --- onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 2c015524e1ac..d28769f08071 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -242,7 +242,6 @@ fn elu_v(v: vec4) -> vec4 { WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) -// TODO: support attribute "approximate" class Gelu : public UnaryElementwise { public: Gelu(const OpKernelInfo& info) From 72ebd856efc9a1088105d5f4a190ef07ecf3110b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:01:33 -0700 Subject: [PATCH 060/114] Clip --- .../webgpu/math/unary_elementwise_ops.cc | 94 ++++++++++++++++--- .../webgpu/math/unary_elementwise_ops.h | 10 +- 2 files changed, 84 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index d28769f08071..ceaae426ddde 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/webgpu/math/unary_elementwise_ops.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -40,7 +42,7 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { if (!cache_hint.empty()) { program.CacheHint(cache_hint); } - ORT_RETURN_IF_ERROR(ConfigureProgram(program)); + ORT_RETURN_IF_ERROR(ConfigureProgram(context, program)); return context.RunProgram(program); } @@ -62,6 +64,12 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { KernelDefBuilder().TypeConstraint("T", TYPE), \ OP_TYPE_AND_CLASS_NAME); +#define WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + OP_TYPE_AND_CLASS_NAME); + // // math // @@ -123,8 +131,8 @@ WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) constexpr char HardSigmoidImpl[] = R"( fn hard_sigmoid_v(v: vec4) -> vec4 { - let alpha = x_element_t(uniforms.f32_attr[0]); - let beta_v = vec4(uniforms.f32_attr[1]); + let alpha = x_element_t(uniforms.attr[0]); + let beta_v = vec4(uniforms.attr[1]); return max(vec4(0.0), min(vec4(1.0), alpha * v + beta_v)); } @@ -138,8 +146,8 @@ class HardSigmoid final : public UnaryElementwise { info.GetAttrOrDefault("beta", attr + 1, 0.5f); } - Status ConfigureProgram(UnaryElementwiseProgram& program) const override { - program.UniformVariables({gsl::make_span(attr, 2), {}}); + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.UniformVariables({gsl::make_span(attr, 2)}); return Status::OK(); } @@ -194,14 +202,72 @@ WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) -// todo: logical ops +WEBGPU_ELEMENTWISE_IMPL(Not, "!a") +WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(Not, 1) + +// No longer support Clip < opset 11 (where min and max are attributes) +// +// Use template class for "Clip" because the implementation is significantly different between float16 and float32 +template +class Clip final : public UnaryElementwise { + public: + Clip(const OpKernelInfo& info) + : UnaryElementwise{info, + "Clip", + std::is_same_v ? ClipF16Impl : ClipImpl, + "", ShaderVariable::UseElementTypeAlias} {} + + Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { + const auto* clip_min_tensor = context.Input(1); + const auto* clip_max_tensor = context.Input(2); + const T attr[] = {clip_min_tensor->Data()[0], + clip_max_tensor->Data()[0]}; + if constexpr (std::is_same_v) { + // F16: stores span as a single float + float encoded_value = *reinterpret_cast(attr); + program.UniformVariables({encoded_value}); + } else { + static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32"); + // stores span as-is + program.UniformVariables({gsl::make_span(attr, 2)}); + } + return Status::OK(); + } + + // uniforms.attr is a f32 value. It is encoded as a float for 2 f16 values. + // bitcast>(uniforms.attr)[0] is clip_min, bitcast>(uniforms.attr)[1] is clip_max + constexpr static const char ClipF16Impl[] = "clamp(a, vec4(bitcast>(uniforms.attr)[0]), vec4(bitcast>(uniforms.attr)[1]))"; + + // the size of element of uniforms.attr should be the same as x_element_t. use bitcast to convert between them + // uniforms.attr[0] is clip_min, uniforms.attr[1] is clip_max + constexpr static const char ClipImpl[] = "clamp(a, vec4(bitcast(uniforms.attr[0])), vec4(bitcast(uniforms.attr[1])))"; +}; +#define WEBGPU_CLIP_KERNEL(TYPE) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(Clip, kOnnxDomain, 13, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip); +WEBGPU_CLIP_KERNEL(float) +WEBGPU_CLIP_KERNEL(MLFloat16) // // activation // -// todo: clip - class LinearUnit : public UnaryElementwise { public: LinearUnit(const OpKernelInfo& info, @@ -213,8 +279,8 @@ class LinearUnit : public UnaryElementwise { info.GetAttrOrDefault("alpha", &alpha_, default_alpha); } - Status ConfigureProgram(UnaryElementwiseProgram& program) const override { - program.UniformVariables({alpha_, {}}); + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.UniformVariables({alpha_}); return Status::OK(); } @@ -230,7 +296,7 @@ class LinearUnit : public UnaryElementwise { constexpr char EluImpl[] = R"( fn elu(a: x_element_t) -> x_element_t { - let alpha = x_element_t(uniforms.f32_attr); + let alpha = x_element_t(uniforms.attr); return select((exp(a) - 1.0) * alpha, a, a >= 0.0); } @@ -264,14 +330,12 @@ WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) -WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.f32_attr) * a, a, a >= vec4(0))", "", 0.01f) +WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.attr) * a, a, a >= vec4(0))", "", 0.01f) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) -WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.f32_attr))", "", 1.0f) +WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.attr))", "", 1.0f) WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) -// TODO: add other unary elementwise ops - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 711b0b0a6044..d870278f4c09 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -19,9 +19,9 @@ class UnaryElementwiseProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size - {"f32_attr", ProgramUniformVariableDataType::Float32}, // float type attribute(s) - {"i32_attr", ProgramUniformVariableDataType::Int32}); // int type attribute(s) + {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size + {"attr", ProgramUniformVariableDataType::Float32}); // float type attribute(s) + // TODO: add u32/i32 attribute(s) if needed private: std::string_view expression_; @@ -48,8 +48,8 @@ class UnaryElementwise : public WebGpuKernel { std::string cache_hint; Status ComputeInternal(ComputeContext& context) const final; - virtual Status ConfigureProgram(UnaryElementwiseProgram& program) const { - program.UniformVariables({{}, {}}); // empty for both float and int attribute(s) + virtual Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const { + program.UniformVariables({{}}); // empty for attribute(s) return Status::OK(); } From a3244aeb685c1d048b684672429ae5b17af343de Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:24:11 -0700 Subject: [PATCH 061/114] fix rank in shader helper --- onnxruntime/core/providers/webgpu/shader_helper.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index f43806d1406c..7e6130dd4e91 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -160,7 +160,7 @@ Status ValidateVariableShape(const TensorShape& origin_shape, "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); } else if (num_components > 1) { // if shape is not overriden, assert origin_shape[-1] % 4 == 0 - ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.Size() - 1] % num_components == 0, + ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.NumDimensions() - 1] % num_components == 0, "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); } From 5a2ae8c54347c0a51bad33d1ef55b9c1077a098c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 01:56:28 -0700 Subject: [PATCH 062/114] fix shader variable --- onnxruntime/core/providers/webgpu/shader_variable.cc | 4 ++-- onnxruntime/core/providers/webgpu/shader_variable.h | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f5fc236aca71..07c5915be466 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -172,8 +172,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn {res_name}_bi2o_{name}" if (usage_ & UseBroadcastedIndicesToOffset) { if (rank_ > 0) { - for (const auto& iter : broadcasted_to_) { - const auto& broadcasted_result = iter.get(); + for (const auto& broadcasted_result_ptr : broadcasted_to_) { + const auto& broadcasted_result = *broadcasted_result_ptr; SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); if (rank_ == 1) { SS(" return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index aa186d58740e..d4281dd31d65 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/framework/tensor_shape.h" @@ -60,7 +61,7 @@ class ShaderVariable { ShaderVariable& operator=(ShaderVariable&&) = default; // get the name of the variable. - std::string_view Name() const; + inline std::string_view Name() const { return name_; } // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. @@ -147,7 +148,7 @@ class ShaderVariable { TensorShape dims_; mutable Usage usage_; - mutable std::vector> broadcasted_to_; + mutable std::set broadcasted_to_; // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. std::string indices_type_; @@ -202,7 +203,7 @@ inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { usage_ |= UseBroadcastedIndicesToOffset | UseShapeAndStride; - broadcasted_to_.push_back(broadcasted_result); + broadcasted_to_.insert(&broadcasted_result); return rank_ == 0 ? "0" : MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); From aa54ff8012d45ad6dc7ade798957cefa971397f8 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 03:47:09 -0700 Subject: [PATCH 063/114] move components number from variable to program --- .../webgpu/math/unary_elementwise_ops.cc | 12 +-- onnxruntime/core/providers/webgpu/program.h | 74 ++++++++++++------- .../core/providers/webgpu/shader_helper.cc | 39 ++++++---- .../core/providers/webgpu/shader_helper.h | 4 - .../core/providers/webgpu/tensor/expand.cc | 8 +- 5 files changed, 77 insertions(+), 60 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index ceaae426ddde..8d8f855ec20a 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -9,12 +9,8 @@ namespace onnxruntime { namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", - ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), - ShaderVariable::UseUniform | additional_usage_); - const auto& output = shader.AddOutput("y", - ToProgramVariableDataType(Outputs()[0].tensor->GetElementType(), 4), - ShaderVariable::UseUniform); + const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), " let a = ", input.GetByOffset("global_idx"), ";\n ", @@ -33,8 +29,8 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { SafeInt vec_size = (size + 3) / 4; UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program - .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}}}) - .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}}}) + .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) + .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .UniformVariables({ {static_cast(vec_size)}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 6b339af767f5..38e7a842aa32 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -163,34 +163,6 @@ inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependen return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b); } -struct ProgramInput { - ProgramInput(const Tensor* tensor) - : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency) - : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) - : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} - - const Tensor* tensor; - ProgramTensorMetadataDependency dependency; - bool use_override_shape; - TensorShape override_shape; -}; - -struct ProgramOutput { - ProgramOutput(Tensor* tensor) - : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency) - : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) - : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} - - Tensor* tensor; - ProgramTensorMetadataDependency dependency; - bool use_override_shape; - TensorShape override_shape; -}; - constexpr SafeInt WORKGROUP_SIZE = 64; // represents the scope of a variable in a shader program. @@ -232,6 +204,52 @@ int NumberOfComponents(ProgramVariableDataType type); ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); +struct ProgramInput { + ProgramInput(const Tensor* tensor) + : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{false}, + override_shape{} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + + const Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +struct ProgramOutput { + ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{false}, + override_shape{} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + + Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + namespace detail { class ProgramWrapper; } diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 7e6130dd4e91..cd21f4752f30 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -27,9 +27,7 @@ ShaderHelper::ShaderHelper(const ProgramBase& program, dispatch_group_size_y_{dispatch_group_size_y}, dispatch_group_size_z_{dispatch_group_size_z}, program_{program}, - program_metadata_{program_metadata}, - use_f16_{false} { -} + program_metadata_{program_metadata} {} Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here @@ -80,24 +78,24 @@ Status ShaderHelper::Init() { return Status::OK(); } -const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { +const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ShaderVariable::Usage usage) { const size_t input_index = vars_[std::underlying_type::type(ProgramVariableScope::Input)].size(); ORT_ENFORCE(input_index < program_.Inputs().size(), "Too many inputs in the program (", program_.Inputs().size(), ")"); const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape : program_.Inputs()[input_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Input, name, type, usage, dims); + return AddVariableImpl(ProgramVariableScope::Input, name, usage, dims); } -const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { +const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ShaderVariable::Usage usage) { const size_t output_index = vars_[std::underlying_type::type(ProgramVariableScope::Output)].size(); ORT_ENFORCE(output_index < program_.Outputs().size(), "Too many outputs in the program (", program_.Outputs().size(), ")"); const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape : program_.Outputs()[output_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Output, name, type, usage, dims); + return AddVariableImpl(ProgramVariableScope::Output, name, usage, dims); } #ifndef NDEBUG // if debug build @@ -200,7 +198,6 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage, const TensorShape& dims) { if (scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) { @@ -210,15 +207,20 @@ const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); } - if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { - use_f16_ = true; - } + auto& vars = vars_[std::underlying_type::type(scope)]; + ProgramVariableDataType type = ProgramVariableDataType::InvalidType; - if (scope == ProgramVariableScope::Local) { + if (scope == ProgramVariableScope::Input) { + const auto& input = program_.Inputs()[vars.size()]; + type = input.var_type; + } else if (scope == ProgramVariableScope::Output) { + const auto& output = program_.Outputs()[vars.size()]; + type = output.var_type; + } else { ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); } - const auto& var = vars_[std::underlying_type::type(scope)].emplace_back(std::make_unique(name, type, usage, dims)); + const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); return *var; } @@ -311,7 +313,16 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // // Section feature enabling // - if (use_f16_) { + if (std::any_of(program_.Inputs().begin(), + program_.Inputs().end(), + [](const ProgramInput& input) { + return input.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + }) || + std::any_of(program_.Outputs().begin(), + program_.Outputs().end(), + [](const ProgramOutput& output) { + return output.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + })) { ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index 23c1ff42b0df..08ff111f8a69 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -81,14 +81,12 @@ class ShaderHelper final { // // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddInput(const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); // Add an output variable to the shader. // // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddOutput(const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); // Append additional implementation code to the shader. @@ -140,7 +138,6 @@ class ShaderHelper final { const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage, const TensorShape& dims); @@ -178,7 +175,6 @@ class ShaderHelper final { std::ostringstream additional_implementation_; std::ostringstream body_; - bool use_f16_ = false; bool body_set_ = false; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 9052095dec67..82451c939824 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -11,12 +11,8 @@ namespace onnxruntime { namespace webgpu { Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", - ToProgramVariableDataType(Inputs()[0].tensor->GetElementType()), - ShaderVariable::UseUniform); - const auto& output = shader.AddOutput("output", - ToProgramVariableDataType(Outputs()[0].tensor->GetElementType()), - ShaderVariable::UseUniform); + const auto& input = shader.AddInput("input", ShaderVariable::UseUniform); + const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", From 969384d23c2ea6043d9450b8007dc59455674f4e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:04:19 -0700 Subject: [PATCH 064/114] mark components in cache key --- onnxruntime/core/providers/webgpu/program.cc | 24 +++++++++++++++++++ onnxruntime/core/providers/webgpu/program.h | 3 +++ .../providers/webgpu/program_cache_key.cc | 11 +++++---- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 4a5785dc4def..d4a2b24172d0 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -82,6 +82,30 @@ std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) return os; } +#ifndef NDEBUG +constexpr std::string_view ProgramVariableDataTypeName[] = { + "f32", // f32 + "f32x2", // vec2f32 + "f32x4", // vec4f32 + "f16", // f16 + "f16x2", // vec2f16 + "f16x4", // vec4f16 + "i32", // i32 + "i32x2", // vec2i32 + "i32x4", // vec4i32 + "u32", // u32 + "u32x2", // vec2u32 + "u32x4", // vec4u32 + "i64", // int64 + "u64", // uint64 + "boolx4", // vec4bool +}; +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { + os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} +#endif + int NumberOfComponents(ProgramVariableDataType type) { switch (type) { case ProgramVariableDataType::Float32: diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 38e7a842aa32..e162cddbb640 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -199,6 +199,9 @@ enum class ProgramVariableDataType { Uint64, Vec4Bool, }; +#ifndef NDEBUG +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); +#endif int NumberOfComponents(ProgramVariableDataType type); diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index c6ab16a73423..09a536f7916b 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -9,7 +9,8 @@ namespace onnxruntime { namespace webgpu { namespace { -void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTensorMetadataDependency dependency, bool& first) { +// append the info of an input or output to the cachekey +void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) { if (first) { first = false; } else { @@ -17,9 +18,9 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTenso } if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { #ifndef NDEBUG // if debug build - ss << DataTypeImpl::ToString(tensor.DataType()); + ss << var_type; #else - ss << output.tensor->GetElementType(); + ss << static_cast(var_type); #endif ss << ';'; } @@ -87,13 +88,13 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp ss << ":" D("Inputs="); first = true; for (const auto& input : program.Inputs()) { - AppendTensorInfo(ss, *input.tensor, input.dependency, first); + AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first); } ss << ":" D("Outputs="); first = true; for (const auto& output : program.Outputs()) { - AppendTensorInfo(ss, *output.tensor, output.dependency, first); + AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); } return ss.str(); From 6b824861ad565b5b549cd80aad186cd71ec8835c Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 10 Sep 2024 08:30:54 +0800 Subject: [PATCH 065/114] Add FastGelu op (#21991) ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 84 +++++++++++++++++++ .../contrib_ops/webgpu/bert/fast_gelu.h | 38 +++++++++ .../webgpu/webgpu_contrib_kernels.cc | 42 +++++----- .../webgpu/webgpu_contrib_kernels.h | 5 +- onnxruntime/core/providers/webgpu/program.cc | 19 ++++- onnxruntime/core/providers/webgpu/program.h | 13 ++- .../test/contrib_ops/fastgelu_op_test.cc | 13 ++- 7 files changed, 184 insertions(+), 30 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc new file mode 100644 index 000000000000..42f056206f3f --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + FastGelu); + +Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); + + std::string add_bias = ""; + if (Inputs().size() > 1) { + const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride); + add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n" + " x += input_value_t(" + + bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " + + bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") + ", " + + bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") + ", " + + bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") + ");\n" + : " x += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } + + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " var x = ", input.GetByOffset("global_idx"), ";\n", + add_bias, + " let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ", + output.SetByOffset("global_idx", "y")); + + return Status::OK(); +} + +Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + auto* output = context.Output(0, input->Shape()); + + uint32_t data_size = SafeInt(output->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const auto vec_size = (data_size + 3) / 4; + uint32_t bias_size = 0; + int bias_components = 1; + + if (bias != nullptr) { + bias_size = SafeInt(bias->Shape().Size()); + if (bias_size % 4 == 0) { + bias_components = 4; + bias_size = bias_size / 4; + } + } + + FastGeluProgram program{bias_components}; + program.Input({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .Output({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .UniformVariable({vec_size}); + + if (bias != nullptr) { + program.Input({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) + .CacheHint(std::to_string(bias_components)); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h new file mode 100644 index 000000000000..fa40d52bf301 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class FastGeluProgram final : public Program { + public: + FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int bias_components_; +}; + +class FastGelu final : public WebGpuKernel { + public: + FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 91f51df588fc..def104b6cb10 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -26,11 +26,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); -// template <> -// KernelCreateInfo BuildKernelCreateInfo() { -// KernelCreateInfo info; -// return info; -// } +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -38,22 +38,22 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h index 6cdf7382804f..d73859de7823 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h @@ -3,13 +3,16 @@ #pragma once -#include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" namespace onnxruntime { namespace contrib { namespace webgpu { +// forward declaration for this EP's namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry); } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index d4a2b24172d0..b05b576b4bc3 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -192,13 +192,23 @@ ProgramBase::ProgramBase(const std::string& name) workgroup_size_z_{0} { } +ProgramBase& ProgramBase::Input(ProgramInput&& input) { + inputs_.emplace_back(input); + return *this; +} + ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { - inputs_.assign(inputs.begin(), inputs.end()); + inputs_.insert(inputs_.end(), inputs.begin(), inputs.end()); + return *this; +} + +ProgramBase& ProgramBase::Output(ProgramOutput&& output) { + outputs_.emplace_back(output); return *this; } ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { - outputs_.assign(outputs.begin(), outputs.end()); + outputs_.insert(outputs_.end(), outputs.begin(), outputs.end()); return *this; } @@ -232,6 +242,11 @@ ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { return *this; } +ProgramBase& ProgramBase::UniformVariable(ProgramUniformVariableValue&& variable) { + variables_.emplace_back(variable); + return *this; +} + ProgramBase& ProgramBase::UniformVariables(std::initializer_list variables) { variables_.insert(variables_.end(), variables.begin(), variables.end()); return *this; diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index e162cddbb640..f5f75747dbe5 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -272,8 +272,12 @@ class ProgramBase { return *this; } - // set one or more program inputs + // append a program input + ProgramBase& Input(ProgramInput&& input); + // append multiple program inputs ProgramBase& Inputs(std::initializer_list inputs); + // append a program output + ProgramBase& Output(ProgramOutput&& output); // set one or more program outputs ProgramBase& Outputs(std::initializer_list outputs); @@ -291,7 +295,12 @@ class ProgramBase { // set the size of a workgroup grid. ProgramBase& WorkgroupSize(uint32_t x, uint32_t y, uint32_t z); - // set the uniform variables. + // append a uniform variable. + // + // the specified uniform variable should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& UniformVariable(ProgramUniformVariableValue&& variable); + // append multiple uniform variables. // // the specified uniform variables should match the uniform definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 5cf749dc4c97..a7d751f4472f 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -41,7 +41,7 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& bias_dims, const std::vector& output_dims, @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector& input_data, const std:: execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -142,7 +144,7 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); #endif RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); @@ -245,8 +247,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); } -// CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) +// CUDA, ROCm and WebGPU only for Float16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { int batch_size = 1; int sequence_length = 2; @@ -381,7 +383,10 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); } +#endif +// CUDA and ROCm only for BFloat16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; From 2b3e7c2d81395c4aca7ad0a4f2ffb583e8622051 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:04:39 -0700 Subject: [PATCH 066/114] use 'set/add' as prefix for some functions --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 20 +++++----- .../webgpu/math/unary_elementwise_ops.cc | 22 +++++------ .../webgpu/math/unary_elementwise_ops.h | 2 +- onnxruntime/core/providers/webgpu/program.cc | 34 ++++++++--------- onnxruntime/core/providers/webgpu/program.h | 38 +++++++++---------- .../core/providers/webgpu/shader_helper.h | 2 +- .../core/providers/webgpu/tensor/expand.cc | 16 ++++---- 7 files changed, 67 insertions(+), 67 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 42f056206f3f..40c083c76d33 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -35,11 +35,11 @@ Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { : " x += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; } - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " var x = ", input.GetByOffset("global_idx"), ";\n", - add_bias, - " let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ", - output.SetByOffset("global_idx", "y")); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " var x = ", input.GetByOffset("global_idx"), ";\n", + add_bias, + " let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ", + output.SetByOffset("global_idx", "y")); return Status::OK(); } @@ -67,13 +67,13 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c } FastGeluProgram program{bias_components}; - program.Input({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) - .Output({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .UniformVariable({vec_size}); + program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({vec_size}); if (bias != nullptr) { - program.Input({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) .CacheHint(std::to_string(bias_components)); } return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 8d8f855ec20a..272ff43a68df 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -12,9 +12,9 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_); const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " let a = ", input.GetByOffset("global_idx"), ";\n ", - output.SetByOffset("global_idx", expression_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " let a = ", input.GetByOffset("global_idx"), ";\n ", + output.SetByOffset("global_idx", expression_)); return Status::OK(); } @@ -29,10 +29,10 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { SafeInt vec_size = (size + 3) / 4; UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program - .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) - .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .UniformVariables({ + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ {static_cast(vec_size)}, }); if (!cache_hint.empty()) { @@ -143,7 +143,7 @@ class HardSigmoid final : public UnaryElementwise { } Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { - program.UniformVariables({gsl::make_span(attr, 2)}); + program.AddUniformVariables({gsl::make_span(attr, 2)}); return Status::OK(); } @@ -221,11 +221,11 @@ class Clip final : public UnaryElementwise { if constexpr (std::is_same_v) { // F16: stores span as a single float float encoded_value = *reinterpret_cast(attr); - program.UniformVariables({encoded_value}); + program.AddUniformVariable({encoded_value}); } else { static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32"); // stores span as-is - program.UniformVariables({gsl::make_span(attr, 2)}); + program.AddUniformVariable({gsl::make_span(attr, 2)}); } return Status::OK(); } @@ -276,7 +276,7 @@ class LinearUnit : public UnaryElementwise { } Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { - program.UniformVariables({alpha_}); + program.AddUniformVariables({alpha_}); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index d870278f4c09..2691d67e1f9f 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -49,7 +49,7 @@ class UnaryElementwise : public WebGpuKernel { Status ComputeInternal(ComputeContext& context) const final; virtual Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const { - program.UniformVariables({{}}); // empty for attribute(s) + program.AddUniformVariables({{}}); // empty for attribute(s) return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index b05b576b4bc3..023fa78a4196 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -192,67 +192,67 @@ ProgramBase::ProgramBase(const std::string& name) workgroup_size_z_{0} { } -ProgramBase& ProgramBase::Input(ProgramInput&& input) { +ProgramBase& ProgramBase::AddInput(ProgramInput&& input) { inputs_.emplace_back(input); return *this; } -ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { +ProgramBase& ProgramBase::AddInputs(std::initializer_list inputs) { inputs_.insert(inputs_.end(), inputs.begin(), inputs.end()); return *this; } -ProgramBase& ProgramBase::Output(ProgramOutput&& output) { +ProgramBase& ProgramBase::AddOutput(ProgramOutput&& output) { outputs_.emplace_back(output); return *this; } -ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { +ProgramBase& ProgramBase::AddOutputs(std::initializer_list outputs) { outputs_.insert(outputs_.end(), outputs.begin(), outputs.end()); return *this; } -ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x) { - return DispatchGroupSize(x, 1, 1); +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { + return SetDispatchGroupSize(x, 1, 1); } -ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y) { - return DispatchGroupSize(x, y, 1); +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y) { + return SetDispatchGroupSize(x, y, 1); } -ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { dispatch_group_size_x_ = x; dispatch_group_size_y_ = y; dispatch_group_size_z_ = z; return *this; } -ProgramBase& ProgramBase::WorkgroupSize(uint32_t x) { - return WorkgroupSize(x, 1, 1); +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x) { + return SetWorkgroupSize(x, 1, 1); } -ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y) { - return WorkgroupSize(x, y, 1); +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y) { + return SetWorkgroupSize(x, y, 1); } -ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { workgroup_size_x_ = x; workgroup_size_y_ = y; workgroup_size_z_ = z; return *this; } -ProgramBase& ProgramBase::UniformVariable(ProgramUniformVariableValue&& variable) { +ProgramBase& ProgramBase::AddUniformVariable(ProgramUniformVariableValue&& variable) { variables_.emplace_back(variable); return *this; } -ProgramBase& ProgramBase::UniformVariables(std::initializer_list variables) { +ProgramBase& ProgramBase::AddUniformVariables(std::initializer_list variables) { variables_.insert(variables_.end(), variables.begin(), variables.end()); return *this; } -ProgramBase& ProgramBase::OverridableConstants(std::initializer_list overridable_constants) { +ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list overridable_constants) { overridable_constants_.insert(overridable_constants_.end(), overridable_constants.begin(), overridable_constants.end()); return *this; } diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index f5f75747dbe5..ae3d82a6371d 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -272,45 +272,45 @@ class ProgramBase { return *this; } - // append a program input - ProgramBase& Input(ProgramInput&& input); - // append multiple program inputs - ProgramBase& Inputs(std::initializer_list inputs); - // append a program output - ProgramBase& Output(ProgramOutput&& output); - // set one or more program outputs - ProgramBase& Outputs(std::initializer_list outputs); + // add a program input + ProgramBase& AddInput(ProgramInput&& input); + // add multiple program inputs + ProgramBase& AddInputs(std::initializer_list inputs); + // add a program output + ProgramBase& AddOutput(ProgramOutput&& output); + // add multiple program outputs + ProgramBase& AddOutputs(std::initializer_list outputs); // set the size of dispatch groups. Y and Z are 1 if not specified. - ProgramBase& DispatchGroupSize(uint32_t x); + ProgramBase& SetDispatchGroupSize(uint32_t x); // set the size of dispatch groups. Z is 1 if not specified. - ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y); + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y); // set the size of dispatch groups. - ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); // set the size of a workgroup grid. Y and Z are 1 if not specified. - ProgramBase& WorkgroupSize(uint32_t x); + ProgramBase& SetWorkgroupSize(uint32_t x); // set the size of a workgroup grid. Z is 1 if not specified. - ProgramBase& WorkgroupSize(uint32_t x, uint32_t y); + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y); // set the size of a workgroup grid. - ProgramBase& WorkgroupSize(uint32_t x, uint32_t y, uint32_t z); + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z); - // append a uniform variable. + // add a uniform variable. // // the specified uniform variable should match the uniform definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. - ProgramBase& UniformVariable(ProgramUniformVariableValue&& variable); - // append multiple uniform variables. + ProgramBase& AddUniformVariable(ProgramUniformVariableValue&& variable); + // add multiple uniform variables. // // the specified uniform variables should match the uniform definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. - ProgramBase& UniformVariables(std::initializer_list variables); + ProgramBase& AddUniformVariables(std::initializer_list variables); // set the overridable constants // // the specified overridable constants should match the overridable constant definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS. - ProgramBase& OverridableConstants(std::initializer_list overridable_constants); + ProgramBase& SetOverridableConstants(std::initializer_list overridable_constants); // // shader code generation diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index 08ff111f8a69..811ae3cfa15c 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -102,7 +102,7 @@ class ShaderHelper final { // // can be called only once. template - inline void MainFunctionBody(const Strs&... body) { + inline void SetMainFunctionBody(const Strs&... body) { ORT_ENFORCE(!body_set_, "Main function body is already set"); onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); body_set_ = true; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 82451c939824..45084472d353 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -14,10 +14,10 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderVariable::UseUniform); const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), - "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - "let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n", - output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), + " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", + " let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n ", + output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); return Status::OK(); } @@ -34,10 +34,10 @@ Status Expand::ComputeInternal(ComputeContext& context) const { uint32_t data_size = SafeInt(output_shape.Size()); ExpandProgram program{}; program - .Inputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .Outputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) - .DispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .UniformVariables({ + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ {data_size}, }); return context.RunProgram(program); From ef0d53b78c518e56cbb5d69173bc9f4aa8ace387 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:06:32 -0700 Subject: [PATCH 067/114] remove unnecessary cache hint for FastGelu --- onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 40c083c76d33..f6631025f0b3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -73,8 +73,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c .AddUniformVariable({vec_size}); if (bias != nullptr) { - program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) - .CacheHint(std::to_string(bias_components)); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}); } return context.RunProgram(program); } From c4ca47f763f6e7be8b021576195c44ea8992dc54 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:17:38 -0700 Subject: [PATCH 068/114] revise unary - expose consts in header --- .../webgpu/math/unary_elementwise_ops.cc | 47 +---------------- .../webgpu/math/unary_elementwise_ops.h | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 272ff43a68df..b4b397b2c4b5 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -98,21 +98,6 @@ WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) -constexpr char ErfImpl[] = R"( -const r0 = 0.3275911; -const r1 = 0.254829592; -const r2 = -0.284496736; -const r3 = 1.421413741; -const r4 = -1.453152027; -const r5 = 1.061405429; - -fn erf_v(v: x_value_t) -> x_value_t { - let absv = abs(v); - let x = 1.0 / (1.0 + r0 * absv); - return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); -} -)"; - WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) @@ -125,14 +110,6 @@ WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) -constexpr char HardSigmoidImpl[] = R"( -fn hard_sigmoid_v(v: vec4) -> vec4 { - let alpha = x_element_t(uniforms.attr[0]); - let beta_v = vec4(uniforms.attr[1]); - return max(vec4(0.0), - min(vec4(1.0), alpha * v + beta_v)); -} -)"; class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) @@ -177,14 +154,6 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) -// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) -// https://github.com/gpuweb/gpuweb/issues/4458 -constexpr char TanhImpl[] = R"( -fn tanh_v(a: x_value_t) -> x_value_t { - let expr = exp(-2 * abs(a)); - return sign(a) * (1 - expr) / (1 + expr); -} -)"; WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) @@ -290,17 +259,6 @@ class LinearUnit : public UnaryElementwise { OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \ }; -constexpr char EluImpl[] = R"( -fn elu(a: x_element_t) -> x_element_t { - let alpha = x_element_t(uniforms.attr); - return select((exp(a) - 1.0) * alpha, a, a >= 0.0); -} - -fn elu_v(v: vec4) -> vec4 { - return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); -} -)"; - WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) @@ -309,14 +267,11 @@ class Gelu : public UnaryElementwise { Gelu(const OpKernelInfo& info) : UnaryElementwise{info, "Gelu", - info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhBasedImpl : DefaultImpl, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, ShaderVariable::UseValueTypeAlias} { cache_hint = info.GetAttrOrDefault("approximate", "none"); } - - constexpr static const char DefaultImpl[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; - constexpr static const char TanhBasedImpl[] = "0.5 * a * (1 + tanh_v(0.7978845608028654 * (a + 0.044715 * a * a * a)))"; }; WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 2691d67e1f9f..de85c18da117 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -60,5 +60,55 @@ class UnaryElementwise : public WebGpuKernel { ShaderVariable::Usage additional_usage_; }; +constexpr const char ErfImpl[] = R"( +const r0 = 0.3275911; +const r1 = 0.254829592; +const r2 = -0.284496736; +const r3 = 1.421413741; +const r4 = -1.453152027; +const r5 = 1.061405429; + +fn erf_v(v: x_value_t) -> x_value_t { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +} +)"; + +constexpr const char HardSigmoidImpl[] = R"( +fn hard_sigmoid_v(v: vec4) -> vec4 { + let alpha = x_element_t(uniforms.attr[0]); + let beta_v = vec4(uniforms.attr[1]); + return max(vec4(0.0), + min(vec4(1.0), alpha * v + beta_v)); +} +)"; + +// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) +// https://github.com/gpuweb/gpuweb/issues/4458 +constexpr const char TanhImpl[] = R"( +fn tanh_v(a: x_value_t) -> x_value_t { + let expr = exp(-2 * abs(a)); + return sign(a) * (1 - expr) / (1 + expr); +} +)"; + +constexpr const char EluImpl[] = R"( +fn elu(a: x_element_t) -> x_element_t { + let alpha = x_element_t(uniforms.attr); + return select((exp(a) - 1.0) * alpha, a, a >= 0.0); +} + +fn elu_v(v: vec4) -> vec4 { + return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); +} +)"; + +// default GELU expression, depending on ErfImpl +constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; + +// fast GELU expression, depending on TanhImpl +constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))"; + } // namespace webgpu } // namespace onnxruntime From 8806d57727be2723ff4d84fd33fb9504503fb7b5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:53:24 -0700 Subject: [PATCH 069/114] use path for header file --- onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index f6631025f0b3..7d8bef1e66f4 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "fast_gelu.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/bert/fast_gelu.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" namespace onnxruntime { From 0568e2b6e59e68f5c47265deb3c2a2739804025c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:21:46 -0700 Subject: [PATCH 070/114] a few revises to the code (#22047) --- .../core/providers/webgpu/buffer_manager.cc | 8 ++-- .../providers/webgpu/program_cache_key.cc | 5 +- .../core/providers/webgpu/program_manager.h | 2 +- .../core/providers/webgpu/shader_helper.cc | 46 +++++++++---------- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index da544e1d1ed6..8751338d2417 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -243,10 +243,10 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) : context_{context}, - storage_cache_{std::move(CreateBufferCacheManager(storage_buffer_cache_mode))}, - uniform_cache_{std::move(CreateBufferCacheManager(uniform_buffer_cache_mode))}, - query_resolve_cache_{std::move(CreateBufferCacheManager(query_resolve_buffer_cache_mode))}, - default_cache_{std::move(CreateBufferCacheManager(BufferCacheMode::Disabled))} { + storage_cache_{CreateBufferCacheManager(storage_buffer_cache_mode)}, + uniform_cache_{CreateBufferCacheManager(uniform_buffer_cache_mode)}, + query_resolve_cache_{CreateBufferCacheManager(query_resolve_buffer_cache_mode)}, + default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} { } void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 09a536f7916b..6c7ef2bc89c6 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -10,12 +10,14 @@ namespace webgpu { namespace { // append the info of an input or output to the cachekey -void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) { +void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, + bool& first) { if (first) { first = false; } else { ss << '|'; } + if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { #ifndef NDEBUG // if debug build ss << var_type; @@ -24,6 +26,7 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVaria #endif ss << ';'; } + if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { ss D("Dims=") << tensor.Shape().ToString(); } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 782788910e3a..eded1cfa1797 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -30,7 +30,7 @@ class ProgramArtifact { const std::vector shape_uniform_ranks; ProgramArtifact(ProgramArtifact&&) = default; - ProgramArtifact& operator=(ProgramArtifact&&) = default; + ProgramArtifact& operator=(ProgramArtifact&&) = delete; private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index cd21f4752f30..be89efae5fc9 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -196,6 +196,29 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh } } // namespace +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), + input.use_override_shape, + input.use_override_shape ? input.override_shape : input.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); + + return Status::OK(); +} +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), + output.use_override_shape, + output.use_override_shape ? output.override_shape : output.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); + + return Status::OK(); +} + +#endif // NDEBUG + const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, const std::string& name, ShaderVariable::Usage usage, @@ -224,27 +247,6 @@ const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, return *var; } -Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { - ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); - ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), - input.use_override_shape, - input.use_override_shape ? input.override_shape : input.tensor->Shape(), - var.num_components_)); - ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); - - return Status::OK(); -} -Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { - ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); - ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), - output.use_override_shape, - output.use_override_shape ? output.override_shape : output.tensor->Shape(), - var.num_components_)); - ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); - - return Status::OK(); -} - Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; @@ -304,8 +306,6 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { return Status::OK(); } -#endif - Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { std::ostringstream ss; ss.imbue(std::locale::classic()); From b7a9c0e90a164ce2f97d39ac51a5d6b3e6646a72 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:42:39 -0700 Subject: [PATCH 071/114] use OrtMutex --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 6 +++--- onnxruntime/core/providers/webgpu/webgpu_context.h | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 276d74905adb..01d7704d2be2 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -419,7 +419,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog } std::unordered_map> WebGpuContextFactory::contexts_; -std::mutex WebGpuContextFactory::mutex_; +OrtMutex WebGpuContextFactory::mutex_; WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) { if (context_id == 0) { @@ -432,7 +432,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance "WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device."); } - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); auto it = contexts_.find(context_id); if (it == contexts_.end()) { @@ -446,7 +446,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance } WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); auto it = contexts_.find(context_id); ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index d8b0c2b48b06..2086213e248f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -13,6 +13,7 @@ #include #include "core/common/common.h" +#include "core/platform/ort_mutex.h" #include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/program_manager.h" @@ -34,7 +35,7 @@ class WebGpuContextFactory { WebGpuContextFactory() {} static std::unordered_map> contexts_; - static std::mutex mutex_; + static OrtMutex mutex_; }; // Class WebGpuContext includes all necessary resources for the context. From d4a963d7bf7e9be9b09e41056eda4d6e9a9fe550 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 11 Sep 2024 15:31:11 +0800 Subject: [PATCH 072/114] [webgpu-native] Add transpose op (#21986) Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../core/providers/webgpu/tensor/transpose.cc | 103 ++++++++++++++++++ .../core/providers/webgpu/tensor/transpose.h | 37 +++++++ .../webgpu/webgpu_execution_provider.cc | 6 +- .../providers/webgpu/webgpu_supported_types.h | 6 +- 4 files changed, 146 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/transpose.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/transpose.h diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc new file mode 100644 index 000000000000..68af858d515c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_KERNEL_EX( + Transpose, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +const std::string AppendPermFunction(gsl::span perm) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + ss << "fn perm(i: y_indices_t)->x_indices_t {\n" + " var a: x_indices_t;\n"; + for (auto i = 0; i < perm.size(); ++i) { + ss << " a[" << perm[i] << "] = i[" << i << "];\n"; + } + ss << " return a;\n" + "}\n"; + return ss.str(); +} + +Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); + shader.AppendImplementation(AppendPermFunction(this->perm_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " let indices = ", output.OffsetToIndices("global_idx"), + ";\n" + " let x_indices = perm(indices); \n" + " ", + output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + return Status::OK(); +} + +Status Transpose::ComputeInternal(ComputeContext& context) const { + // TODO: there is an optimized version of transpose to port. + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); + TransposeProgram program{*p_perm}; + program + .CacheHint(absl::StrJoin(*p_perm, "-")) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(output_size)}, + }); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h new file mode 100644 index 000000000000..3ca5674d5dfa --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class TransposeProgram final : public Program { + public: + TransposeProgram(const gsl::span& permutations) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + InlinedVector perm_; +}; + +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index decc74b59cae..ae5b429fb230 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -239,7 +239,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -552,8 +552,8 @@ std::unique_ptr RegisterKernels() { // KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), // KERNEL_CREATE_INFO(16, Where), - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h index fccaef2c5357..ff66cd535399 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace webgpu { -using SupportedTypes = +using SupportedNumberTypes = TypeList< float, MLFloat16, @@ -20,8 +20,8 @@ using SupportedFloats = float, MLFloat16>; -inline const std::vector& WebGpuSupportedDataTypes() { - static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); +inline const std::vector& WebGpuSupportedNumberTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); return supportedDataTypes; } From 8b61532e73f7d04b65c00ebfc719d03e085a9a06 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 16:34:09 -0700 Subject: [PATCH 073/114] PushErrorScope and PopErrorScope --- .../core/providers/webgpu/compute_context.cc | 20 +++++++++++++++++++ .../core/providers/webgpu/compute_context.h | 14 +++++++++++++ .../core/providers/webgpu/webgpu_context.cc | 5 +++++ .../core/providers/webgpu/webgpu_kernel.h | 18 ++++++++--------- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index b7a1af5b26ef..62289b7cd6aa 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -13,5 +13,25 @@ ComputeContext::ComputeContext(OpKernelContext& kernel_context) kernel_context_{kernel_context} { } +void ComputeContext::PushErrorScope() { + webgpu_context_.Device().PushErrorScope(wgpu::ErrorFilter::Validation); +} + +Status ComputeContext::PopErrorScope() { + Status status{}; + + ORT_RETURN_IF_ERROR(webgpu_context_.Wait( + webgpu_context_.Device().PopErrorScope( + wgpu::CallbackMode::WaitAnyOnly, [](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) { + ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped."); + if (error_type == wgpu::ErrorType::NoError) { + return; + } + *status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message); + }, + &status))); + return status; +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 132f629ac745..c98480523ae6 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -106,6 +106,20 @@ class ComputeContext { return webgpu_context_.Run(*this, program); } + // + // Push error scope. + // + // This is useful only when "skip_validation" is not set. + // + void PushErrorScope(); + + // + // Pop error scope. + // + // This is useful only when "skip_validation" is not set. + // + Status PopErrorScope(); + protected: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 01d7704d2be2..ec8f0cda10ee 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -34,7 +34,12 @@ std::vector GetEnabledDeviceToggles() { // Enable / disable other toggles that may affect the performance. // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" constexpr const char* toggles[] = { +#ifdef NDEBUG + // todo: when skip validation, the process may crash. + // need careful decision to enable this toggle. + // revisit this flag before release. "skip_validation", +#endif "disable_robustness", "disable_workgroup_init", "d3d_disable_ieee_strictness", diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 6486987501d1..72fea52313f9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -22,16 +22,14 @@ class WebGpuKernel : public OpKernel { Status Compute(OpKernelContext* p_op_kernel_context) const override { ComputeContext context{*p_op_kernel_context}; - auto s = ComputeInternal(context); - // use this to precisely locate the node where CUDA failure comes from - // if (cudaSuccess != cudaDeviceSynchronize()) - // __debugbreak(); - // if (s.IsOK()) { - // auto err = cudaGetLastError(); - // if (err != cudaSuccess) { - // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); - // } - // } +#ifndef NDEBUG + context.PushErrorScope(); +#endif + Status s = ComputeInternal(context); +#ifndef NDEBUG + ORT_RETURN_IF_ERROR(context.PopErrorScope()); +#endif + return s; } From dce0f181a272668fe459915374ca1cc1424525ff Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:05:21 -0700 Subject: [PATCH 074/114] placeholder for setting proc table --- .../webgpu/webgpu_provider_factory.cc | 18 ++++++++++++++++-- .../providers/webgpu/webgpu_provider_options.h | 2 ++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 3848ccfc19f5..b03bddf408b6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -3,6 +3,8 @@ #include +#include + #include "core/framework/error_code_helper.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" @@ -33,7 +35,19 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { // - // STEP.1 - prepare WebGpuExecutionProviderInfo + // STEP.1 - set dawn proc table + // + std::string dawn_proc_table_str; + if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + size_t dawn_proc_table = 0; + ORT_ENFORCE(std::errc{} == + std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); + // TODO: do this for static link build + // dawnProcSetProcs(reinterpret_cast(dawn_proc_table)); + } + + // + // STEP.2 - prepare WebGpuExecutionProviderInfo // WebGpuExecutionProviderInfo webgpu_ep_info{ // preferred layout is NHWC by default @@ -100,7 +114,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; // - // STEP.2 - prepare WebGpuContext + // STEP.3 - prepare WebGpuContext // int context_id = 0; std::string context_id_str; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 65ccbd800b12..334f21c737af 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -17,6 +17,8 @@ constexpr const char* kWebGpuInstance = "webgpuInstance"; constexpr const char* kWebGpuAdapter = "webgpuAdapter"; constexpr const char* kWebGpuDevice = "webgpuDevice"; +constexpr const char* kDawnProcTable = "dawnProcTable"; + constexpr const char* kStorageBufferCacheMode = "storageBufferCacheMode"; constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; From 8978d8954bf343f2d6ed5426b2406812e69fc82e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 20:21:55 -0700 Subject: [PATCH 075/114] Revert "placeholder for setting proc table" This reverts commit dce0f181a272668fe459915374ca1cc1424525ff. --- .../webgpu/webgpu_provider_factory.cc | 18 ++---------------- .../providers/webgpu/webgpu_provider_options.h | 2 -- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index b03bddf408b6..3848ccfc19f5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -3,8 +3,6 @@ #include -#include - #include "core/framework/error_code_helper.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" @@ -35,19 +33,7 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { // - // STEP.1 - set dawn proc table - // - std::string dawn_proc_table_str; - if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { - size_t dawn_proc_table = 0; - ORT_ENFORCE(std::errc{} == - std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); - // TODO: do this for static link build - // dawnProcSetProcs(reinterpret_cast(dawn_proc_table)); - } - - // - // STEP.2 - prepare WebGpuExecutionProviderInfo + // STEP.1 - prepare WebGpuExecutionProviderInfo // WebGpuExecutionProviderInfo webgpu_ep_info{ // preferred layout is NHWC by default @@ -114,7 +100,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; // - // STEP.3 - prepare WebGpuContext + // STEP.2 - prepare WebGpuContext // int context_id = 0; std::string context_id_str; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 334f21c737af..65ccbd800b12 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -17,8 +17,6 @@ constexpr const char* kWebGpuInstance = "webgpuInstance"; constexpr const char* kWebGpuAdapter = "webgpuAdapter"; constexpr const char* kWebGpuDevice = "webgpuDevice"; -constexpr const char* kDawnProcTable = "dawnProcTable"; - constexpr const char* kStorageBufferCacheMode = "storageBufferCacheMode"; constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; From 43ccaf45b6a791a8acb8b2e323a1cc6a38d33b13 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:34:24 -0700 Subject: [PATCH 076/114] allow setting "ValidationMode" --- onnxruntime/core/providers/webgpu/program.h | 7 + .../core/providers/webgpu/webgpu_context.cc | 242 +++++++++--------- .../core/providers/webgpu/webgpu_context.h | 21 +- .../webgpu/webgpu_execution_provider.cc | 3 +- .../webgpu/webgpu_execution_provider.h | 6 +- .../webgpu/webgpu_provider_factory.cc | 25 +- .../webgpu/webgpu_provider_options.h | 7 + 7 files changed, 186 insertions(+), 125 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index ae3d82a6371d..0daf24766136 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -253,6 +253,13 @@ struct ProgramOutput { TensorShape override_shape; }; +enum class ValidationMode { + Disabled = 0, + WGPUOnly, + Basic, + Full +}; + namespace detail { class ProgramWrapper; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ec8f0cda10ee..11a337cd3e37 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -17,79 +17,6 @@ namespace onnxruntime { namespace webgpu { -namespace { - -std::vector GetEnabledAdapterToggles() { - // See the description of all the toggles in toggles.cpp - // "use_dxc" for Shader Model 6+ features (e.g. float16) - // "allow_unsafe_apis" for chromium experimental features - constexpr const char* toggles[] = { - "use_dxc", - "allow_unsafe_apis", - }; - return std::vector(std::begin(toggles), std::end(toggles)); -} - -std::vector GetEnabledDeviceToggles() { - // Enable / disable other toggles that may affect the performance. - // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" - constexpr const char* toggles[] = { -#ifdef NDEBUG - // todo: when skip validation, the process may crash. - // need careful decision to enable this toggle. - // revisit this flag before release. - "skip_validation", -#endif - "disable_robustness", - "disable_workgroup_init", - "d3d_disable_ieee_strictness", - }; - return std::vector(std::begin(toggles), std::end(toggles)); -} - -std::vector GetDisabledDeviceToggles() { - constexpr const char* toggles[] = { - "lazy_clear_resource_on_first_use", - }; - return std::vector(std::begin(toggles), std::end(toggles)); -} - -std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) { - std::vector required_features; - constexpr wgpu::FeatureName features[]{ - wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, - wgpu::FeatureName::TimestampQuery, - wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::Subgroups, - wgpu::FeatureName::SubgroupsF16}; - for (auto feature : features) { - if (adapter.HasFeature(feature)) { - required_features.push_back(feature); - } - } - return required_features; -} - -wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) { - wgpu::RequiredLimits required_limits{}; - wgpu::SupportedLimits adapter_limits; - ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); - - required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; - required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; - required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; - required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; - required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; - required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; - required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; - required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; - required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; - - return required_limits; -} - -} // namespace - void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info) { std::call_once(init_flag_, [this, &webgpu_ep_info]() { // Initialization.Step.1 - Create wgpu::Instance @@ -194,34 +121,34 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); -#ifndef NDEBUG // if debug build - ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { - const auto* tensor = input.tensor; - return tensor != nullptr && - tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && - tensor->Location().device.Type() == OrtDevice::GPU && - !strcmp(tensor->Location().name, WEBGPU_BUFFER); - }), - "All inputs must be tensors on WebGPU buffers."); - - ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { - const auto* tensor = output.tensor; - return tensor != nullptr && - tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && - tensor->Location().device.Type() == OrtDevice::GPU && - !strcmp(tensor->Location().name, WEBGPU_BUFFER); - }), - "All outputs must be tensors on WebGPU buffers."); -#endif - if (outputs.size() == 0) { return Status::OK(); } + if (ValidationMode() >= ValidationMode::Basic) { + ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { + const auto* tensor = input.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All inputs must be tensors on WebGPU buffers."); + + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { + const auto* tensor = output.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All outputs must be tensors on WebGPU buffers."); + } + const ProgramMetadata metadata = program.GetMetadata(); // validate program metadata - { + if (ValidationMode() >= ValidationMode::Basic) { const auto& [constants, overridable_constants, uniform_variables] = metadata; // check overridable constants @@ -229,17 +156,20 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog "Size of overridable constants mismatch in program \"", program.Name(), "\", Expected: ", overridable_constants.size(), ", Actual: ", program.OverridableConstants().size()); - size_t num_overridable_constants = program.OverridableConstants().size(); - for (size_t i = 0; i < num_overridable_constants; ++i) { - const auto& override_value = program.OverridableConstants()[i]; - const auto& definition = overridable_constants[i]; - ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, - "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), - "\", Expected: ", definition.type, - ", Actual: ", override_value.type); - ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, - "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), - "\""); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_overridable_constants = program.OverridableConstants().size(); + for (size_t i = 0; i < num_overridable_constants; ++i) { + const auto& override_value = program.OverridableConstants()[i]; + const auto& definition = overridable_constants[i]; + ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, + "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.type, + ", Actual: ", override_value.type); + ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, + "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), + "\""); + } } // check uniform variables @@ -247,14 +177,17 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog "Size of uniform_value variables mismatch in program \"", program.Name(), "\", Expected: ", uniform_variables.size(), ", Actual: ", program.UniformVariables().size()); - size_t num_uniform_variables = program.UniformVariables().size(); - for (size_t i = 0; i < num_uniform_variables; ++i) { - const auto& uniform_value = program.UniformVariables()[i]; - const auto& definition = uniform_variables[i]; - ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, - "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), - "\", Expected: ", definition.data_type, - ", Actual: ", uniform_value.data_type); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_uniform_variables = program.UniformVariables().size(); + for (size_t i = 0; i < num_uniform_variables; ++i) { + const auto& uniform_value = program.UniformVariables()[i]; + const auto& definition = uniform_variables[i]; + ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, + "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.data_type, + ", Actual: ", uniform_value.data_type); + } } } @@ -295,9 +228,11 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog // prepare shape uniforms for shader variables (if any) and user defined uniforms std::vector shape_uniforms; shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); - ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), - "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), - ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + if (ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), + "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), + ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + } for (size_t i = 0; i < program_artifact->shape_uniform_ranks.size(); ++i) { SafeInt expected_rank = program_artifact->shape_uniform_ranks[i]; if (expected_rank > 0) { @@ -423,10 +358,81 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return Status::OK(); } +std::vector WebGpuContext::GetEnabledAdapterToggles() const { + // See the description of all the toggles in toggles.cpp + // "use_dxc" for Shader Model 6+ features (e.g. float16) + // "allow_unsafe_apis" for chromium experimental features + constexpr const char* toggles[] = { + "use_dxc", + "allow_unsafe_apis", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetEnabledDeviceToggles() const { + // Enable / disable other toggles that may affect the performance. + // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" + constexpr const char* toggles[] = { + "skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled" + "disable_robustness", + "disable_workgroup_init", + "d3d_disable_ieee_strictness", + }; + return std::vector(ValidationMode() >= ValidationMode::WGPUOnly + ? std::begin(toggles) + 1 + : std::begin(toggles), + std::end(toggles)); +} + +std::vector WebGpuContext::GetDisabledDeviceToggles() const { + constexpr const char* toggles[] = { + "lazy_clear_resource_on_first_use", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const { + wgpu::RequiredLimits required_limits{}; + wgpu::SupportedLimits adapter_limits; + ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); + + required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; + required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; + required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + std::unordered_map> WebGpuContextFactory::contexts_; OrtMutex WebGpuContextFactory::mutex_; -WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) { +WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode) { if (context_id == 0) { // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, @@ -441,7 +447,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance auto it = contexts_.find(context_id); if (it == contexts_.end()) { - auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device)); + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, validation_mode)); it = contexts_.emplace(context_id, std::move(context)).first; } else if (context_id != 0) { ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 2086213e248f..3251364e85ce 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -28,7 +28,11 @@ class ProgramBase; class WebGpuContextFactory { public: - static WebGpuContext& CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device); + static WebGpuContext& CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode); static WebGpuContext& GetContext(int context_id); private: @@ -95,18 +99,31 @@ class WebGpuContext final { webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } + inline webgpu::ValidationMode ValidationMode() const { + return validation_mode_; + } + Status Run(const ComputeContext& context, const ProgramBase& program); private: - WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) : instance_{instance}, adapter_{adapter}, device_{device} {} + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode) + : instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode} {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + std::vector WebGpuContext::GetEnabledAdapterToggles() const; + std::vector WebGpuContext::GetEnabledDeviceToggles() const; + std::vector WebGpuContext::GetDisabledDeviceToggles() const; + std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const; + wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const; + std::once_flag init_flag_; wgpu::Instance instance_; wgpu::Adapter adapter_; wgpu::Device device_; + webgpu::ValidationMode validation_mode_; + wgpu::AdapterInfo adapter_info_; wgpu::Limits device_limits_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index ae5b429fb230..d049cbbf6456 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -23,7 +23,8 @@ #include "core/framework/kernel_registry.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" -#include "data_transfer.h" + +#include "core/providers/webgpu/data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 5f27fad14afc..db9de9dc4933 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -22,9 +22,9 @@ enum class BufferCacheMode; } // namespace webgpu struct WebGpuExecutionProviderInfo { - WebGpuExecutionProviderInfo(DataLayout data_layout1, bool enable_graph_capture1) - : data_layout{data_layout1}, - enable_graph_capture{enable_graph_capture1}, + WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture) + : data_layout{data_layout}, + enable_graph_capture{enable_graph_capture}, storage_buffer_cache_mode{}, uniform_buffer_cache_mode{}, query_resolve_buffer_cache_mode{}, diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 3848ccfc19f5..4ceaa0623859 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -99,6 +99,28 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; + webgpu::ValidationMode validation_mode = +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::WGPUOnly // for release build, only enable WGPU validation. +#endif // !NDEBUG + ; + std::string validation_mode_str; + if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (validation_mode_str == kValidationMode_Disabled) { + validation_mode = webgpu::ValidationMode::Disabled; + } else if (validation_mode_str == kValidationMode_wgpuOnly) { + validation_mode = webgpu::ValidationMode::WGPUOnly; + } else if (validation_mode_str == kValidationMode_basic) { + validation_mode = webgpu::ValidationMode::Basic; + } else if (validation_mode_str == kValidationMode_full) { + validation_mode = webgpu::ValidationMode::Full; + } else { + ORT_THROW("Invalid validation mode: ", validation_mode_str); + } + } + // // STEP.2 - prepare WebGpuContext // @@ -136,7 +158,8 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, reinterpret_cast(webgpu_instance), reinterpret_cast(webgpu_adapter), - reinterpret_cast(webgpu_device)); + reinterpret_cast(webgpu_device), + validation_mode); context.Initialize(webgpu_ep_info); return std::make_shared(context_id, context, webgpu_ep_info); diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 65ccbd800b12..ebbca55a8c70 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -22,6 +22,8 @@ constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; constexpr const char* kDefaultBufferCacheMode = "defaultBufferCacheMode"; +constexpr const char* kValidationMode = "validationMode"; + // The following are the possible values for the provider options. constexpr const char* kPreferredLayout_NCHW = "NCHW"; @@ -35,6 +37,11 @@ constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; constexpr const char* kBufferCacheMode_Simple = "simple"; constexpr const char* kBufferCacheMode_Bucket = "bucket"; +constexpr const char* kValidationMode_Disabled = "disabled"; +constexpr const char* kValidationMode_wgpuOnly = "wgpuOnly"; +constexpr const char* kValidationMode_basic = "basic"; +constexpr const char* kValidationMode_full = "full"; + } // namespace options } // namespace webgpu } // namespace onnxruntime From eae4c3f22937b2bc40f11a32a7c6ac5094c3cc3e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 12 Sep 2024 21:01:45 -0700 Subject: [PATCH 077/114] make shape/stride correct when component != 1 --- onnxruntime/core/providers/webgpu/program.cc | 51 +++++++++++++++++++ onnxruntime/core/providers/webgpu/program.h | 34 +++---------- .../core/providers/webgpu/shader_helper.cc | 4 -- 3 files changed, 57 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 023fa78a4196..f12f6fb8a01c 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -182,6 +182,57 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp } } +namespace { +TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { + ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, + "Cannot reduce shape ", shape.ToString(), " by component=", component); + TensorShape reduced_shape = shape; + reduced_shape[reduced_shape.NumDimensions() - 1] /= component; + return reduced_shape; +} +} // namespace + +ProgramInput::ProgramInput(const Tensor* tensor) : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramOutput::ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + ProgramBase::ProgramBase(const std::string& name) : name_{name}, dispatch_group_size_x_{0}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 0daf24766136..2a2d4160e161 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -208,20 +208,9 @@ int NumberOfComponents(ProgramVariableDataType type); ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); struct ProgramInput { - ProgramInput(const Tensor* tensor) - : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{false}, - override_shape{} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{true}, - override_shape{override_shape} {} + ProgramInput(const Tensor* tensor); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); const Tensor* tensor; ProgramTensorMetadataDependency dependency; @@ -231,20 +220,9 @@ struct ProgramInput { }; struct ProgramOutput { - ProgramOutput(Tensor* tensor) - : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{false}, - override_shape{} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{true}, - override_shape{override_shape} {} + ProgramOutput(Tensor* tensor); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); Tensor* tensor; ProgramTensorMetadataDependency dependency; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index be89efae5fc9..64ed98c78507 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -156,10 +156,6 @@ Status ValidateVariableShape(const TensorShape& origin_shape, // if override shape specified, assert override_size == ceil( origin_size / 4 ) ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(), "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); - } else if (num_components > 1) { - // if shape is not overriden, assert origin_shape[-1] % 4 == 0 - ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.NumDimensions() - 1] % num_components == 0, - "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); } return Status::OK(); From b8c369d9c89b02d6517f6026c49d3063453fe999 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 12 Sep 2024 23:02:44 -0700 Subject: [PATCH 078/114] expose number of components --- onnxruntime/core/providers/webgpu/shader_variable.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index d4281dd31d65..71822a61f7a7 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -63,6 +63,9 @@ class ShaderVariable { // get the name of the variable. inline std::string_view Name() const { return name_; } + // get the number of components of the variable. + inline int NumComponents() const { return num_components_; } + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. inline std::string OffsetToIndices(std::string_view offset_expr) const; From c3086d693c7b6c1d6461fb9f97c2bc5d3d5759a0 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 13 Sep 2024 16:53:32 +1000 Subject: [PATCH 079/114] Fix build errors --- onnxruntime/core/providers/webgpu/webgpu_context.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 3251364e85ce..f74dda38fca0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -110,11 +110,11 @@ class WebGpuContext final { : instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode} {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); - std::vector WebGpuContext::GetEnabledAdapterToggles() const; - std::vector WebGpuContext::GetEnabledDeviceToggles() const; - std::vector WebGpuContext::GetDisabledDeviceToggles() const; - std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const; - wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const; + std::vector GetEnabledAdapterToggles() const; + std::vector GetEnabledDeviceToggles() const; + std::vector GetDisabledDeviceToggles() const; + std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const; + wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) const; std::once_flag init_flag_; From c5cf2abe24c37031fbb883e0fd101822bf81ef71 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Fri, 13 Sep 2024 17:36:30 -0700 Subject: [PATCH 080/114] [WebGPU EP] Support Shape operator (#22095) ### Description Shape operator ### Motivation and Context --- .../core/providers/webgpu/tensor/shape_op.cc | 78 +++++++++++++++++++ .../webgpu/webgpu_execution_provider.cc | 5 +- 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/shape_op.cc diff --git a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc new file mode 100644 index 000000000000..b211d48dab1c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 13, 14, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 15, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_KERNEL_EX( + Shape, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index d049cbbf6456..444f07e1664b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -222,7 +222,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, 18, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Shape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 5, 12, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); From 0bc714fd4333c7a8eb0202ec76bbab717125acee Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 16 Sep 2024 23:11:54 -0700 Subject: [PATCH 081/114] [webgpu EP] Binary operators (#22112) based on: - #22058 --------- Co-authored-by: Qin Jiajia --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 6 +- .../webgpu/math/binary_elementwise_ops.cc | 311 ++++++++++++++++++ .../webgpu/math/binary_elementwise_ops.h | 56 ++++ .../webgpu/math/unary_elementwise_ops.cc | 18 +- .../webgpu/math/unary_elementwise_ops.h | 16 +- onnxruntime/core/providers/webgpu/program.cc | 12 +- onnxruntime/core/providers/webgpu/program.h | 18 +- .../core/providers/webgpu/program_manager.cc | 4 +- .../core/providers/webgpu/shader_helper.cc | 194 ++++++----- .../core/providers/webgpu/shader_helper.h | 31 +- .../core/providers/webgpu/shader_variable.cc | 57 ++-- .../core/providers/webgpu/shader_variable.h | 159 +++++---- .../core/providers/webgpu/tensor/expand.cc | 4 +- .../core/providers/webgpu/tensor/transpose.cc | 4 +- .../core/providers/webgpu/webgpu_context.cc | 26 +- .../webgpu/webgpu_execution_provider.cc | 60 ++-- .../cpu/math/element_wise_ops_test.cc | 22 ++ 17 files changed, 736 insertions(+), 262 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc create mode 100644 onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 7d8bef1e66f4..50debe26ce45 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -20,12 +20,12 @@ ONNX_OPERATOR_KERNEL_EX( FastGelu); Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); std::string add_bias = ""; if (Inputs().size() > 1) { - const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride); + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n" " x += input_value_t(" + bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " + diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc new file mode 100644 index 000000000000..9d9eff2ccdde --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -0,0 +1,311 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/webgpu/math/binary_elementwise_ops.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + std::string common; + std::string get_a_data = is_lhs_scalar_ ? "let a = input_a_value_t(" + a.GetByOffset("0") + ".x" + ");\n" + : "let a = " + a.GetByOffset("global_idx") + ";\n"; + std::string get_b_data = is_rhs_scalar_ ? "let b = input_b_value_t(" + b.GetByOffset("0") + ".x" + ");\n" + : "let b = " + b.GetByOffset("global_idx") + ";\n"; + // check whether can use element-wise mode. + // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. + // In element-wise mode, no indices calculation is needed. + if (!is_lhs_scalar_ && !is_rhs_scalar_ && is_broadcast_) { + const auto& c_indices = shader.AddIndices("bcast_indices"); + // check whether can use vectorize mode. + // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode + // can be enabled. + // In vectorize mode, the source data of A and B will be loaded only once to calculate 4 output values. + // Use indices helpers to calculate the offset of A and B. + if (vectorize_) { + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + common = "let outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + + ";\n" + "let offset_a = " + + a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b = " + + b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; + get_a_data = a.NumComponents() == 4 ? "let a = " + a.GetByOffset("offset_a / 4") + ";\n" + : "let a = input_b_value_t(" + a.GetByOffset("offset_a") + ");\n"; + get_b_data = b.NumComponents() == 4 ? "let b = " + b.GetByOffset("offset_b / 4") + ";\n" + : "let b = input_a_value_t(" + b.GetByOffset("offset_b") + ");\n"; + } else { + // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. + common = "var outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + + ";\n" + "let offset_a0 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b0 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "outputIndices = " + + c_indices.OffsetToIndices("global_idx * 4 + 1") + + ";\n" + "let offset_a1 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b1 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "outputIndices = " + + c_indices.OffsetToIndices("global_idx * 4 + 2") + + ";\n" + "let offset_a2 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b2 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "outputIndices = " + + c_indices.OffsetToIndices("global_idx * 4 + 3") + + ";\n" + "let offset_a3 = " + + a.BroadcastedIndicesToOffset("outputIndices", c_indices) + + ";\n" + "let offset_b3 = " + + b.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; + get_a_data = "let a = vec4(" + a.GetByOffset("offset_a0") + ", " + + a.GetByOffset("offset_a1") + ", " + + a.GetByOffset("offset_a2") + ", " + + a.GetByOffset("offset_a3") + ");\n"; + get_b_data = "let b = vec4(" + b.GetByOffset("offset_b0") + ", " + + b.GetByOffset("offset_b1") + ", " + + b.GetByOffset("offset_b2") + ", " + + b.GetByOffset("offset_b3") + ");\n"; + } + } + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + common, get_a_data, get_b_data, + c.SetByOffset("global_idx", expression_)); + return Status::OK(); +} + +Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { + auto lhs_tensor = context.Input(0); + auto rhs_tensor = context.Input(1); + const auto& lhs_shape = lhs_tensor->Shape(); + const auto& rhs_shape = rhs_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); + auto output_tensor = context.Output(0, output_shape); + int64_t size = output_shape.Size(); + if (size == 0) { + return Status::OK(); + } + + bool is_broadcast = lhs_shape != rhs_shape; + bool is_lhs_scalar = lhs_shape.IsScalar(); + bool is_rhs_scalar = rhs_shape.IsScalar(); + + bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast; + bool a_last_dim_divisible_by_4 = false; + bool b_last_dim_divisible_by_4 = false; + bool shared_dimension_divisible_by_4 = false; + size_t num_shared_dimension = 0; + if (!vectorize) { + // check whether vectorize can be enabled + a_last_dim_divisible_by_4 = lhs_shape.NumDimensions() > 0 && lhs_shape[lhs_shape.NumDimensions() - 1] % 4 == 0; + b_last_dim_divisible_by_4 = rhs_shape.NumDimensions() > 0 && rhs_shape[rhs_shape.NumDimensions() - 1] % 4 == 0; + if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) { + vectorize = true; + } else { + size_t shared_dimension = 1; + for (size_t i = 1; i < output_shape.NumDimensions(); i++) { + size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; + size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; + if (dimA == dimB) { + shared_dimension *= dimA; + num_shared_dimension++; + } else { + break; + } + } + if (shared_dimension % 4 == 0) { + shared_dimension_divisible_by_4 = true; + vectorize = true; + } + } + } + + SafeInt vec_size = (size + 3) / 4; + BinaryElementwiseProgram program{kernel_name_, + expression_, + is_broadcast, + is_lhs_scalar, + is_rhs_scalar, + vectorize}; + program + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}); + + if (is_lhs_scalar || is_rhs_scalar || !is_broadcast) { + // Mode Element-wise + // cache hint: "E{is_a_scalar}{is_b_scalar}" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Type, {is_lhs_scalar ? 1 : vec_size}, 4}, + {rhs_tensor, ProgramTensorMetadataDependency::Type, {is_rhs_scalar ? 1 : vec_size}, 4}}) + .CacheHint("E" + std::to_string(is_lhs_scalar) + std::to_string(is_rhs_scalar)); + } else if (vectorize) { + // reshape the dims to merge the shared dimension if available + bool need_reshape = shared_dimension_divisible_by_4 && num_shared_dimension > 1; + TensorShape reshaped_lhs_shape = need_reshape ? lhs_shape.Slice(0, lhs_shape.NumDimensions() - num_shared_dimension + 1) + : lhs_shape; + TensorShape reshaped_rhs_shape = need_reshape ? rhs_shape.Slice(0, rhs_shape.NumDimensions() - num_shared_dimension + 1) + : rhs_shape; + TensorShape reshaped_output_shape = need_reshape ? output_shape.Slice(0, output_shape.NumDimensions() - num_shared_dimension + 1) + : output_shape; + if (need_reshape) { + reshaped_lhs_shape[reshaped_lhs_shape.NumDimensions() - 1] = lhs_shape.SizeFromDimension(lhs_shape.NumDimensions() - num_shared_dimension); + reshaped_rhs_shape[reshaped_rhs_shape.NumDimensions() - 1] = rhs_shape.SizeFromDimension(rhs_shape.NumDimensions() - num_shared_dimension); + reshaped_output_shape[reshaped_output_shape.NumDimensions() - 1] = output_shape.SizeFromDimension(output_shape.NumDimensions() - num_shared_dimension); + } + + if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4) { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type, {(lhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type}); + } + if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4) { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type, {(rhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type}); + } + // Mode Vectorize broadcast + // cache hint: "V{a_rank};{b_rank};{output_rank}" + program + .AddIndices(reshaped_output_shape) + .AddIndices(reshaped_lhs_shape) + .AddIndices(reshaped_rhs_shape) + .CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(), + reshaped_rhs_shape.NumDimensions(), + reshaped_output_shape.NumDimensions()}, + ";")); + } else { + // Mode Broadcast + // cache hint: "B" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddIndices(output_tensor->Shape()) + .CacheHint("B"); + } + + return context.RunProgram(program); +} + +#define WEBGPU_BINARY_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public BinaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : BinaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_BINARY_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_KERNEL_2(OP_TYPE, VERSION, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +WEBGPU_BINARY_IMPL(Add, "a + b") +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 7, 12, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 13, 13, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Add, 14, Add, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Div, "a / b") +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 7, 12, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 13, 13, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Div, 14, Div, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Mul, "a * b") +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 7, 12, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 13, 13, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Mul, 14, Mul, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Sub, "a - b") +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(output_value_t(a), output_value_t(b)))") +WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL_2(Pow, 15, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Equal, "vec4(a == b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 7, 10, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 11, 12, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 13, 18, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Equal, 19, Equal, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Greater, "vec4(a > b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 7, 8, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 9, 12, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Greater, 13, Greater, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Less, "vec4(a < b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 7, 8, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 9, 12, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Less, 13, Less, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(GreaterOrEqual, "vec4(a >= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(GreaterOrEqual, 16, GreaterOrEqual, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(LessOrEqual, "vec4(a <= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(LessOrEqual, 16, LessOrEqual, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h new file mode 100644 index 000000000000..84cbcdf3244d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class BinaryElementwiseProgram final : public Program { + public: + BinaryElementwiseProgram(const std::string& kernel_name, + const std::string& expression, + const bool is_broadcast, + const bool is_lhs_scalar, + const bool is_rhs_scalar, + const bool vectorize) : Program{kernel_name}, + expression_{expression}, + is_broadcast_{is_broadcast}, + is_lhs_scalar_{is_lhs_scalar}, + is_rhs_scalar_{is_rhs_scalar}, + vectorize_{vectorize} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + std::string expression_; + bool is_broadcast_; + bool is_lhs_scalar_; + bool is_rhs_scalar_; + bool vectorize_; +}; + +class BinaryElementwise : public WebGpuKernel { + public: + BinaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression} {} + + protected: + Status ComputeInternal(ComputeContext& context) const final; + + private: + std::string kernel_name_; + std::string expression_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index b4b397b2c4b5..870dd3df24c7 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -9,8 +9,8 @@ namespace onnxruntime { namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_); - const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform); + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform); shader.AppendImplementation(additional_impl_); shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), " let a = ", input.GetByOffset("global_idx"), ";\n ", @@ -98,7 +98,7 @@ WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) @@ -113,7 +113,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) - : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias} { + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} { // attr[0] is alpha, attr[1] is beta info.GetAttrOrDefault("alpha", attr, 0.2f); info.GetAttrOrDefault("beta", attr + 1, 0.5f); @@ -154,7 +154,7 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) @@ -180,7 +180,7 @@ class Clip final : public UnaryElementwise { : UnaryElementwise{info, "Clip", std::is_same_v ? ClipF16Impl : ClipImpl, - "", ShaderVariable::UseElementTypeAlias} {} + "", ShaderUsage::UseElementTypeAlias} {} Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { const auto* clip_min_tensor = context.Input(1); @@ -240,7 +240,7 @@ class LinearUnit : public UnaryElementwise { const std::string& expression, const std::string& additional_impl, float default_alpha) - : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} { + : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} { info.GetAttrOrDefault("alpha", &alpha_, default_alpha); } @@ -269,14 +269,14 @@ class Gelu : public UnaryElementwise { "Gelu", info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, - ShaderVariable::UseValueTypeAlias} { + ShaderUsage::UseValueTypeAlias} { cache_hint = info.GetAttrOrDefault("approximate", "none"); } }; WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index de85c18da117..70fa81d21f95 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -12,7 +12,7 @@ namespace webgpu { class UnaryElementwiseProgram final : public Program { public: - UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderVariable::Usage usage) + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage) : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} { } @@ -26,7 +26,7 @@ class UnaryElementwiseProgram final : public Program { private: std::string_view expression_; std::string_view additional_impl_; - ShaderVariable::Usage additional_usage_; + ShaderUsage additional_usage_; }; // TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch @@ -38,11 +38,11 @@ class UnaryElementwise : public WebGpuKernel { const std::string& kernel_name, const std::string& expression, const std::string& additional_impl = "", - ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info}, - kernel_name_{kernel_name}, - expression_{expression}, - additional_impl_{additional_impl}, - additional_usage_{usage} {} + ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl}, + additional_usage_{usage} {} protected: std::string cache_hint; @@ -57,7 +57,7 @@ class UnaryElementwise : public WebGpuKernel { std::string kernel_name_; std::string expression_; std::string additional_impl_; - ShaderVariable::Usage additional_usage_; + ShaderUsage additional_usage_; }; constexpr const char ErfImpl[] = R"( diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index f12f6fb8a01c..21c63f75d26d 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -263,6 +263,16 @@ ProgramBase& ProgramBase::AddOutputs(std::initializer_list output return *this; } +ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) { + indices_.emplace_back(shape); + return *this; +} + ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { return SetDispatchGroupSize(x, 1, 1); } @@ -309,4 +319,4 @@ ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list WORKGROUP_SIZE = 64; -// represents the scope of a variable in a shader program. -// -// this is not a full list of all possible variable scopes in shader programs. -// it only includes what are used in WebGPU EP. -enum class ProgramVariableScope { - Input = 0, // storage buffer variable with access mode "read" - Output = 1, // storage buffer variable with access mode "read_write" - Local = 2, // local variable - - Count // should always be the last element -}; - // data type of variable // // this is not a full list of all possible data types in shader programs. @@ -265,6 +253,10 @@ class ProgramBase { ProgramBase& AddOutput(ProgramOutput&& output); // add multiple program outputs ProgramBase& AddOutputs(std::initializer_list outputs); + // add a program variable for indices + ProgramBase& AddIndices(const TensorShape& shape); + // add a program variable for indices + ProgramBase& AddIndices(TensorShape&& shape); // set the size of dispatch groups. Y and Z are 1 if not specified. ProgramBase& SetDispatchGroupSize(uint32_t x); @@ -330,6 +322,7 @@ class ProgramBase { inline const std::string& CacheHint() const { return cache_hint_; } inline const std::vector& Inputs() const { return inputs_; } inline const std::vector& Outputs() const { return outputs_; } + inline const std::vector& Indices() const { return indices_; } inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } @@ -351,6 +344,7 @@ class ProgramBase { std::string cache_hint_; std::vector inputs_; std::vector outputs_; + std::vector indices_; uint32_t dispatch_group_size_x_; uint32_t dispatch_group_size_y_; diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 3e4fbd33a6bd..297d211ff126 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -60,7 +60,9 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); - ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputsAndOutputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices()); // code is a large std::string that contains the final shader code std::string code; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 64ed98c78507..c229e821cbf8 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -78,24 +78,33 @@ Status ShaderHelper::Init() { return Status::OK(); } -const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ShaderVariable::Usage usage) { - const size_t input_index = vars_[std::underlying_type::type(ProgramVariableScope::Input)].size(); +const ShaderVariableHelper& ShaderHelper::AddInput(const std::string& name, ShaderUsage usage) { + const size_t input_index = input_vars_.size(); ORT_ENFORCE(input_index < program_.Inputs().size(), "Too many inputs in the program (", program_.Inputs().size(), ")"); const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape : program_.Inputs()[input_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Input, name, usage, dims); + return AddVariableImpl(true, name, usage, dims); } -const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ShaderVariable::Usage usage) { - const size_t output_index = vars_[std::underlying_type::type(ProgramVariableScope::Output)].size(); +const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, ShaderUsage usage) { + const size_t output_index = output_vars_.size(); ORT_ENFORCE(output_index < program_.Outputs().size(), "Too many outputs in the program (", program_.Outputs().size(), ")"); const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape : program_.Outputs()[output_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Output, name, usage, dims); + return AddVariableImpl(false, name, usage, dims); +} + +const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, bool use_uniform) { + const size_t indices_index = indices_vars_.size(); + return *indices_vars_.emplace_back( + std::make_unique(name, + ProgramVariableDataType::InvalidType, + use_uniform ? ShaderUsage::UseUniform : ShaderUsage::None, + program_.Indices()[indices_index])); } #ifndef NDEBUG // if debug build @@ -162,7 +171,7 @@ Status ValidateVariableShape(const TensorShape& origin_shape, } // Validate if the dependency and variable usage match -Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderVariable::Usage usage, bool is_input) { +Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderUsage usage, bool is_input) { bool dependency_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; bool dependency_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; bool dependency_type = (dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type; @@ -172,7 +181,7 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh "Dependency cannot set for both \"Rank\" and \"Shape\"."); // if dependency is set for shape, it's already part of the shader cache. no need to use uniform. - ORT_RETURN_IF(dependency_shape && (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform, + ORT_RETURN_IF(dependency_shape && (usage & ShaderUsage::UseUniform), "Dependency is set for \"Shape\", using uniform for shape is not allowed."); // for input variable, check is more strict. @@ -180,11 +189,11 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh if (is_input) { // if dependency is not set for type, should not use type alias for element and value. // storage type is always used. so setting not depending on type is at user's own risk. - ORT_RETURN_IF(!dependency_type && (usage & (ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias)), + ORT_RETURN_IF(!dependency_type && (usage & (ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias)), "Input dependency is not set for \"Type\", but type alias for element type or value type is used."); // if dependency is not set for rank and shape, the shader should not use shape and stride. - ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderVariable::UseShapeAndStride), + ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderUsage::UseShapeAndStride), "Input dependency is set for neither \"Rank\" nor \"Shape\", but variable shape and stride is used."); } @@ -192,7 +201,7 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh } } // namespace -Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const { ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), input.use_override_shape, @@ -202,7 +211,7 @@ Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVar return Status::OK(); } -Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const { ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), output.use_override_shape, @@ -215,93 +224,97 @@ Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderV #endif // NDEBUG -const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, - const std::string& name, - ShaderVariable::Usage usage, - const TensorShape& dims) { - if (scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) { - ORT_ENFORCE(vars_[std::underlying_type::type(ProgramVariableScope::Input)].size() + - vars_[std::underlying_type::type(ProgramVariableScope::Output)].size() < - limits_.maxStorageBuffersPerShaderStage, - "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); - } +const ShaderVariableHelper& ShaderHelper::AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims) { + ORT_ENFORCE(input_vars_.size() + output_vars_.size() < limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); - auto& vars = vars_[std::underlying_type::type(scope)]; ProgramVariableDataType type = ProgramVariableDataType::InvalidType; + auto& vars = is_input ? input_vars_ : output_vars_; - if (scope == ProgramVariableScope::Input) { + if (is_input) { const auto& input = program_.Inputs()[vars.size()]; type = input.var_type; - } else if (scope == ProgramVariableScope::Output) { + } else { const auto& output = program_.Outputs()[vars.size()]; type = output.var_type; - } else { - ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); } - const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); + const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); return *var; } -Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { - const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; - const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; - - // Validate input/output as dependencies of shape_uniforms - ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), - "Mismatched input variable count. Shader: ", input_vars.size(), ", Program: ", program_.Inputs().size()); - ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), - "Mismatched output variable count. Shader: ", output_vars.size(), ", Program: ", program_.Outputs().size()); - - for (size_t i = 0; i < input_vars.size(); i++) { +Status ShaderHelper::ValidateShapeForInputs() const { + // Validate input as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(input_vars_.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", input_vars_.size(), ", Program: ", program_.Inputs().size()); + for (size_t i = 0; i < input_vars_.size(); i++) { #ifndef NDEBUG // if debug build // Validate input shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars_[i])); #endif // check input dependencies with actual usages. - auto usage = input_vars[i]->usage_; - bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto usage = input_vars_[i]->usage_; auto dependency = program_.Inputs()[i].dependency; bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; - if (use_uniform) { - ORT_RETURN_IF_NOT((use_rank || input_vars[i]->rank_ < 2) && !use_shape, - "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); - } else { - ORT_RETURN_IF_NOT(use_shape, - "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); - // If you want neither hard-coded shape nor shape uniform, set UseUniform with a flattened shape (rank=1). - // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + ORT_RETURN_IF_NOT((use_rank || input_vars_[i]->rank_ < 2) && !use_shape, + "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); + // If you want neither hard-coded shape nor shape uniform, use a flattened shape (rank=1). + // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + } } } + return Status::OK(); +} + +Status ShaderHelper::ValidateShapeForOutputs() const { + // Validate output as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(output_vars_.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", output_vars_.size(), ", Program: ", program_.Outputs().size()); - for (size_t i = 0; i < output_vars.size(); i++) { + for (size_t i = 0; i < output_vars_.size(); i++) { #ifndef NDEBUG // if debug build // Validate output shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars_[i])); #endif // check output dependencies with actual usages. - auto usage = output_vars[i]->usage_; - bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto usage = output_vars_[i]->usage_; auto dependency = program_.Outputs()[i].dependency; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; - if (use_uniform) { - // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not - // necessarily a part of the cache key. - ORT_RETURN_IF_NOT(!use_shape, - "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); - } else { - ORT_RETURN_IF_NOT(use_shape, - "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not + // necessarily a part of the cache key. + ORT_RETURN_IF_NOT(!use_shape, + "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + } } } return Status::OK(); } +Status ShaderHelper::ValidateIndices() const { + ORT_RETURN_IF_NOT(indices_vars_.size() == program_.Indices().size(), + "Mismatched indices variable count. Shader: ", indices_vars_.size(), ", Program: ", program_.Indices().size()); + + return Status::OK(); +} + Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -362,12 +375,10 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // Input/output variables // size_t variable_count = 0; - const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; - for (const auto& input : input_vars) { + for (const auto& input : input_vars_) { ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; } - const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; - for (const auto& output : output_vars) { + for (const auto& output : output_vars_) { ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; } @@ -378,22 +389,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // store shape uniform ranks in shape_uniform_ranks bool use_any_shape_uniform = false; ORT_ENFORCE(shape_uniform_ranks.size() == 0); - shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); + shape_uniform_ranks.reserve(input_vars_.size() + output_vars_.size() + indices_vars_.size()); - for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input->usage_ & ShaderVariable::UseUniform) && - (input->usage_ & ShaderVariable::UseShapeAndStride) && + for (const auto& input : input_vars_) { + bool use_uniform = (input->usage_ & ShaderUsage::UseUniform) && + (input->usage_ & ShaderUsage::UseShapeAndStride) && input->rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? input->rank_ : 0); } - for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output->usage_ & ShaderVariable::UseUniform) && - (output->usage_ & ShaderVariable::UseShapeAndStride) && + for (const auto& output : output_vars_) { + bool use_uniform = (output->usage_ & ShaderUsage::UseUniform) && + (output->usage_ & ShaderUsage::UseShapeAndStride) && output->rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? output->rank_ : 0); } + for (const auto& indices : indices_vars_) { + bool use_uniform = (indices->usage_ & ShaderUsage::UseUniform) && + (indices->usage_ & ShaderUsage::UseShapeAndStride) && + indices->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? indices->rank_ : 0); + } if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), program_.UniformVariables().cend(), @@ -430,9 +448,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } }; - for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + for (const auto& input : input_vars_) { const size_t rank = input->rank_; - if (rank > 0 && (input->usage_ & ShaderVariable::Usage::UseUniform) && (input->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + if (rank > 0 && (input->usage_ & ShaderUsage::UseUniform) && (input->usage_ & ShaderUsage::UseShapeAndStride)) { std::string shape = input->name_ + "_shape"; std::string stride = input->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); @@ -440,9 +458,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } } - for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + for (const auto& output : output_vars_) { const size_t rank = output->rank_; - if (rank > 0 && (output->usage_ & ShaderVariable::Usage::UseUniform) && (output->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + if (rank > 0 && (output->usage_ & ShaderUsage::UseUniform) && (output->usage_ & ShaderUsage::UseShapeAndStride)) { std::string shape = output->name_ + "_shape"; std::string stride = output->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); @@ -450,6 +468,16 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } } + for (const auto& indices : indices_vars_) { + const size_t rank = indices->rank_; + if (rank > 0 && (indices->usage_ & ShaderUsage::UseUniform) && (indices->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = indices->name_ + "_shape"; + std::string stride = indices->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + for (size_t i = 0; i < program_.UniformVariables().size(); i++) { const auto& uniform_def = program_metadata_.uniform_variables[i]; const auto& uniform_value = program_.UniformVariables()[i]; @@ -465,10 +493,14 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // Indices helper // ss << "\n"; - for (const auto& var_group : vars_) { - for (const auto& var : var_group) { - var->Impl(ss); - } + for (const auto& var : input_vars_) { + var->Impl(ss); + } + for (const auto& var : output_vars_) { + var->Impl(ss); + } + for (const auto& var : indices_vars_) { + var->Impl(ss); } ss << "\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index 811ae3cfa15c..bdc14669cfb5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -80,14 +80,17 @@ class ShaderHelper final { // Add an input variable to the shader. // // depending on the usage of the variable, additional code may be generated. - const ShaderVariable& AddInput(const std::string& name, - ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + const ShaderVariableHelper& AddInput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); // Add an output variable to the shader. // // depending on the usage of the variable, additional code may be generated. - const ShaderVariable& AddOutput(const std::string& name, - ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + const ShaderVariableHelper& AddOutput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an indices variable to the shader. + const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform = true); // Append additional implementation code to the shader. // @@ -136,17 +139,19 @@ class ShaderHelper final { } } - const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, - const std::string& name, - ShaderVariable::Usage usage, - const TensorShape& dims); + const ShaderVariableHelper& AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims); #ifndef NDEBUG // if debug build - Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; - Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; + Status ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const; + Status ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const; #endif - Status ValidateShapeForInputsAndOutputs() const; + Status ValidateShapeForInputs() const; + Status ValidateShapeForOutputs() const; + Status ValidateIndices() const; // Generate source code. // @@ -171,7 +176,9 @@ class ShaderHelper final { const ProgramBase& program_; const ProgramMetadata& program_metadata_; - std::array>, static_cast(ProgramVariableScope::Count)> vars_; + std::vector> input_vars_; + std::vector> output_vars_; + std::vector> indices_vars_; std::ostringstream additional_implementation_; std::ostringstream body_; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 07c5915be466..f2a5b049b477 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -76,7 +76,7 @@ inline std::string GetIndicesType(int rank) { } // namespace -ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) +ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) : name_(name), type_(type), num_components_{NumberOfComponents(type)}, @@ -86,30 +86,33 @@ ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType ty indices_type_{GetIndicesType(rank_)}, value_type_alias_{name_ + "_value_t"}, element_type_alias_{name_ + "_element_t"}, - indices_type_alias_{name_ + "_indices_t"} { + indices_type_alias_{name_ + "_indices_t"} {} + +ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : ShaderIndicesHelper{name, type, usage, dims} { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } -void ShaderVariable::Impl(std::ostringstream& ss) const { +void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { // Start generating code - const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; - const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; + const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; + const std::string stride = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; // Types - if (usage_ & UseValueTypeAlias) { + if (usage_ & ShaderUsage::UseValueTypeAlias) { SS("alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); } - if (usage_ & UseIndicesTypeAlias) { + if (usage_ & ShaderUsage::UseIndicesTypeAlias) { SS("alias ", indices_type_alias_, " = ", indices_type_, ";\n"); } - if (usage_ & UseElementTypeAlias) { + if (usage_ & ShaderUsage::UseElementTypeAlias) { SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } // Need shape and strides when (not use uniform) and (use shape and stride is enabled) - if (!(usage_ & UseUniform) && (usage_ & UseShapeAndStride) && rank_ > 0) { + if (!(usage_ & ShaderUsage::UseUniform) && (usage_ & ShaderUsage::UseShapeAndStride) && rank_ > 0) { SS("const ", shape, " = ", IndicesType(), "("); bool first = true; @@ -138,7 +141,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn o2i_{name}" - if (usage_ & UseOffsetToIndices) { + if (usage_ & ShaderUsage::UseOffsetToIndices) { if (rank_ >= 2) { SS("fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); SS(" var indices: ", IndicesType(), ";\n"); @@ -157,7 +160,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn i2o_{name}" - if (usage_ & UseIndicesToOffset) { + if (usage_ & ShaderUsage::UseIndicesToOffset) { if (rank_ >= 2) { SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); SS(" return "); @@ -170,7 +173,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn {res_name}_bi2o_{name}" - if (usage_ & UseBroadcastedIndicesToOffset) { + if (usage_ & ShaderUsage::UseBroadcastedIndicesToOffset) { if (rank_ > 0) { for (const auto& broadcasted_result_ptr : broadcasted_to_) { const auto& broadcasted_result = *broadcasted_result_ptr; @@ -190,9 +193,13 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } } } +} + +void ShaderVariableHelper::Impl(std::ostringstream& ss) const { + ShaderIndicesHelper::Impl(ss); // Implementation of "fn set_{name}" - if (usage_ & UseSet) { + if (usage_ & ShaderUsage::UseSet) { if (rank_ >= 2) { SS("fn set_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { @@ -209,7 +216,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn set_{name}_by_indices" - if (usage_ & UseSetByIndices) { + if (usage_ & ShaderUsage::UseSetByIndices) { if (rank_ >= 2) { SS("fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); @@ -218,7 +225,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn get_{name}" - if (usage_ & UseGet) { + if (usage_ & ShaderUsage::UseGet) { if (rank_ >= 2) { SS("fn get_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { @@ -235,7 +242,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Implementation of "fn get_{name}_by_indices" - if (usage_ & UseGetByIndices) { + if (usage_ & ShaderUsage::UseGetByIndices) { if (rank_ >= 2) { SS("fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); @@ -244,7 +251,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } } -std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { +std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -270,7 +277,7 @@ std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { return ss.str(); } -std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string_view value) const { +std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::string_view value) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -294,20 +301,20 @@ std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string return ss.str(); } -std::string_view ShaderVariable::StorageType() const { +std::string_view ShaderVariableHelper::StorageType() const { return STORAGE_TYPE[static_cast(type_)]; } -std::string_view ShaderVariable::ValueType() const { - return (usage_ & UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; +std::string_view ShaderVariableHelper::ValueType() const { + return (usage_ & ShaderUsage::UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; } -std::string_view ShaderVariable::ElementType() const { - return (usage_ & UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; +std::string_view ShaderVariableHelper::ElementType() const { + return (usage_ & ShaderUsage::UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; } -std::string_view ShaderVariable::IndicesType() const { - return (usage_ & UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; +std::string_view ShaderIndicesHelper::IndicesType() const { + return (usage_ & ShaderUsage::UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; } } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 71822a61f7a7..326c6814410d 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -37,9 +37,8 @@ std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool i return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : std::string{var}; } -class ShaderVariable { - public: - enum Usage : uint32_t { +struct ShaderUsage { + enum : uint32_t { None = 0, // no usage. this means no additional implementation code will be generated. UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) @@ -53,17 +52,21 @@ class ShaderVariable { UseGet = 1024, // use implementation of fn get_{name} UseGetByIndices = 2048, // use implementation of fn get_{name}_by_indices UseUniform = 32768, // use uniform for shape and stride - }; + } usage; - ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims); + ShaderUsage(decltype(usage) usage) : usage{usage} {} + ShaderUsage(uint32_t usage) : usage{usage} {} - ShaderVariable(ShaderVariable&&) = default; - ShaderVariable& operator=(ShaderVariable&&) = default; + explicit operator bool() { + return usage != None; + } +}; - // get the name of the variable. - inline std::string_view Name() const { return name_; } +// A helper class to make it easier to generate shader code related to indices calculation. +class ShaderIndicesHelper { + public: + ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); - // get the number of components of the variable. inline int NumComponents() const { return num_components_; } // create a WGSL expression ({varname}_indices_t) for getting indices from offset. @@ -77,7 +80,7 @@ class ShaderVariable { // create a WGSL expression (u32) for getting original offset from broadcasted indices. // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. // \param broadcasted_result: the broadcasted result variable. - inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const; + inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const; // create a WGSL expression ({varname}_indices_t) as an indices literal // \param init: a list of indices values. @@ -97,6 +100,41 @@ class ShaderVariable { template inline std::string IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const; + protected: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper); + + void Impl(std::ostringstream& ss) const; + + std::string_view IndicesType() const; + + std::string name_; + ProgramVariableDataType type_; // for variable + int num_components_; // for variable + int rank_; + TensorShape dims_; + + mutable ShaderUsage usage_; + mutable std::set broadcasted_to_; + + // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. + std::string indices_type_; + + // the alias for the types + std::string value_type_alias_; + std::string element_type_alias_; + std::string indices_type_alias_; + + friend class ShaderHelper; +}; + +// A helper class to make it easier to generate shader code related to a variable setting/getting and its indices calculation. +class ShaderVariableHelper : public ShaderIndicesHelper { + public: + ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderVariableHelper(ShaderVariableHelper&&) = default; + ShaderVariableHelper& operator=(ShaderVariableHelper&&) = default; + // create a WGSL statement for setting data at the given indices. // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). template @@ -128,12 +166,7 @@ class ShaderVariable { inline std::string GetByOffset(TOffset&& offset) const; private: - friend ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b); - friend ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b); - friend ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b); - friend ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b); - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); void Impl(std::ostringstream& ss) const; @@ -142,39 +175,23 @@ class ShaderVariable { std::string_view StorageType() const; std::string_view ValueType() const; std::string_view ElementType() const; - std::string_view IndicesType() const; - - std::string name_; - ProgramVariableDataType type_; - int num_components_; - int rank_; - TensorShape dims_; - - mutable Usage usage_; - mutable std::set broadcasted_to_; - - // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. - std::string indices_type_; - - // the alias for the types - std::string value_type_alias_; - std::string element_type_alias_; - std::string indices_type_alias_; friend class ShaderHelper; }; -inline ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage)((uint32_t&)a | (uint32_t&)b); +inline ShaderUsage operator|(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage | (uint32_t)b.usage; } -inline ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage)((uint32_t&)a & (uint32_t&)b); +inline ShaderUsage operator&(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage & (uint32_t)b.usage; } -inline ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage&)((uint32_t&)a |= (uint32_t&)b); +inline ShaderUsage& operator|=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage |= (uint32_t)b.usage; + return a; } -inline ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { - return (ShaderVariable::Usage&)((uint32_t&)a &= (uint32_t&)b); +inline ShaderUsage& operator&=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage &= (uint32_t)b.usage; + return a; } namespace detail { @@ -192,20 +209,24 @@ std::string pass_as_string(T&& v) { } } // namespace detail -inline std::string ShaderVariable::OffsetToIndices(std::string_view offset_expr) const { - usage_ |= UseOffsetToIndices | UseShapeAndStride; +inline std::string ShaderIndicesHelper::OffsetToIndices(std::string_view offset_expr) const { + usage_ |= ShaderUsage::UseOffsetToIndices | ShaderUsage::UseShapeAndStride; return rank_ < 2 ? std::string{offset_expr} : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); } -inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr) const { - usage_ |= UseIndicesToOffset | UseShapeAndStride; +inline std::string ShaderIndicesHelper::IndicesToOffset(std::string_view indices_expr) const { + usage_ |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; return rank_ < 2 ? std::string{indices_expr} : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); } -inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { - usage_ |= UseBroadcastedIndicesToOffset | UseShapeAndStride; +inline std::string ShaderIndicesHelper::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const { + ORT_ENFORCE(broadcasted_result.num_components_ == -1 || + num_components_ == -1 || + broadcasted_result.num_components_ == num_components_, + "number of components should be the same for 2 variables to calculate"); + usage_ |= ShaderUsage::UseBroadcastedIndicesToOffset | ShaderUsage::UseShapeAndStride; broadcasted_to_.insert(&broadcasted_result); return rank_ == 0 ? "0" @@ -213,8 +234,8 @@ inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view i } template -inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderIndicesHelper::Indices(TIndices&&... indices_args) const { + usage_ |= ShaderUsage::UseShapeAndStride; return rank_ == 0 ? "0" : MakeStringWithClassicLocale(IndicesType(), "(", @@ -223,77 +244,77 @@ inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { } template -inline std::string ShaderVariable::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderIndicesHelper::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { + usage_ |= ShaderUsage::UseShapeAndStride; return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); } template -inline std::string ShaderVariable::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderIndicesHelper::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + usage_ |= ShaderUsage::UseShapeAndStride; return rank_ < 2 ? std::string{indices_var} : GetElementAt(indices_var, idx_expr, rank_); } template -inline std::string ShaderVariable::SetByOffset(TOffset&& offset, TValue&& value) const { +inline std::string ShaderVariableHelper::SetByOffset(TOffset&& offset, TValue&& value) const { return SetByOffsetImpl(detail::pass_as_string(offset), detail::pass_as_string(value)); } template -inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::Set(TIndicesAndValue&&... args) const { + usage_ |= ShaderUsage::UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); if constexpr (sizeof...(TIndicesAndValue) == 1) { return SetByOffset("0", std::forward(args)...); } else if constexpr (sizeof...(TIndicesAndValue) == 2) { return SetByOffset(std::forward(args)...); } else { - usage_ |= UseSet | UseSetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseSet | ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("set_", name_, '(', absl::StrJoin(std::forward_as_tuple(std::forward(args)...), ", "), ");"); } } -inline std::string ShaderVariable::SetByIndices(std::string_view indices_var, std::string_view value) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::SetByIndices(std::string_view indices_var, std::string_view value) const { + usage_ |= ShaderUsage::UseShapeAndStride; if (rank_ < 2) { return SetByOffset(indices_var, value); } else { - usage_ |= UseSetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("set_", name_, "_by_indices(", indices_var, ", ", value, ");"); } } template -inline std::string ShaderVariable::GetByOffset(TOffset&& offset) const { +inline std::string ShaderVariableHelper::GetByOffset(TOffset&& offset) const { return GetByOffsetImpl(detail::pass_as_string(offset)); } template -inline std::string ShaderVariable::Get(TIndices&&... indices) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::Get(TIndices&&... indices) const { + usage_ |= ShaderUsage::UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); if constexpr (sizeof...(TIndices) == 0) { return GetByOffset("0"); } else if constexpr (sizeof...(TIndices) == 1) { return GetByOffset(std::forward(indices)...); } else { - usage_ |= UseGet | UseGetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseGet | ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("get_", name_, '(', absl::StrJoin(std::forward_as_tuple(std::forward(indices)...), ", "), ')'); } } -inline std::string ShaderVariable::GetByIndices(std::string_view indices_var) const { - usage_ |= UseShapeAndStride; +inline std::string ShaderVariableHelper::GetByIndices(std::string_view indices_var) const { + usage_ |= ShaderUsage::UseShapeAndStride; if (rank_ < 2) { return GetByOffset(indices_var); } else { - usage_ |= UseGetByIndices | UseIndicesToOffset; + usage_ |= ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; return MakeStringWithClassicLocale("get_", name_, "_by_indices(", indices_var, ")"); } } diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 45084472d353..a10658365188 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -11,8 +11,8 @@ namespace onnxruntime { namespace webgpu { Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderVariable::UseUniform); - const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 68af858d515c..b620e83843b2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -61,8 +61,8 @@ const std::string AppendPermFunction(gsl::span perm) { } Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); shader.AppendImplementation(AppendPermFunction(this->perm_)); shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), " let indices = ", output.OffsetToIndices("global_idx"), diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 11a337cd3e37..66b1c2c7fafa 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -229,17 +229,16 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog std::vector shape_uniforms; shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); if (ValidationMode() >= ValidationMode::Basic) { - ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size() + program.Indices().size(), "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), - ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + ") does not match current program (input: ", inputs.size(), + ", output: ", outputs.size(), + ", indices: ", program.Indices().size(), ")"); } - for (size_t i = 0; i < program_artifact->shape_uniform_ranks.size(); ++i) { + + auto append_shape_uniforms = [&shape_uniforms, program_artifact](size_t i, const TensorShape& shape) { SafeInt expected_rank = program_artifact->shape_uniform_ranks[i]; if (expected_rank > 0) { - const auto& shape = i < inputs.size() ? (inputs[i].use_override_shape ? inputs[i].override_shape - : inputs[i].tensor->Shape()) - : (outputs[i - inputs.size()].use_override_shape ? outputs[i - inputs.size()].override_shape - : outputs[i - inputs.size()].tensor->Shape()); ORT_RETURN_IF(expected_rank != shape.NumDimensions(), "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", (int)expected_rank, ", Actual: ", shape.NumDimensions()); @@ -258,6 +257,19 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog shape_uniforms.emplace_back(gsl::make_span(stride)); } } + return Status::OK(); + }; + + for (size_t i = 0; i < inputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i, + inputs[i].use_override_shape ? inputs[i].override_shape : inputs[i].tensor->Shape())); + } + for (size_t i = 0; i < outputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size(), + outputs[i].use_override_shape ? outputs[i].override_shape : outputs[i].tensor->Shape())); + } + for (size_t i = 0; i < program.Indices().size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size() + outputs.size(), program.Indices()[i])); } const size_t uniform_count = shape_uniforms.size() + program.UniformVariables().size(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 444f07e1664b..abd471578146 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -456,36 +456,36 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(20, Gelu), // // binary - math - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), - // KERNEL_CREATE_INFO(14, Add), - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), - // KERNEL_CREATE_INFO(14, Sub), - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), - // KERNEL_CREATE_INFO(14, Mul), - // KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), - // KERNEL_CREATE_INFO(14, Div), - // KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), - // KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), - // KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), - // KERNEL_CREATE_INFO(15, Pow), - // KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), - // KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), - // KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), - // KERNEL_CREATE_INFO(19, Equal), - // KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), - // KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), - // KERNEL_CREATE_INFO(13, Greater), - // KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), - // KERNEL_CREATE_INFO(16, GreaterOrEqual), - // KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), - // KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), - // KERNEL_CREATE_INFO(13, Less), - // KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), - // KERNEL_CREATE_INFO(16, LessOrEqual), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + KERNEL_CREATE_INFO(14, Add), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + KERNEL_CREATE_INFO(14, Sub), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + KERNEL_CREATE_INFO(14, Mul), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + KERNEL_CREATE_INFO(14, Div), + KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + KERNEL_CREATE_INFO(15, Pow), + KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), + KERNEL_CREATE_INFO(19, Equal), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + KERNEL_CREATE_INFO(13, Greater), + KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + KERNEL_CREATE_INFO(16, GreaterOrEqual), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + KERNEL_CREATE_INFO(13, Less), + KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + KERNEL_CREATE_INFO(16, LessOrEqual), // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 4ca915dd394c..4aa3e9c6b37a 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -369,6 +369,28 @@ TEST(MathOpTest, Add_Broadcast_3x2_3x1) { #endif } +TEST(MathOpTest, Add_Broadcast_2x2x2_1x2x2) { + OpTester test("Add"); + + test.AddInput("A", {2, 2, 2}, + {101.0f, 102.0f, + 103.0f, 104.0f, + + 201.0f, 202.0f, + 203.0f, 204.0f}); + test.AddInput("B", {1, 2, 2}, + {010.0f, 020.0f, + 030.0f, 040.0f}); + test.AddOutput("C", {2, 2, 2}, + {111.0f, 122.0f, + 133.0f, 144.0f, + + 211.0f, 222.0f, + 233.0f, 244.0f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(MathOpTest, Add_Broadcast_2x1x4_1x3x1) { OpTester test("Add"); From 2e91a8b1231d4b1a6bccbdab044cdc015700a597 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 17 Sep 2024 01:11:18 -0700 Subject: [PATCH 082/114] use f32 for pow anyway --- .../core/providers/webgpu/math/binary_elementwise_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 9d9eff2ccdde..8feb5daae3f5 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -277,7 +277,7 @@ WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) -WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(output_value_t(a), output_value_t(b)))") +WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4(a), vec4(b)))") WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) From 87f9edb5fcc23ac0b59503fbdadac938cc2f234c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:59:15 -0700 Subject: [PATCH 083/114] Cast operator --- .../core/providers/webgpu/tensor/cast.cc | 117 ++++++++++++++++++ .../core/providers/webgpu/tensor/cast.h | 39 ++++++ 2 files changed, 156 insertions(+) create mode 100644 onnxruntime/core/providers/webgpu/tensor/cast.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/cast.h diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc new file mode 100644 index 000000000000..8d59570de996 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/tensor/cast.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +const std::vector& CastOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 6, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 9, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 13, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_KERNEL_EX( + Cast, + kOnnxDomain, + 19, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); + +Status Cast::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + SafeInt vec_size = (size + 3) / 4; + + CastProgram program{to_}; + program + .AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .CacheHint(std::to_string(to_)); + return context.RunProgram(program); +} + +Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& input = sh.AddInput("x", ShaderUsage::UseUniform); + const auto& output = sh.AddOutput("y", ShaderUsage::UseUniform); + std::string expression; + switch (to_) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + expression = "vec4(a)"; + break; + default: + ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported."); + } + sh.SetMainFunctionBody(sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " let a = ", input.GetByOffset("global_idx"), ";\n ", + output.SetByOffset("global_idx", expression)); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h new file mode 100644 index 000000000000..c7216c18500f --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class CastProgram final : public Program { + public: + CastProgram(int32_t to) : Program{"Cast"}, to_{to} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + private: + int32_t to_; +}; + +class Cast final : public WebGpuKernel { + public: + Cast(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t to; + Status status = info.GetAttr("to", &to); + ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); + to_ = SafeInt(to); + + // ignore attribute 'saturate' as float8 is not supported in WebGPU + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int32_t to_; +}; + +} // namespace webgpu +} // namespace onnxruntime From 19ee9f3745cb7306a9634023f3b986a646365cd9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:17:50 -0700 Subject: [PATCH 084/114] do not use virtual function for getting ProgramMetadata --- onnxruntime/core/providers/webgpu/program.cc | 3 +- onnxruntime/core/providers/webgpu/program.h | 40 +++++-------------- .../core/providers/webgpu/webgpu_context.cc | 2 +- 3 files changed, 14 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 21c63f75d26d..75c3c9ee9608 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -233,8 +233,9 @@ ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dep use_override_shape{true}, override_shape{override_shape} {} -ProgramBase::ProgramBase(const std::string& name) +ProgramBase::ProgramBase(const std::string& name, ProgramMetadata&& metadata) : name_{name}, + metadata_{metadata}, dispatch_group_size_x_{0}, dispatch_group_size_y_{0}, dispatch_group_size_z_{0}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index c7199e2a57a6..f05ca9c2bf22 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -230,7 +230,11 @@ namespace detail { class ProgramWrapper; } -struct ProgramMetadata; +struct ProgramMetadata { + gsl::span constants; + gsl::span overridable_constants; + gsl::span uniform_variables; +}; class ProgramBase { public: @@ -295,30 +299,12 @@ class ProgramBase { virtual Status GenerateShaderCode(ShaderHelper& shader) const = 0; - // - // abstract methods for getting metadata - // - // A derived class may contain any of the following static members: - // - // \code{.cpp} - // // define a list of constant that used in the shader program - // static constexpr const ProgramConstant constants[] = { ... }; - // - // // define a list of overridable constant that used in the shader program - // static constexpr const ProgramOverridableConstantDefinition overridable_constants[] = { ... }; - // - // // define a list of uniform variables that used in the shader program - // static constexpr const ProgramUniformVariableDefinition uniform_variables[] = { ... }; - // \endcode - // - // If those static members exist, the value of them will be used to generate the metadata. - virtual ProgramMetadata GetMetadata() const = 0; - // // Properties Getters // inline const std::string& Name() const { return name_; } + inline const ProgramMetadata& Metadata() const { return metadata_; } inline const std::string& CacheHint() const { return cache_hint_; } inline const std::vector& Inputs() const { return inputs_; } inline const std::vector& Outputs() const { return outputs_; } @@ -338,9 +324,11 @@ class ProgramBase { private: // Make the constructor private to prevent direct instantiation or inheritance from this class // Use the Program template class as base class to create a new program class - explicit ProgramBase(const std::string& name); + explicit ProgramBase(const std::string& name, ProgramMetadata&& metadata); std::string name_; + ProgramMetadata metadata_; + std::string cache_hint_; std::vector inputs_; std::vector outputs_; @@ -500,19 +488,13 @@ static_assert(!TestTypeCheck::has_a_with_correct_type); } // namespace detail -struct ProgramMetadata { - gsl::span constants; - gsl::span overridable_constants; - gsl::span uniform_variables; -}; - template class Program : public detail::ProgramWrapper { public: template - Program(Args&&... args) : detail::ProgramWrapper{std::forward(args)...} {} + Program(Args&&... args) : detail::ProgramWrapper{std::forward(args)..., GetMetadata()} {} - virtual ProgramMetadata GetMetadata() const final { + static ProgramMetadata GetMetadata() { ProgramMetadata metadata; if constexpr (detail::DerivedProgramClassTypeCheck::has_constants) { constexpr const ProgramConstant* ptr = T::constants.data(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 66b1c2c7fafa..f2414c14f6f9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -145,7 +145,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog "All outputs must be tensors on WebGPU buffers."); } - const ProgramMetadata metadata = program.GetMetadata(); + const ProgramMetadata& metadata = program.Metadata(); // validate program metadata if (ValidationMode() >= ValidationMode::Basic) { From d9f7f1924f47cb91143c362450ca12d65cddef80 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:21:14 -0700 Subject: [PATCH 085/114] reshape, squeeze and unsqueeze --- .../core/providers/webgpu/tensor/reshape.cc | 72 +++++++++++++++++++ .../core/providers/webgpu/tensor/reshape.h | 51 +++++++++++++ .../core/providers/webgpu/tensor/squeeze.cc | 44 ++++++++++++ .../core/providers/webgpu/tensor/squeeze.h | 52 ++++++++++++++ .../core/providers/webgpu/tensor/unsqueeze.cc | 44 ++++++++++++ .../core/providers/webgpu/tensor/unsqueeze.h | 53 ++++++++++++++ .../webgpu/webgpu_execution_provider.cc | 67 +++++++++-------- 7 files changed, 355 insertions(+), 28 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/reshape.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/reshape.h create mode 100644 onnxruntime/core/providers/webgpu/tensor/squeeze.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/squeeze.h create mode 100644 onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/unsqueeze.h diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.cc b/onnxruntime/core/providers/webgpu/tensor/reshape.cc new file mode 100644 index 000000000000..9ede015a0c99 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/reshape.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Reshape, + kOnnxDomain, + 21, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 14, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 13, 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 5, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.h b/onnxruntime/core/providers/webgpu/tensor/reshape.h new file mode 100644 index 000000000000..4629598d068f --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/data_transfer_manager.h" +#include "core/providers/cpu/tensor/reshape_helper.h" + +namespace onnxruntime { +namespace webgpu { + +class Reshape final : public OpKernel { + public: + Reshape(const OpKernelInfo& info) + : OpKernel{info}, + allow_zero_(info.GetAttrOrDefault("allowzero", static_cast(0)) == 1) { + } + + Status Compute(OpKernelContext* context) const override { + // Copy the second input tensor into the shape vector + const Tensor* shapeTensor = context->Input(1); + if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (shapeTensor->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); + } + auto data_span = shapeTensor->template DataAsSpan(); + TensorShapeVector shape(data_span.begin(), data_span.end()); + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + ReshapeHelper helper(X_shape, shape, allow_zero_); + + Tensor* Y = context->Output(0, TensorShape(shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } + + private: + bool allow_zero_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc new file mode 100644 index 000000000000..136a1ba9776a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/squeeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.h b/onnxruntime/core/providers/webgpu/tensor/squeeze.h new file mode 100644 index 000000000000..bc80cb86d5e8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/squeeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Squeeze final : public OpKernel, public SqueezeBase { + public: + explicit Squeeze(const OpKernelInfo& info) : OpKernel{info}, SqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc new file mode 100644 index 000000000000..4bcef4fd7929 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/unsqueeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h new file mode 100644 index 000000000000..38475af5e277 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Unsqueeze final : public OpKernel, public UnsqueezeBase { + public: + explicit Unsqueeze(const OpKernelInfo& info) : OpKernel{info}, UnsqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index abd471578146..8539992e85cf 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -125,10 +125,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 10, Clip); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, Clip); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Clip); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Clip); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, Elu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Relu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Relu); @@ -229,7 +231,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 5, 12, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); @@ -434,18 +438,20 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(9, Atanh), KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), KERNEL_CREATE_INFO(13, Tanh), - // KERNEL_CREATE_INFO(1, Not), + KERNEL_CREATE_INFO(1, Not), - // KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), - // KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), - // KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - // KERNEL_CREATE_INFO(19, Cast), + KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), + KERNEL_CREATE_INFO(19, Cast), // // activations - // KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip), - // KERNEL_CREATE_INFO_VERSIONED(11, 11, Clip), - // KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip), - // KERNEL_CREATE_INFO(13, Clip), + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, KERNEL_CREATE_INFO(6, Elu), KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), @@ -487,17 +493,22 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), KERNEL_CREATE_INFO(16, LessOrEqual), - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -509,9 +520,9 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, From 07675cfef79afc2b49a9ca2c24b6a6e6bb104aa2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 17 Sep 2024 18:02:51 -0700 Subject: [PATCH 086/114] fix Cast and Clip --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 8 ++++++-- onnxruntime/core/providers/webgpu/tensor/cast.h | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 870dd3df24c7..3b43c87fb0c8 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/providers/webgpu/math/unary_elementwise_ops.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -185,8 +186,11 @@ class Clip final : public UnaryElementwise { Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { const auto* clip_min_tensor = context.Input(1); const auto* clip_max_tensor = context.Input(2); - const T attr[] = {clip_min_tensor->Data()[0], - clip_max_tensor->Data()[0]}; + + const T attr[] = {clip_min_tensor ? clip_min_tensor->Data()[0] + : std::numeric_limits::lowest(), + clip_max_tensor ? clip_max_tensor->Data()[0] + : std::numeric_limits::max()}; if constexpr (std::is_same_v) { // F16: stores span as a single float float encoded_value = *reinterpret_cast(attr); diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h index c7216c18500f..47e8e6412be4 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.h +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -14,6 +14,8 @@ class CastProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + private: int32_t to_; }; From dfab3225f0d2b5239ba93fb079661387111bf5be Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 20 Sep 2024 08:30:55 +0800 Subject: [PATCH 087/114] [webgpu-native] Add where op (#22014) Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../core/providers/webgpu/tensor/where.cc | 192 ++++++++++++++++++ .../core/providers/webgpu/tensor/where.h | 35 ++++ .../webgpu/webgpu_execution_provider.cc | 4 +- 3 files changed, 229 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/where.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/where.h diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc new file mode 100644 index 000000000000..31806a0af174 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/where.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +// Compute where operator output shape based upon three way broad-casting. +Status ComputeOutputShape(const TensorShape& cond_shape, + const TensorShape& x_shape, const TensorShape& y_shape, TensorShape& output_shape) { + size_t cond_rank = cond_shape.NumDimensions(); + size_t x_rank = x_shape.NumDimensions(); + size_t y_rank = y_shape.NumDimensions(); + size_t output_rank = std::max(std::max(cond_rank, x_rank), y_rank); + + std::vector output_dims(output_rank, 0); + for (size_t i = 0; i < output_rank; ++i) { + int64_t cond_dim = 1; + if (i < cond_rank) + cond_dim = cond_shape[cond_rank - 1 - i]; + + int64_t x_dim = 1; + if (i < x_rank) + x_dim = x_shape[x_rank - 1 - i]; + + int64_t y_dim = 1; + if (i < y_rank) + y_dim = y_shape[y_rank - 1 - i]; + + int64_t output_dim = std::max(std::max(cond_dim, x_dim), y_dim); + // special case to handle a dim of 0 which can be broadcast with a 1 + if (output_dim == 1) + output_dim = std::min(std::min(cond_dim, x_dim), y_dim); + + const auto node_name = "Where"; + if (cond_dim != output_dim && cond_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": condition operand cannot broadcast on dim ", cond_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (x_dim != output_dim && x_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": X operand cannot broadcast on dim ", x_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (y_dim != output_dim && y_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": Y operand cannot broadcast on dim ", y_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + output_dims[output_rank - 1 - i] = output_dim; + } + + output_shape = TensorShape(output_dims); + return Status::OK(); +} + +Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& c_input = shader.AddInput("c_data", ShaderUsage::UseUniform); + const auto& a_input = shader.AddInput("a_data", ShaderUsage::UseUniform); + const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); + + auto expression = [](const std::string& a, const std::string& b, const std::string& c) -> const auto { + return "select(" + b + ", " + a + ", " + c + ")"; + }; + std::string assignment; + if (!is_broadcast_) { + assignment = output.SetByOffset( + "global_idx", + expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx"))); + + } else { + const auto& c_indices = shader.AddIndices("c_indices"); + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + const auto& output_indices = shader.AddIndices("output_indices"); + + auto single_assignment = + [&expression, &output_indices, &a_indices, &b_indices, &c_indices]( + const std::string& rest_str, const std::string& x, const std::string& type_cast = "") + -> const auto { + const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; + const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; + const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; + + std::ostringstream ss; + ss.imbue(std::locale::classic()); + ss << "let output_indices" + x + " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n"; + ss << "let offset_a" + x + " = " + a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; + ss << "let offset_b" + x + " = " + b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; + ss << "let offset_c" + x + " = " + c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; + ss << "let index_a" + x + " = offset_a" + x + " / 4u;\n"; + ss << "let index_b" + x + " = offset_b" + x + " / 4u;\n"; + ss << "let index_c" + x + " = offset_c" + x + " / 4u;\n"; + ss << "let component_a" + x + " = offset_a" + x + " % 4u;\n"; + ss << "let component_b" + x + " = offset_b" + x + " % 4u;\n"; + ss << "let component_c" + x + " = offset_c" + x + " % 4u;\n"; + ss << rest_str + "[" + x + "] = " + type_cast + "(" + expression(a_expression, b_expression, c_expression) + ");\n"; + return ss.str(); + }; + + if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { + assignment = + "var data = vec4(0); \n" + + single_assignment("data", "0", "u32") + + single_assignment("data", "1", "u32") + + single_assignment("data", "2", "u32") + + single_assignment("data", "3", "u32") + + "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; + } else { + assignment = + single_assignment("output_data[global_idx]", "0") + + single_assignment("output_data[global_idx]", "1") + + single_assignment("output_data[global_idx]", "2") + + single_assignment("output_data[global_idx]", "3"); + } + } + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + assignment); + return Status::OK(); +} + +Status Where::ComputeInternal(ComputeContext& context) const { + const auto* cond_tensor = context.Input(0); + const auto* x_tensor = context.Input(1); + const auto* y_tensor = context.Input(2); + const auto& cond_shape = cond_tensor->Shape(); + const auto& x_shape = x_tensor->Shape(); + const auto& y_shape = y_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); + auto* output_tensor = context.Output(0, output_shape); + const auto component = 4; + uint32_t vec_size = gsl::narrow_cast((output_shape.Size() + 3) / component); + const auto is_broadcast = !(x_shape == y_shape && + y_shape == cond_shape); + WhereProgram program{is_broadcast}; + program + .CacheHint(is_broadcast) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::None, {(cond_shape.Size() + 3) / 4}, 4}, + {x_tensor, ProgramTensorMetadataDependency::None, {(x_shape.Size() + 3) / 4}, 4}, + {y_tensor, ProgramTensorMetadataDependency::None, {(y_shape.Size() + 3) / 4}, 4}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddUniformVariables({ + {static_cast(vec_size)}, + }); + if (is_broadcast) { + program + .AddIndices(cond_shape) + .AddIndices(x_shape) + .AddIndices(y_shape) + .AddIndices(output_tensor->Shape()); + } + return context.RunProgram(program); +} + +namespace { +const std::vector& WhereOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Where, + kOnnxDomain, + 9, 15, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +ONNX_OPERATOR_KERNEL_EX( + Where, + kOnnxDomain, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/where.h b/onnxruntime/core/providers/webgpu/tensor/where.h new file mode 100644 index 000000000000..e46b24e9ba2e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class WhereProgram final : public Program { + public: + WhereProgram(bool is_broadcast) : Program{"Where"}, is_broadcast_{is_broadcast} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool is_broadcast_; +}; + +class Where final : public WebGpuKernel { + public: + Where(const OpKernelInfo& info) : WebGpuKernel{info} { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 8539992e85cf..f5d66d6a2413 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -564,8 +564,8 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), - // KERNEL_CREATE_INFO(16, Where), + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), BuildKernelCreateInfo, BuildKernelCreateInfo, From cb9f3a4381eb21057890fbd44c57915a605df149 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 20 Sep 2024 14:39:59 -0700 Subject: [PATCH 088/114] fix linux build break --- .../core/providers/webgpu/math/binary_elementwise_ops.cc | 2 +- onnxruntime/core/providers/webgpu/tensor/transpose.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 8feb5daae3f5..bae7c6a73c4c 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -142,7 +142,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { } } - SafeInt vec_size = (size + 3) / 4; + uint32_t vec_size = SafeInt((size + 3) / 4); BinaryElementwiseProgram program{kernel_name_, expression_, is_broadcast, diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index b620e83843b2..0962d9191d78 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -52,7 +52,7 @@ const std::string AppendPermFunction(gsl::span perm) { ss.imbue(std::locale::classic()); ss << "fn perm(i: y_indices_t)->x_indices_t {\n" " var a: x_indices_t;\n"; - for (auto i = 0; i < perm.size(); ++i) { + for (size_t i = 0; i < perm.size(); ++i) { ss << " a[" << perm[i] << "] = i[" << i << "];\n"; } ss << " return a;\n" From 929725eec8adfd7d1bd05fa6dfda6d30f5462de5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:24:33 -0700 Subject: [PATCH 089/114] expose KernelContext --- onnxruntime/core/providers/webgpu/compute_context.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index c98480523ae6..b7ea8a58e232 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -43,7 +43,14 @@ class ComputeContext { } // - // Get the logger + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; + } + + // + // Get the logger. // inline const logging::Logger& Logger() const { return kernel_context_.Logger(); From c5e5af35f35c5b08cf14697abdeb23c97fa9f349 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 24 Sep 2024 23:45:21 -0700 Subject: [PATCH 090/114] revise fast gelu --- onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 50debe26ce45..52459b0632d5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -3,6 +3,7 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" #include "contrib_ops/webgpu/bert/fast_gelu.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -20,26 +21,25 @@ ONNX_OPERATOR_KERNEL_EX( FastGelu); Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform); std::string add_bias = ""; if (Inputs().size() > 1) { const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n" - " x += input_value_t(" + + " a += x_value_t(" + bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " + bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") + ", " + bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") + ", " + bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") + ");\n" - : " x += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + : " a += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; } - + shader.AppendImplementation(TanhImpl); shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " var x = ", input.GetByOffset("global_idx"), ";\n", + " var a = ", x.GetByOffset("global_idx"), ";\n", add_bias, - " let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ", - output.SetByOffset("global_idx", "y")); + y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr)); return Status::OK(); } From 82cd59e9aa6fe5ea787f9ec8d3c57cf4023c8d9b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 25 Sep 2024 00:38:48 -0700 Subject: [PATCH 091/114] expose Rank in IndicesHelper --- onnxruntime/core/providers/webgpu/shader_variable.cc | 6 ++++++ onnxruntime/core/providers/webgpu/shader_variable.h | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f2a5b049b477..7dcf2f352905 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -88,6 +88,12 @@ ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableD element_type_alias_{name_ + "_element_t"}, indices_type_alias_{name_ + "_indices_t"} {} +inline int ShaderIndicesHelper::Rank() { + // getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride. + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_; +} + ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) : ShaderIndicesHelper{name, type, usage, dims} { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 326c6814410d..4b2d30478232 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -67,8 +67,12 @@ class ShaderIndicesHelper { public: ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + // get the number of components of the variable. inline int NumComponents() const { return num_components_; } + // get the rank of the indices. + inline int Rank(); + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. inline std::string OffsetToIndices(std::string_view offset_expr) const; From 2393dbf118f3bab7045cb1923e130279c049eb8c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:43:16 -0700 Subject: [PATCH 092/114] fix: move inline impl to .h --- onnxruntime/core/providers/webgpu/shader_variable.cc | 6 ------ onnxruntime/core/providers/webgpu/shader_variable.h | 6 ++++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 7dcf2f352905..f2a5b049b477 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -88,12 +88,6 @@ ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableD element_type_alias_{name_ + "_element_t"}, indices_type_alias_{name_ + "_indices_t"} {} -inline int ShaderIndicesHelper::Rank() { - // getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride. - usage_ |= ShaderUsage::UseShapeAndStride; - return rank_; -} - ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) : ShaderIndicesHelper{name, type, usage, dims} { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 4b2d30478232..dc1c1742fe0e 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -213,6 +213,12 @@ std::string pass_as_string(T&& v) { } } // namespace detail +inline int ShaderIndicesHelper::Rank() { + // getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride. + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_; +} + inline std::string ShaderIndicesHelper::OffsetToIndices(std::string_view offset_expr) const { usage_ |= ShaderUsage::UseOffsetToIndices | ShaderUsage::UseShapeAndStride; return rank_ < 2 ? std::string{offset_expr} From 9bdbd85ff3ae50aa445701b245336493b8984a7d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:56:05 -0700 Subject: [PATCH 093/114] add const modifier --- onnxruntime/core/providers/webgpu/shader_variable.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index dc1c1742fe0e..2ddc9a6e8160 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -71,7 +71,7 @@ class ShaderIndicesHelper { inline int NumComponents() const { return num_components_; } // get the rank of the indices. - inline int Rank(); + inline int Rank() const; // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. @@ -213,7 +213,7 @@ std::string pass_as_string(T&& v) { } } // namespace detail -inline int ShaderIndicesHelper::Rank() { +inline int ShaderIndicesHelper::Rank() const { // getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride. usage_ |= ShaderUsage::UseShapeAndStride; return rank_; From 0101ce8fa32856cfc62a2402f7a8f5c5938782d7 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:38:02 -0700 Subject: [PATCH 094/114] remove toggle "disable_workgroup_init" --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index f2414c14f6f9..7dbccb532dd5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -387,7 +387,6 @@ std::vector WebGpuContext::GetEnabledDeviceToggles() const { constexpr const char* toggles[] = { "skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled" "disable_robustness", - "disable_workgroup_init", "d3d_disable_ieee_strictness", }; return std::vector(ValidationMode() >= ValidationMode::WGPUOnly From 38967062675ef6906824f5317732b5630054220c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:44:32 -0700 Subject: [PATCH 095/114] set backend type to D3D12 since we always uses dxc (win). --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 7dbccb532dd5..e9ae97369c6c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -33,6 +33,9 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info wgpu::RequestAdapterOptions req_adapter_options = {}; wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; req_adapter_options.nextInChain = &adapter_toggles_desc; +#ifdef WIN32 + req_adapter_options.backendType = wgpu::BackendType::D3D12; +#endif auto enabled_adapter_toggles = GetEnabledAdapterToggles(); adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size(); From f02e85a3adefb62649e97c6fadd5851e6fa1ab2c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:53:41 -0700 Subject: [PATCH 096/114] update build configurations to webgpu EP (#22047) ### Description --------- Co-authored-by: Scott McKay --- .../external/onnxruntime_external_deps.cmake | 82 ++++++++++++++----- cmake/onnxruntime.cmake | 66 +++++++++++++-- cmake/onnxruntime_providers_webgpu.cmake | 12 +-- cmake/patches/dawn/dawn.patch | 66 +++++++++++++++ .../webgpu/webgpu_provider_factory.h | 14 ++++ .../main/java/ai/onnxruntime/OrtProvider.java | 4 +- .../platform/apple/logging/apple_log_sink.mm | 2 - .../webgpu/math/unary_elementwise_ops.cc | 12 +++ .../core/providers/webgpu/shader_variable.h | 3 + .../core/providers/webgpu/tensor/where.cc | 6 +- .../core/providers/webgpu/webgpu_context.cc | 5 ++ .../ios_package_uitest_cpp_api.mm | 23 +++++- .../macos_package_uitest_cpp_api.mm | 24 +++++- .../default_full_aar_build_settings.json | 1 + .../apple/build_and_assemble_apple_pods.py | 2 + ...t_full_apple_framework_build_settings.json | 1 + ...ult_full_ios_framework_build_settings.json | 2 + .../templates/mac-cpu-packing-jobs.yml | 6 +- 18 files changed, 277 insertions(+), 54 deletions(-) create mode 100644 cmake/patches/dawn/dawn.patch create mode 100644 include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 370a2d5c7235..6f54ce1b4fac 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -575,10 +575,11 @@ if (onnxruntime_USE_MIMALLOC) onnxruntime_fetchcontent_makeavailable(mimalloc) endif() -#onnxruntime_EXTERNAL_LIBRARIES could contain onnx, onnx_proto,libprotobuf, cuda/cudnn, -# dnnl/mklml, onnxruntime_codegen_tvm, tvm and pthread -# pthread is always at the last -set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK} ${WIL_TARGET} nlohmann_json::nlohmann_json onnx onnx_proto ${PROTOBUF_LIB} re2::re2 Boost::mp11 safeint_interface flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date ${ONNXRUNTIME_CLOG_TARGET_NAME}) +set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK} ${WIL_TARGET} nlohmann_json::nlohmann_json + onnx onnx_proto ${PROTOBUF_LIB} re2::re2 Boost::mp11 safeint_interface + flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date + ${ONNXRUNTIME_CLOG_TARGET_NAME}) + # The source code of onnx_proto is generated, we must build this lib first before starting to compile the other source code that uses ONNX protobuf types. # The other libs do not have the problem. All the sources are already there. We can compile them in any order. set(onnxruntime_EXTERNAL_DEPENDENCIES onnx_proto flatbuffers::flatbuffers) @@ -638,33 +639,70 @@ if (onnxruntime_USE_WEBGPU) dawn URL ${DEP_URL_dawn} URL_HASH SHA1=${DEP_SHA1_dawn} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch ) - set(DAWN_FETCH_DEPENDENCIES ON) - set(DAWN_ENABLE_INSTALL ON) - set(TINT_BUILD_TESTS OFF) - set(DAWN_USE_BUILT_DXC ON) + + # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size + set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) + set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) + + # disable things we don't use set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) - onnxruntime_fetchcontent_makeavailable(dawn) -endif() + set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) + set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) + set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) + set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) + set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) + + set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving + set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. + + # SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V. + if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE) + endif() -message(STATUS "Finished fetching external dependencies") + if (WIN32) + # building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON. + set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) + set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) -set(onnxruntime_LINK_DIRS ) + # Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it. + set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + endif() + onnxruntime_fetchcontent_makeavailable(dawn) + + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native dawn::dawn_proc) +endif() + +set(onnxruntime_LINK_DIRS) if (onnxruntime_USE_CUDA) - find_package(CUDAToolkit REQUIRED) + find_package(CUDAToolkit REQUIRED) - if(onnxruntime_CUDNN_HOME) - file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) - set(CUDNN_PATH ${onnxruntime_CUDNN_HOME}) - endif() - include(cuDNN) + if(onnxruntime_CUDNN_HOME) + file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) + set(CUDNN_PATH ${onnxruntime_CUDNN_HOME}) + endif() + + include(cuDNN) endif() if(onnxruntime_USE_SNPE) - include(external/find_snpe.cmake) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SNPE_NN_LIBS}) + include(external/find_snpe.cmake) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SNPE_NN_LIBS}) endif() -FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) -FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) +FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) +FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) + +message(STATUS "Finished fetching external dependencies") diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 9b6acea876f9..b1d797ca16ad 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -89,10 +89,22 @@ elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) # create Info.plist for the framework and podspec for CocoaPods (optional) set(MACOSX_FRAMEWORK_NAME "onnxruntime") set(MACOSX_FRAMEWORK_IDENTIFIER "com.microsoft.onnxruntime") - # Need to include CoreML as a weaklink for CocoaPods package if the EP is enabled + + # Setup weak frameworks for macOS/iOS. 'weak' as the CoreML or WebGPU EPs are optionally enabled. if(onnxruntime_USE_COREML) - set(APPLE_WEAK_FRAMEWORK "\\\"CoreML\\\"") + list(APPEND _weak_frameworks "\\\"CoreML\\\"") + endif() + + if(onnxruntime_USE_WEBGPU) + list(APPEND _weak_frameworks "\\\"QuartzCore\\\"") + list(APPEND _weak_frameworks "\\\"IOSurface\\\"") + list(APPEND _weak_frameworks "\\\"Metal\\\"") endif() + + if (_weak_frameworks) + string(JOIN ", " APPLE_WEAK_FRAMEWORK ${_weak_frameworks}) + endif() + set(INFO_PLIST_PATH "${CMAKE_CURRENT_BINARY_DIR}/Info.plist") configure_file(${REPO_ROOT}/cmake/Info.plist.in ${INFO_PLIST_PATH}) configure_file( @@ -364,16 +376,58 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) endif() endforeach() + # helper function that recurses to also handle static library dependencies of the ORT external libraries + set(_processed_libs) # keep track of processed libraries to skip any duplicate dependencies + function(add_symlink_for_static_lib_and_dependencies lib) + function(process cur_target) + # de-alias if applicable so a consistent target name is used + get_target_property(alias ${cur_target} ALIASED_TARGET) + if(TARGET ${alias}) + set(cur_target ${alias}) + endif() + + if(${cur_target} IN_LIST _processed_libs OR ${cur_target} IN_LIST lib_and_dependencies) + return() + endif() + + list(APPEND lib_and_dependencies ${cur_target}) + + get_target_property(link_libraries ${cur_target} LINK_LIBRARIES) + foreach(dependency ${link_libraries}) + if(TARGET ${dependency}) + process(${dependency}) + endif() + endforeach() + + set(lib_and_dependencies ${lib_and_dependencies} PARENT_SCOPE) + endfunction() + + set(lib_and_dependencies) + process(${lib}) + + foreach(_target ${lib_and_dependencies}) + get_target_property(type ${_target} TYPE) + if(${type} STREQUAL "STATIC_LIBRARY") + # message(STATUS "Adding symlink for ${_target}") + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E create_symlink + $ ${STATIC_LIB_DIR}/$) + endif() + endforeach() + + list(APPEND _processed_libs ${lib_and_dependencies}) + set(_processed_libs ${_processed_libs} PARENT_SCOPE) + endfunction() + # for external libraries we create a symlink to the .a file foreach(_LIB ${onnxruntime_EXTERNAL_LIBRARIES}) - if(NOT TARGET ${_LIB}) # if we didn't build from source. it may not a target + if(NOT TARGET ${_LIB}) # if we didn't build from source it may not be a target continue() endif() + GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") - add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ${CMAKE_COMMAND} -E create_symlink - $ ${STATIC_LIB_DIR}/$) + add_symlink_for_static_lib_and_dependencies(${_LIB}) endif() endforeach() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 587c4b2c1ff2..8d00ab5aa449 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -24,14 +24,8 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) - - # Copy webgpu_dawn.dll to the output directory - add_custom_command( - TARGET onnxruntime_providers_webgpu - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" "$" - VERBATIM ) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu + onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native dawn::dawn_proc) set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch new file mode 100644 index 000000000000..d696d386452e --- /dev/null +++ b/cmake/patches/dawn/dawn.patch @@ -0,0 +1,66 @@ +diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt +index 9c0bd6fa4e..bf8a57aeac 100644 +--- a/src/dawn/native/CMakeLists.txt ++++ b/src/dawn/native/CMakeLists.txt +@@ -857,6 +857,11 @@ if (DAWN_ENABLE_SWIFTSHADER) + target_compile_definitions(dawn_native PRIVATE "DAWN_ENABLE_SWIFTSHADER") + endif() + ++if (IOS) ++ target_compile_options(dawn_native_objects PRIVATE -fno-objc-arc) ++ target_compile_options(dawn_native PRIVATE -fno-objc-arc) ++endif() ++ + if (DAWN_BUILD_MONOLITHIC_LIBRARY) + ############################################################################### + # Do the 'complete_lib' build. +diff --git a/src/dawn/native/Surface_metal.mm b/src/dawn/native/Surface_metal.mm +index ce55acbd43..baa4835362 100644 +--- a/src/dawn/native/Surface_metal.mm ++++ b/src/dawn/native/Surface_metal.mm +@@ -36,7 +36,13 @@ + namespace dawn::native { + + bool InheritsFromCAMetalLayer(void* obj) { +- id object = static_cast(obj); ++ id object = ++#if TARGET_OS_IOS ++ (__bridge id)obj; ++#else ++ static_cast(obj); ++#endif ++ + return [object isKindOfClass:[CAMetalLayer class]]; + } + +diff --git a/src/dawn/native/metal/SharedFenceMTL.mm b/src/dawn/native/metal/SharedFenceMTL.mm +index bde8bfea07..f2f6459e91 100644 +--- a/src/dawn/native/metal/SharedFenceMTL.mm ++++ b/src/dawn/native/metal/SharedFenceMTL.mm +@@ -40,7 +40,13 @@ ResultOrError> SharedFence::Create( + DAWN_INVALID_IF(descriptor->sharedEvent == nullptr, "MTLSharedEvent is missing."); + if (@available(macOS 10.14, iOS 12.0, *)) { + return AcquireRef(new SharedFence( +- device, label, static_cast>(descriptor->sharedEvent))); ++ device, label, ++#if TARGET_OS_IOS ++ (__bridge id)(descriptor->sharedEvent) ++#else ++ static_cast>(descriptor->sharedEvent) ++#endif ++ )); + } else { + return DAWN_INTERNAL_ERROR("MTLSharedEvent not supported."); + } +diff --git a/src/tint/api/BUILD.cmake b/src/tint/api/BUILD.cmake +index 0037d83276..6372c4ee77 100644 +--- a/src/tint/api/BUILD.cmake ++++ b/src/tint/api/BUILD.cmake +@@ -57,6 +57,7 @@ tint_target_add_dependencies(tint_api lib + tint_lang_wgsl_ast_transform + tint_lang_wgsl_common + tint_lang_wgsl_features ++ tint_lang_wgsl_inspector + tint_lang_wgsl_program + tint_lang_wgsl_sem + tint_lang_wgsl_writer_ir_to_program diff --git a/include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h b/include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h new file mode 100644 index 000000000000..0b45b847d651 --- /dev/null +++ b/include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Dummy file to provide a signal in the ONNX Runtime C cocoapod as to whether the WebGPU EP was included in the build. +// If it was, this file will be included in the cocoapod, and a test like this can be used: +// +// #if __has_include() +// #define WEBGPU_EP_AVAILABLE 1 +// #else +// #define WEBGPU_EP_AVAILABLE 0 +// #endif + +// The WebGPU EP can be enabled via the generic SessionOptionsAppendExecutionProvider method, so no direct usage of +// the provider factory is required. diff --git a/java/src/main/java/ai/onnxruntime/OrtProvider.java b/java/src/main/java/ai/onnxruntime/OrtProvider.java index ae9cb9f90862..b06f884896ee 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProvider.java +++ b/java/src/main/java/ai/onnxruntime/OrtProvider.java @@ -40,7 +40,9 @@ public enum OrtProvider { /** The XNNPACK execution provider. */ XNNPACK("XnnpackExecutionProvider"), /** The Azure remote endpoint execution provider. */ - AZURE("AzureExecutionProvider"); + AZURE("AzureExecutionProvider"), + /** The WebGPU execution provider */ + WEBGPU("WebGpuExecutionProvider"); private static final Map valueMap = new HashMap<>(values().length); diff --git a/onnxruntime/core/platform/apple/logging/apple_log_sink.mm b/onnxruntime/core/platform/apple/logging/apple_log_sink.mm index 00e691a8f9fd..6abbe76a7f15 100644 --- a/onnxruntime/core/platform/apple/logging/apple_log_sink.mm +++ b/onnxruntime/core/platform/apple/logging/apple_log_sink.mm @@ -7,8 +7,6 @@ #include -#include "date/date.h" - namespace onnxruntime { namespace logging { diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 3b43c87fb0c8..9e8117aa34a9 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -165,7 +165,19 @@ WEBGPU_ELEMENTWISE_KERNEL(Asinh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Acosh, "acosh(a)") WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) +#if __APPLE__ +// Metal returns 0 for values >= 1.0. +// Need custom impl to return +inf for 1.0 (by dividing 1 by 0), and NaN for > 1.0 (by dividing 0 by 0) +WEBGPU_ELEMENTWISE_IMPL(Atanh, + "select(" + " select(x_value_t(1.0), x_value_t(0.0), a > x_value_t(1.0)) / x_value_t(0.0)," + " atanh(a)," + " a < x_value_t(1.0))", + "", + ShaderUsage::UseValueTypeAlias) +#else WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") +#endif WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Not, "!a") diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 2ddc9a6e8160..72f38aecb99c 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -67,6 +67,9 @@ class ShaderIndicesHelper { public: ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + ShaderIndicesHelper(ShaderIndicesHelper&&) = default; + ShaderIndicesHelper& operator=(ShaderIndicesHelper&&) = default; + // get the number of components of the variable. inline int NumComponents() const { return num_components_; } diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index 31806a0af174..1d58538a7489 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -59,7 +59,7 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); - auto expression = [](const std::string& a, const std::string& b, const std::string& c) -> const auto { + const auto expression = [](const std::string& a, const std::string& b, const std::string& c) -> auto { return "select(" + b + ", " + a + ", " + c + ")"; }; std::string assignment; @@ -74,10 +74,10 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& b_indices = shader.AddIndices("b_indices"); const auto& output_indices = shader.AddIndices("output_indices"); - auto single_assignment = + const auto single_assignment = [&expression, &output_indices, &a_indices, &b_indices, &c_indices]( const std::string& rest_str, const std::string& x, const std::string& type_cast = "") - -> const auto { + -> auto { const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index e9ae97369c6c..bb4ae4f6dcce 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -4,6 +4,9 @@ #include #include +#include "dawn/dawn_proc.h" +#include "dawn/native/DawnNative.h" + #include "core/common/common.h" #include "core/providers/webgpu/compute_context.h" @@ -21,6 +24,8 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info std::call_once(init_flag_, [this, &webgpu_ep_info]() { // Initialization.Step.1 - Create wgpu::Instance if (instance_ == nullptr) { + dawnProcSetProcs(&dawn::native::GetProcs()); + wgpu::InstanceDescriptor instance_desc{}; instance_desc.features.timedWaitAnyEnable = true; instance_ = wgpu::CreateInstance(&instance_desc); diff --git a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm index d145a00b1348..32b4b32e299d 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm +++ b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm @@ -13,15 +13,19 @@ #if __has_include() #define COREML_EP_AVAILABLE 1 +#include #else #define COREML_EP_AVAILABLE 0 #endif -#if COREML_EP_AVAILABLE -#include +#if __has_include() +#define WEBGPU_EP_AVAILABLE 1 +// WebGPU EP doesn't require including the header as it's enabled via AppendExecutionProvider +#else +#define WEBGPU_EP_AVAILABLE 0 #endif -void testSigmoid(const char* modelPath, bool useCoreML) { +void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU = false) { // This is an e2e test for ORT C++ API Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "testCppAPI"); @@ -38,6 +42,12 @@ void testSigmoid(const char* modelPath, bool useCoreML) { (void)useCoreML; #endif + if (useWebGPU) { + std::unordered_map provider_options; + // set provider options if needed. e.g. deviceId + session_options.AppendExecutionProvider("WebGPU", provider_options); + } + Ort::Session session(env, modelPath, session_options); size_t input_tensor_size = 3 * 4 * 5; @@ -96,7 +106,7 @@ - (NSString*)getFilePath { } - (void)testCppAPI_Basic { - testSigmoid([self getFilePath].UTF8String, false /* useCoreML */); + testSigmoid([self getFilePath].UTF8String); } #if COREML_EP_AVAILABLE @@ -105,4 +115,9 @@ - (void)testCppAPI_Basic_CoreML { } #endif +#if WEBGPU_EP_AVAILABLE +- (void)testCppAPI_Basic_WebGPU { + testSigmoid([self getFilePath].UTF8String, false /* useCoreML */, true /* useWebGPU */); +} +#endif @end diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm index 613c6e545939..86001b6cb50a 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -13,15 +13,19 @@ #if __has_include() #define COREML_EP_AVAILABLE 1 +#include #else #define COREML_EP_AVAILABLE 0 #endif -#if COREML_EP_AVAILABLE -#include +#if __has_include() +#define WEBGPU_EP_AVAILABLE 1 +// WebGPU EP doesn't require including the header as it's enabled via AppendExecutionProvider +#else +#define WEBGPU_EP_AVAILABLE 0 #endif -void testSigmoid(const char* modelPath, bool useCoreML) { +void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU = false) { // This is an e2e test for ORT C++ API Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "testCppAPI"); @@ -38,6 +42,12 @@ void testSigmoid(const char* modelPath, bool useCoreML) { (void)useCoreML; #endif + if (useWebGPU) { + std::unordered_map provider_options; + // set provider options if needed. e.g. deviceId + session_options.AppendExecutionProvider("WebGPU", provider_options); + } + Ort::Session session(env, modelPath, session_options); size_t input_tensor_size = 3 * 4 * 5; @@ -96,7 +106,7 @@ - (NSString*)getFilePath { } - (void)testCppAPI_Basic { - testSigmoid([self getFilePath].UTF8String, false /* useCoreML */); + testSigmoid([self getFilePath].UTF8String); } #if COREML_EP_AVAILABLE @@ -105,4 +115,10 @@ - (void)testCppAPI_Basic_CoreML { } #endif +#if WEBGPU_EP_AVAILABLE +- (void)testCppAPI_Basic_WebGPU { + testSigmoid([self getFilePath].UTF8String, false /* useCoreML */, true /* useWebGPU */); +} +#endif + @end diff --git a/tools/ci_build/github/android/default_full_aar_build_settings.json b/tools/ci_build/github/android/default_full_aar_build_settings.json index b0eff7581267..f08f246748a5 100644 --- a/tools/ci_build/github/android/default_full_aar_build_settings.json +++ b/tools/ci_build/github/android/default_full_aar_build_settings.json @@ -16,6 +16,7 @@ "--build_shared_lib", "--use_nnapi", "--use_xnnpack", + "--use_webgpu", "--skip_tests" ] } diff --git a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py index 71aeb9e7b030..dd037c17ae3b 100755 --- a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py +++ b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py @@ -133,6 +133,8 @@ def main(): str(build_dir / "framework_out"), "--variant", package_variant.name, + "--test_project_stage_dir", # use a specific directory so it's easier to debug + str(build_dir / "test_apple_packages_staging"), ] run(test_apple_packages_args) diff --git a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json index 84d7e355ed5b..6175ac3a0ad5 100644 --- a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json @@ -19,6 +19,7 @@ "--build_apple_framework", "--use_coreml", "--use_xnnpack", + "--use_webgpu", "--skip_tests", "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" ], diff --git a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json index e2d8f70c02cf..4c2c9442ab21 100644 --- a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json @@ -24,12 +24,14 @@ "--ios", "--use_xcode", "--use_xnnpack", + "--use_webgpu", "--apple_deploy_target=13.0" ], "iphonesimulator": [ "--ios", "--use_xcode", "--use_xnnpack", + "--use_webgpu", "--apple_deploy_target=13.0" ], "macabi":[ diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index 3b661d9eb2dc..c2b140652a2d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -96,7 +96,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" BuildJava: false BuildNodejs: false WithCache: ${{ parameters.WithCache }} @@ -108,7 +108,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} @@ -120,7 +120,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} From e5233ce865bd70d64830698b93e587888ec459c1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:01:49 -0700 Subject: [PATCH 097/114] enable build pipeline on Windows for WebGPU --- .../win-gpu-webgpu-ci-pipeline.yml | 58 +++++++++++++++++++ tools/ci_build/set-trigger-rules.py | 1 + 2 files changed, 59 insertions(+) create mode 100644 tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml new file mode 100644 index 000000000000..c4db7735aaf2 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml @@ -0,0 +1,58 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +### please do rerun set-trigger-rules.py ### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +stages: +- stage: webgpu + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + # add --enable_pybind and --build_java if necessary + additionalBuildFlags: >- + --build_nodejs + --use_webgpu + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: WebGPU + EnablePython: false + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index fb6aa44cdf31..0e9cd514d8aa 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -40,6 +40,7 @@ "win-gpu-training-ci-pipeline.yml", "win-gpu-doc-gen-ci-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", + "win-gpu-webgpu-ci-pipeline.yml", "win-qnn-arm64-ci-pipeline.yml", "win-qnn-ci-pipeline.yml", ] From 0f7a5f6077f0885aa32b0ede324023419badb3c2 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 27 Sep 2024 13:49:09 +0800 Subject: [PATCH 098/114] [webgpu native] Add RotaryEmbedding op (#22194) ### Description ### Motivation and Context --- .../webgpu/bert/rotary_embedding.cc | 134 ++++++++++++++++++ .../webgpu/bert/rotary_embedding.h | 47 ++++++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- .../contrib_ops/rotary_embedding_op_test.cc | 4 + 4 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc new file mode 100644 index 000000000000..eb5cfad87597 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform); + const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + // TODO: remove output_indices. + const auto& output_indices = shader.AddIndices("output_indices", false); + const auto interleaved_str = interleaved_ ? "true" : "false"; + shader.SetMainFunctionBody( + " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" + " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n", + " if (global_idx >= size) { return; }\n" + " if (bsnh[3] < half_rotary_emb_dim) {\n" + " let position_ids_idx = " + + position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) + ";\n" + + " let position_id = u32(" + + position_ids.GetByOffset("position_ids_idx") + ")" + + " + select(0, bsnh[1], position_ids_idx == 0);\n" + " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " + + interleaved_str + + ");\n" + " let j = i + select(half_rotary_emb_dim, 1, " + + interleaved_str + + ");\n" + " let re = " + + input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + "-" + + input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + ";\n" + + " " + output.SetByOffset("i", "re") + "\n" + + " let im = " + input.GetByOffset("i") + " * " + + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + + "+ " + input.GetByOffset("j") + + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + + ";\n " + output.SetByOffset("j", "im") + + "\n" + " } else { \n" + " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + + " " + output.SetByOffset("k", input.GetByOffset("k")) + + "\n" + " }"); + + return Status::OK(); +} + +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { + scale_ = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads_ = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved_ = (info.GetAttrOrDefault("interleaved", 0) == 1); + is_packed_batching_ = (info.GetAttrOrDefault("is_packed_batching", 0) == 1); +} + +Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto input_shape = input->Shape(); + const auto* position_ids = context.Input(1); + const auto* cos_cache = context.Input(2); + const auto* sin_cache = context.Input(3); + auto* output = context.Output(0, input_shape); + + const auto batch_size = gsl::narrow_cast(input->Shape()[0]); + const auto batch_stride = gsl::narrow_cast(input_shape.SizeFromDimension(1)); + const auto sequence_length = gsl::narrow_cast(input_shape[input_shape.NumDimensions() - 2]); + const auto hidden_size = batch_stride / sequence_length; + const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); + const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const TensorShape global_shape({batch_size, + sequence_length, + hidden_size / head_size, + head_size - half_rotary_embedding_dim}); + + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = gsl::narrow_cast(global_shape[j]); + global_strides[j] = gsl::narrow_cast(global_shape.SizeFromDimension(j + 1)); + } + + const auto output_size = gsl::narrow_cast(global_shape.Size()); + RotaryEmbeddingProgram program{interleaved_}; + const auto input_output_strides = + input_shape.NumDimensions() == 3 + ? std::vector({batch_stride, hidden_size, head_size, 1}) + : (input_shape.NumDimensions() == 4 + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + : std::vector({})); + + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::Rank}, + {position_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{scale_}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h new file mode 100644 index 000000000000..0d73b89fb62d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class RotaryEmbeddingProgram final : public Program { + public: + RotaryEmbeddingProgram(bool interleaved) : Program{"RotaryEmbedding"}, interleaved_{interleaved} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32}, + {"global_shape", ProgramUniformVariableDataType::Uint32}, + {"global_stride", ProgramUniformVariableDataType::Uint32}, + {"input_output_stride", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; +}; + +class RotaryEmbedding final : public WebGpuKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + float scale_; + int num_heads_; + int rotary_embedding_dim_; + bool interleaved_; + bool is_packed_batching_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index def104b6cb10..01c8a28d4506 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -47,7 +47,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo Date: Fri, 27 Sep 2024 14:57:21 +0800 Subject: [PATCH 099/114] [webgpu native] Add transpose shared (#22098) ### Description ### Motivation and Context --- .../core/providers/webgpu/tensor/transpose.cc | 91 +++++++++++++++---- .../core/providers/webgpu/tensor/transpose.h | 24 ++--- 2 files changed, 87 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 0962d9191d78..e0a0113e1322 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,11 +47,11 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -const std::string AppendPermFunction(gsl::span perm) { +const std::string AppendPermFunction(gsl::span perm) { std::ostringstream ss; ss.imbue(std::locale::classic()); - ss << "fn perm(i: y_indices_t)->x_indices_t {\n" - " var a: x_indices_t;\n"; + ss << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; for (size_t i = 0; i < perm.size(); ++i) { ss << " a[" << perm[i] << "] = i[" << i << "];\n"; } @@ -60,21 +60,52 @@ const std::string AppendPermFunction(gsl::span perm) { return ss.str(); } +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { + for (auto i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; + Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AppendImplementation(AppendPermFunction(this->perm_)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " let indices = ", output.OffsetToIndices("global_idx"), - ";\n" - " let x_indices = perm(indices); \n" - " ", - output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + + if (use_shared_) { + shader.AppendImplementation("var tile : array, tile_size>;\n"); + shader.SetMainFunctionBody( + " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + " tile[local_id.y][local_id.x] = " + + input.GetByIndices("a_indices_t(input_row, input_col)") + + ";\n" + " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " + + output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") + "\n }"); + } else { + shader.AppendImplementation(AppendPermFunction(this->perm_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " let indices = ", output.OffsetToIndices("global_idx"), + ";\n" + " let x_indices = perm(indices);\n", + " ", + output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + } return Status::OK(); } Status Transpose::ComputeInternal(ComputeContext& context) const { - // TODO: there is an optimized version of transpose to port. const auto* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); @@ -86,16 +117,42 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const bool channels_last = new_perm == InlinedVector({2, 3, 1}); + const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; + auto new_input_shape = input_shape; + TensorShape new_output_shape(output_dims); + if (use_shared) { + new_input_shape = channels_last + ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) + : channels_first + ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) + : new_shape; + new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); + } + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); - TransposeProgram program{*p_perm}; + TransposeProgram program{*p_perm, use_shared}; + if (use_shared) { + program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); + } + program .CacheHint(absl::StrJoin(*p_perm, "-")) - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({output_tensor}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) .AddUniformVariables({ {static_cast(output_size)}, }); + + use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 3ca5674d5dfa..7cf5c1fe0865 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -11,26 +11,28 @@ namespace onnxruntime { namespace webgpu { +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + Status ComputeInternal(ComputeContext& context) const override; + constexpr static uint32_t TILE_SIZE = 16; +}; + class TransposeProgram final : public Program { public: - TransposeProgram(const gsl::span& permutations) - : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) { + TransposeProgram(const gsl::span& permutations, bool use_shared) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) { } Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_CONSTANTS({"tile_size", Transpose::TILE_SIZE}); private: - InlinedVector perm_; -}; - -class Transpose final : public WebGpuKernel, public TransposeBase { - public: - Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { - } - - Status ComputeInternal(ComputeContext& context) const override; + InlinedVector perm_; + const bool use_shared_; }; } // namespace webgpu From b1b5e1fd1a718c1dcae3a96233a32383c26a1db5 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 27 Sep 2024 15:19:28 +0800 Subject: [PATCH 100/114] [webgpu-native] Add gather (#22183) ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../core/providers/webgpu/tensor/gather.cc | 82 +++++++++++++++++++ .../core/providers/webgpu/tensor/gather.h | 34 ++++++++ .../webgpu/webgpu_execution_provider.cc | 6 +- 3 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/gather.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/gather.h diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc new file mode 100644 index 000000000000..31e0a9e88323 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/gather.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + std::ostringstream calc_data_indices; + calc_data_indices.imbue(std::locale::classic()); + calc_data_indices << " var indices_indices = input_indices_indices_t(0);\n"; + for (int i = 0; i < indices.Rank(); i++) { + calc_data_indices << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; + } + calc_data_indices << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(uniforms.data_shape[" << axis_ << "]);\n" + << " }\n" + << " var data_indices : data_indices_t;\n"; + for (int i = 0, j = 0; i < data.Rank(); i++) { + if (i == SafeInt(axis_)) { + calc_data_indices << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + j += indices.Rank(); + } else { + calc_data_indices << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; + j++; + } + } + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), + " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", + calc_data_indices.str(), " ", + output.SetByOffset("global_idx", data.GetByIndices("data_indices"))); + + return Status::OK(); +} + +Status Gather::ComputeInternal(ComputeContext& context) const { + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + uint32_t data_size = SafeInt(p.output_tensor->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + uint32_t axis = static_cast(p.axis); + GatherProgram program{axis}; + program + .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .CacheHint(std::to_string(axis)) + .AddUniformVariables({{data_size}}); + return context.RunProgram(program); +} + +#define WEBGPU_GATHER_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +#define WEBGPU_GATHER_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.h b/onnxruntime/core/providers/webgpu/tensor/gather.h new file mode 100644 index 000000000000..bebe13519ce4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/tensor/gatherbase.h" + +namespace onnxruntime { +namespace webgpu { + +class GatherProgram final : public Program { + public: + GatherProgram(const uint32_t axis) : Program{"Gather"}, axis_{axis} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t axis_; +}; + +class Gather final : public WebGpuKernel, public GatherBase { + public: + Gather(const OpKernelInfo& info) : WebGpuKernel(info), GatherBase(info) {} + + protected: + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index f5d66d6a2413..df2a2caa0a1f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -641,9 +641,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, From 92a08e2d13f3498fd8eeb8ab572f096ceb2c86e7 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Fri, 27 Sep 2024 02:55:55 -0700 Subject: [PATCH 101/114] [Native-WebGPU] Add Concat (#22225) ### Description Add Concat operator support ### Motivation and Context Required for WebGPU EP --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../core/providers/webgpu/compute_context.h | 4 + .../core/providers/webgpu/tensor/concat.cc | 155 ++++++++++++++++++ .../core/providers/webgpu/tensor/concat.h | 36 ++++ .../webgpu/webgpu_execution_provider.cc | 8 +- 4 files changed, 199 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/concat.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/concat.h diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index b7ea8a58e232..455eb4452f85 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -113,6 +113,10 @@ class ComputeContext { return webgpu_context_.Run(*this, program); } + inline OpKernelContext& GetKernelContext() { + return kernel_context_; + } + // // Push error scope. // diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc new file mode 100644 index 000000000000..671a6a1ed072 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/tensor/concat.h" + +#include "core/common/inlined_containers.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +#define WEBGPU_CONCAT_KERNEL(version) \ + ONNX_OPERATOR_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +WEBGPU_CONCAT_VERSIONED_KERNEL(1, 3) +WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) +WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) +WEBGPU_CONCAT_KERNEL(13) + +const std::string AppendCalCulateInputIndexFunction(size_t input_count) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + ss << "fn calculate_input_index(index: u32) -> u32 {" << std::endl + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {" << std::endl + << " if (index < uniforms.size_in_concat_axis[i]) {" << std::endl + << " return i;" << std::endl + << " }" << std::endl + << " }" << std::endl + << " return " << input_count << ";" << std::endl + << "}" << std::endl; + return ss.str(); +} + +const void AppendAssignOutput(std::ostringstream& ss, const ShaderVariableHelper& input, const ShaderVariableHelper& output) { + ss << output.SetByOffset("global_idx", input.GetByIndices("indices")) << ";" << std::endl; +} + +const std::string AppendAssignOutputDataFunction(gsl::span inputs, const ShaderVariableHelper& output) { + std::ostringstream ss; + size_t input_count = inputs.size(); + ss.imbue(std::locale::classic()); + ss << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {" << std::endl; + if (input_count == 0) { + AppendAssignOutput(ss, *inputs[0], output); + } else { + for (size_t i = 0; i < input_count; ++i) { + if (i == 0) { + ss << " if (input_index == 0u) {" << std::endl; + } else if (i == input_count - 1) { + ss << " } else {" << std::endl; + } else { + ss << " } else if (input_index == " << i << "u) {" << std::endl; + } + ss << " "; + AppendAssignOutput(ss, *inputs[i], output); + } + ss << " }" << std::endl; + } + ss << "}" << std::endl; + return ss.str(); +} +Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { + size_t input_count = Inputs().size(); + std::vector inputs; + inputs.reserve(input_count); + for (size_t i = 0; i < input_count; ++i) { + inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + shader.AppendImplementation(AppendCalCulateInputIndexFunction(input_count)); + shader.AppendImplementation(AppendAssignOutputDataFunction(gsl::make_span(inputs), output)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " var indices = ", output.OffsetToIndices("global_idx"), ";\n", + " let indices_axis = ", output.IndicesGet("indices", axis_), ";\n", + " let input_index = calculate_input_index(indices_axis);\n", + " if (input_index != 0u) {\n", + " ", output.IndicesSet("indices", axis_, "indices_axis - uniforms.size_in_concat_axis[input_index - 1]"), ";\n", + " }\n", + " assign_output_data(global_idx, input_index, indices);\n"); + return Status::OK(); +} + +Status Concat::ComputeInternal(ComputeContext& context) const { + int input_count = context.InputCount(); + InlinedTensorsVector input_tensors; + input_tensors.reserve(input_count); + for (int i = 0; i < input_count; ++i) { + input_tensors.push_back(context.Input(i)); + } + + Prepare prepare; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.GetKernelContext(), input_tensors, prepare)); + if (prepare.output_num_elements == 0) { + return Status::OK(); + } + + uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); + + ConcatProgram program{prepare.axis}; + + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(input_count); + uint32_t sum = 0; + for (int i = 0; i < input_count; ++i) { + const auto& input = prepare.inputs[i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + + auto axis_size = input.tensor->Shape()[prepare.axis]; + sum += static_cast(axis_size); + sizes_in_concat_axis.push_back(sum); + } + + size_t non_empty_input_count = sizes_in_concat_axis.size(); + + if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { + // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", + input_count, ", output=1) exceeds the limit (", + context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + } + + program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), + output_size}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h new file mode 100644 index 000000000000..0f6e6dd327e3 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/concatbase.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class ConcatProgram final : public Program { + public: + ConcatProgram(size_t axis) : Program{"Concat"}, axis_{axis} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + size_t axis_; +}; + +class Concat final : public WebGpuKernel, public ConcatBase { + public: + Concat(const OpKernelInfo& info) : WebGpuKernel(info), ConcatBase(info) { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index df2a2caa0a1f..c1f13d652413 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -628,10 +628,10 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, From 8da1f7a157b554261da9482c3d7ffae2ae0617f6 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 27 Sep 2024 18:11:27 +0800 Subject: [PATCH 102/114] [webgpu-native] Add MatmulNBits (#22150) ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../webgpu/quantization/matmul_nbits.cc | 294 ++++++++++++++++++ .../webgpu/quantization/matmul_nbits.h | 53 ++++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- onnxruntime/core/providers/webgpu/program.cc | 90 ++++-- onnxruntime/core/providers/webgpu/program.h | 21 +- .../core/providers/webgpu/shader_helper.cc | 24 +- .../core/providers/webgpu/shader_variable.cc | 103 +++--- .../test/contrib_ops/matmul_4bits_test.cc | 16 +- 8 files changed, 501 insertions(+), 102 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc new file mode 100644 index 000000000000..b1f1a3a9ad8d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { +// Put it to a common place? +uint32_t GetMaxComponents(uint32_t size) { + // we cannot use vec3 type since it has alignment of 16 bytes + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + + return 1; +} + +std::string QuantizedDataType(int components) { + switch (components) { + case 1: + return "array"; + case 2: + return "mat4x2"; + case 4: + return "mat2x4"; + default: + return "array"; + } +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + MatMulNBits); + +Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); + const int output_element_number = y.NumComponents() * SafeInt(output_number_); + std::ostringstream prepare_scale_and_zero_point; + prepare_scale_and_zero_point.imbue(std::locale::classic()); + prepare_scale_and_zero_point << " var col_index = col * " << y.NumComponents() << ";\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + prepare_scale_and_zero_point << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + << " var zero_point_byte_count: u32;\n" + << " var zero_point_word_index: u32;\n" + << " var zero_point_byte_offset: u32;\n" + << " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + << " var zero_point_bits_offset: u32;\n" + << " var zero_point_word: u32;\n"; + for (int c = 0; c < output_element_number; c++) { + prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; + prepare_scale_and_zero_point << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" + << " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + << " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + << " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n"; + prepare_scale_and_zero_point << " col_index += 1;\n"; + } + } else { + prepare_scale_and_zero_point << " let zero_point = output_element_t(8.0);\n"; + for (int c = 0; c < output_element_number; c++) { + prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; + prepare_scale_and_zero_point << " col_index += 1;\n"; + } + } + + std::ostringstream prepare_b_data; + prepare_b_data.imbue(std::locale::classic()); + prepare_b_data << " col_index = col * " << y.NumComponents() << ";\n"; + for (int c = 0; c < output_element_number; c++) { + prepare_b_data << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" + << " col_index += 1;\n"; + } + prepare_b_data << " var b_value : u32;\n" + << " let b_mask : u32 = 0x0F0F0F0Fu;\n" + << " var b_value_lower : vec4;\n" + << " var b_value_upper : vec4;\n" + << " var b_quantized_values : " << quantized_data_type << ";\n" + << " var b_dequantized_values : " << quantized_data_type << ";\n"; + + std::ostringstream process_one_word; + process_one_word.imbue(std::locale::classic()); + process_one_word << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" + << " var a_data: " << quantized_data_type << ";\n" + << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" + << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" + << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" + << " input_offset++;\n" + << " } else {\n" + << " a_data[j] = input_a_value_t(0);\n" + << " }\n" + << " }\n"; + for (int c = 0; c < output_element_number; c++) { + process_one_word << " b_value = " << "b" << c << "_data"; + if (components_b_ > 1) { + process_one_word << "[i]"; + } + process_one_word << ";\n" + << " b_value_lower = unpack4xU8(b_value & b_mask);\n" + << " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" + << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + << " b_dequantized_values = "; + if (a.NumComponents() == 1) { + if (has_zero_points_) { + process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; + } else { + process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " + << "(b_quantized_values[1] - zero_point) * scale" << c << "," + << "(b_quantized_values[2] - zero_point) * scale" << c << "," + << "(b_quantized_values[3] - zero_point) * scale" << c << "," + << "(b_quantized_values[4] - zero_point) * scale" << c << "," + << "(b_quantized_values[5] - zero_point) * scale" << c << "," + << "(b_quantized_values[6] - zero_point) * scale" << c << "," + << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; + } + } else { + process_one_word << "(b_quantized_values - " << quantized_data_type << "("; + for (int i = 0; i < 8; i++) { + if (has_zero_points_) { + process_one_word << "zero_point" << c; + } else { + process_one_word << "zero_point"; + } + if (i < 7) { + process_one_word << ", "; + } + } + process_one_word << ")) * scale" << c << ";\n"; + } + + process_one_word << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; + if (y.NumComponents() > 1) { + process_one_word << "[" << c % y.NumComponents() << "]"; + } + process_one_word << " += "; + if (a.NumComponents() == 1) { + process_one_word << "a_data[0] * b_dequantized_values[0] + " + << "a_data[1] * b_dequantized_values[1] + " + << "a_data[2] * b_dequantized_values[2] + " + << "a_data[3] * b_dequantized_values[3] + " + << "a_data[4] * b_dequantized_values[4] + " + << "a_data[5] * b_dequantized_values[5] + " + << "a_data[6] * b_dequantized_values[6] + " + << "a_data[7] * b_dequantized_values[7];\n"; + } else if (a.NumComponents() == 2) { + process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " + << "dot(a_data[1], b_dequantized_values[1]) + " + << "dot(a_data[2], b_dequantized_values[2]) + " + << "dot(a_data[3], b_dequantized_values[3]);\n"; + } else if (a.NumComponents() == 4) { + process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " + << "dot(a_data[1], b_dequantized_values[1]);\n"; + } + } + + const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; + std::string offset = "workgroup_idx * " + std::to_string(output_number_); + shader.AppendImplementation("var workgroup_shared : array;\n"); + shader.SetMainFunctionBody(" let output_indices = ", y.OffsetToIndices(offset), + ";\n" + " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + " let blob_size = uniforms.input_b_shape[2]" + ";\n" + " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" + " var word_offset = block * uniforms.block_size / ", + a.NumComponents(), ";\n", + prepare_scale_and_zero_point.str(), + " for (var word: u32 = 0; word < blob_size; word += 1) {\n", + prepare_b_data.str(), + " for (var i: u32 = 0; i < ", components_b_, "; i++) {\n", + process_one_word.str(), + " word_offset += ", 8 / a.NumComponents(), + ";\n" + " }\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + " if (local_id.x < ", + output_number_, + ") {\n" + " var output_value = output_value_t(0);\n" + " var workgroup_shared_offset = local_id.x;\n" + " let blocks_num = min(", + shared_memory_size, + ", n_blocks_per_col);\n" + " for (var b = 0u; b < blocks_num; b++) {\n" + " output_value += workgroup_shared[workgroup_shared_offset];\n" + " workgroup_shared_offset += ", + output_number_, + ";\n" + " }\n", + " ", + y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value"), + "\n" + " }\n"); + + return Status::OK(); +} + +Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* b = context.Input(1); + const Tensor* scales = context.Input(2); + const Tensor* zero_points = context.Input(3); + const Tensor* g_idx = context.Input(4); + const Tensor* bias = context.Input(5); + + ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet."); + ORT_ENFORCE(bias == nullptr, "bias as input is not supported yet."); + + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + auto* y = context.Output(0, helper.OutputShape()); + const uint32_t data_size = SafeInt(y->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const uint32_t batch_count = SafeInt(helper.OutputOffsets().size()); + const uint32_t M = SafeInt(helper.M()); + const uint32_t N = SafeInt(helper.N()); + const uint32_t K = SafeInt(helper.K()); + const uint32_t block_size = SafeInt(block_size_); + const uint32_t nbits = 4; + + const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; + const uint32_t blob_size = (block_size / 8) * nbits; + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + const uint32_t components = GetMaxComponents(N); + // TODO: Support output_number > 1. Some cases are failed when output_number > 1. + // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; + const uint32_t output_number = 1; + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + const bool has_zero_points = zero_points != nullptr; + MatMulNBitsProgram program{output_number, SafeInt(components_b), has_zero_points}; + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, SafeInt(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, SafeInt(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, SafeInt(components)}) + .SetDispatchGroupSize(data_size / components / output_number) + .AddUniformVariable({block_size}) + .CacheHint(std::to_string(output_number)); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h new file mode 100644 index 000000000000..7fec1423faf0 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MatMulNBitsProgram final : public Program { + public: + MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points) : Program{"MatMulNBits"}, + output_number_{output_number}, + components_b_{components_b}, + has_zero_points_{has_zero_points} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t output_number_; + int components_b_; + bool has_zero_points_; +}; + +class MatMulNBits final : public WebGpuKernel { + public: + MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + N_ = info.GetAttr("N"); + block_size_ = info.GetAttr("block_size"); + int64_t bits = info.GetAttr("bits"); + ORT_ENFORCE(bits == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 01c8a28d4506..b5d7a90b9bbf 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -44,7 +44,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 75c3c9ee9608..25c0a4278be8 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -84,21 +84,24 @@ std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) #ifndef NDEBUG constexpr std::string_view ProgramVariableDataTypeName[] = { - "f32", // f32 - "f32x2", // vec2f32 - "f32x4", // vec4f32 - "f16", // f16 - "f16x2", // vec2f16 - "f16x4", // vec4f16 - "i32", // i32 - "i32x2", // vec2i32 - "i32x4", // vec4i32 - "u32", // u32 - "u32x2", // vec2u32 - "u32x4", // vec4u32 - "i64", // int64 - "u64", // uint64 - "boolx4", // vec4bool + "f32", // Float32 + "f32x2", // Float32x2 + "f32x4", // Float32x4 + "f16", // Float16 + "f16x2", // Float16x2 + "f16x4", // Float16x4 + "i32", // Int32 + "i32x2", // Int32x2 + "i32x4", // Int32x4 + "u32", // Uint32 + "u32x2", // Uint32x2 + "u32x4", // Uint32x4 + "i64", // Int64 + "u64", // Uint64 + "boolx4", // Boolx4 + "u8x4", // Uint8x4 + "u8x8", // Uint8x8 + "u8x16", // Uint8x16 }; std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; @@ -115,17 +118,22 @@ int NumberOfComponents(ProgramVariableDataType type) { case ProgramVariableDataType::Uint64: case ProgramVariableDataType::Float16: return 1; - case ProgramVariableDataType::Vec2Float32: - case ProgramVariableDataType::Vec2Int32: - case ProgramVariableDataType::Vec2Uint32: - case ProgramVariableDataType::Vec2Float16: + case ProgramVariableDataType::Float32x2: + case ProgramVariableDataType::Int32x2: + case ProgramVariableDataType::Uint32x2: + case ProgramVariableDataType::Float16x2: return 2; - case ProgramVariableDataType::Vec4Float32: - case ProgramVariableDataType::Vec4Int32: - case ProgramVariableDataType::Vec4Uint32: - case ProgramVariableDataType::Vec4Float16: - case ProgramVariableDataType::Vec4Bool: + case ProgramVariableDataType::Float32x4: + case ProgramVariableDataType::Int32x4: + case ProgramVariableDataType::Uint32x4: + case ProgramVariableDataType::Float16x4: + case ProgramVariableDataType::Boolx4: + case ProgramVariableDataType::Uint8x4: return 4; + case ProgramVariableDataType::Uint8x8: + return 8; + case ProgramVariableDataType::Uint8x16: + return 16; default: return -1; } @@ -152,28 +160,44 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp } else if (component == 2) { switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return ProgramVariableDataType::Vec2Float32; + return ProgramVariableDataType::Float32x2; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - return ProgramVariableDataType::Vec2Float16; + return ProgramVariableDataType::Float16x2; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return ProgramVariableDataType::Vec2Int32; + return ProgramVariableDataType::Int32x2; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - return ProgramVariableDataType::Vec2Uint32; + return ProgramVariableDataType::Uint32x2; default: return ProgramVariableDataType::InvalidType; } } else if (component == 4) { switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return ProgramVariableDataType::Vec4Float32; + return ProgramVariableDataType::Float32x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - return ProgramVariableDataType::Vec4Float16; + return ProgramVariableDataType::Float16x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return ProgramVariableDataType::Vec4Int32; + return ProgramVariableDataType::Int32x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - return ProgramVariableDataType::Vec4Uint32; + return ProgramVariableDataType::Uint32x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - return ProgramVariableDataType::Vec4Bool; + return ProgramVariableDataType::Boolx4; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 8) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x8; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 16) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x16; default: return ProgramVariableDataType::InvalidType; } diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index f05ca9c2bf22..bd9a26b0fcfb 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -172,20 +172,23 @@ constexpr SafeInt WORKGROUP_SIZE = 64; enum class ProgramVariableDataType { InvalidType = -1, Float32, - Vec2Float32, - Vec4Float32, + Float32x2, + Float32x4, Float16, - Vec2Float16, - Vec4Float16, + Float16x2, + Float16x4, Int32, - Vec2Int32, - Vec4Int32, + Int32x2, + Int32x4, Uint32, - Vec2Uint32, - Vec4Uint32, + Uint32x2, + Uint32x4, Int64, Uint64, - Vec4Bool, + Boolx4, + Uint8x4, + Uint8x8, + Uint8x16 }; #ifndef NDEBUG std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index c229e821cbf8..a88687fce18b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -114,27 +114,27 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || - var_type == ProgramVariableDataType::Vec2Float32 || - var_type == ProgramVariableDataType::Vec4Float32, + var_type == ProgramVariableDataType::Float32x2 || + var_type == ProgramVariableDataType::Float32x4, "Unexpected program variable type ", int(var_type), " for float32 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 || - var_type == ProgramVariableDataType::Vec2Float16 || - var_type == ProgramVariableDataType::Vec4Float16, + var_type == ProgramVariableDataType::Float16x2 || + var_type == ProgramVariableDataType::Float16x4, "Unexpected program variable type ", int(var_type), " for float16 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || - var_type == ProgramVariableDataType::Vec2Int32 || - var_type == ProgramVariableDataType::Vec4Int32, + var_type == ProgramVariableDataType::Int32x2 || + var_type == ProgramVariableDataType::Int32x4, "Unexpected program variable type ", int(var_type), " for int32 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 || - var_type == ProgramVariableDataType::Vec2Uint32 || - var_type == ProgramVariableDataType::Vec4Uint32, + var_type == ProgramVariableDataType::Uint32x2 || + var_type == ProgramVariableDataType::Uint32x4, "Unexpected program variable type ", int(var_type), " for uint32 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: @@ -146,9 +146,15 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va "Unexpected program variable type ", int(var_type), " for uint64 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Vec4Bool, + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Boolx4, "Unexpected program variable type ", int(var_type), " for bool tensor"); break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint8x4 || + var_type == ProgramVariableDataType::Uint8x8 || + var_type == ProgramVariableDataType::Uint8x16, + "Unexpected program variable type ", int(var_type), " for uint8 tensor"); + break; default: ORT_RETURN_IF(true, "Unsupported data type: ", element_type); // todo: add int4/uint4 diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f2a5b049b477..cbc39c86e504 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -15,57 +15,66 @@ namespace webgpu { namespace { constexpr static const std::string_view STORAGE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "vec2", // int64 - "vec2", // uint64 - "u32", // vec4bool + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "vec2", // Int64 + "vec2", // Uint64 + "u32", // Boolx4 + "u32", // Uint8x4 + "vec2", // Uint8x8 + "vec4", // Uint8x16 }; constexpr static const std::string_view VALUE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "i32", // int64 (trancated to i32) - "u32", // uint64 (trancated to u32) - "vec4", // vec4bool + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "i32", // Int64 (trancated to i32) + "u32", // Uint64 (trancated to u32) + "vec4", // Boolx4 + "u32", // Uint8x4 (u32 as 4 elements of uint8) + "vec2", // Uint8x8 (vec2 as 2x4 elements of uint8) + "vec4", // Uint8x16 (vec4 as 4x4 elements of uint8) }; constexpr static const std::string_view ELEMENT_TYPE[] = { - "f32", // f32 - "f32", // vec2f32 - "f32", // vec4f32 - "f16", // f16 - "f16", // vec2f16 - "f16", // vec4f16 - "i32", // i32 - "i32", // vec2i32 - "i32", // vec4i32 - "u32", // u32 - "u32", // vec2u32 - "u32", // vec4u32 - "i32", // int64 - "u32", // uint64 - "bool", // vec4bool + "f32", // Float32 + "f32", // Float32x2 + "f32", // Float32x4 + "f16", // Float16 + "f16", // Float16x2 + "f16", // Float16x4 + "i32", // Int32 + "i32", // Int32x2 + "i32", // Int32x4 + "u32", // Uint32 + "u32", // Uint32x2 + "u32", // Uint32x4 + "i32", // Int64 + "u32", // Uint64 + "bool", // Boolx4 + "u32", // Uint8x4 + "u32", // Uint8x8 + "u32", // Uint8x16 }; inline std::string GetIndicesType(int rank) { @@ -263,7 +272,7 @@ std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const case onnxruntime::webgpu::ProgramVariableDataType::Uint64: ss << ElementType() << "(" << name_ << "[" << offset << "].x)"; break; - case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: ss << "vec4(bool(" << name_ << "[" << offset << "] & 0xFFu), bool(" << name_ << "[" << offset << "] & 0xFF00u), bool(" @@ -291,7 +300,7 @@ std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std:: case onnxruntime::webgpu::ProgramVariableDataType::Uint64: ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), 0u);"; break; - case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: ss << name_ << "[" << offset << "]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(" << value << "));"; break; default: diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index fa7c6bce7c23..669beb055309 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -278,7 +278,11 @@ void TestMatMulNBitsTyped() { base_opts.output_abs_error = 0.1f; } else { if constexpr (std::is_same::value) { +#ifdef USE_WEBGPU + base_opts.output_abs_error = 0.03f; +#else base_opts.output_abs_error = 0.01f; +#endif } } @@ -293,7 +297,7 @@ void TestMatMulNBitsTyped() { RunTest(opts); } -#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) && !defined(USE_WEBGPU) { TestOptions opts = base_opts; opts.has_g_idx = true; @@ -324,7 +328,7 @@ void TestMatMulNBitsTyped() { opts.has_zero_point = true, opts.zp_is_4bit = false; RunTest(opts); } -#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) && !defined(USE_WEBGPU) { TestOptions opts = base_opts; @@ -358,7 +362,7 @@ TEST(MatMulNBits, Float16) { #endif #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) namespace { // Legacy test function. @@ -393,6 +397,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); #endif +#ifdef USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); +#endif RunTest(opts, std::move(execution_providers)); } else { @@ -437,6 +444,9 @@ TEST(MatMulNBits, Float16Large) { // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. float abs_error = 0.3f; +#elif USE_WEBGPU + // See Intel A770 to pass these tests with an absolute error of 0.08. + float abs_error = 0.08f; #else float abs_error = 0.05f; #endif From f9b6b7c9b3c6be447de410805c99a2aff6b792dd Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Sun, 29 Sep 2024 20:35:43 -0700 Subject: [PATCH 103/114] [WebGPU-Native] Tile Operator (#22239) Adds WebGPU implementation for Tile operator. --- .../core/providers/webgpu/tensor/tile.cc | 90 +++++++++++++++++++ .../core/providers/webgpu/tensor/tile.h | 30 +++++++ .../webgpu/webgpu_execution_provider.cc | 4 +- .../test/providers/cpu/tensor/tile_op_test.cc | 4 +- 4 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/tile.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/tile.h diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.cc b/onnxruntime/core/providers/webgpu/tensor/tile.cc new file mode 100644 index 000000000000..2737b6dafea8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/tile.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Tile, + kOnnxDomain, + 6, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +ONNX_OPERATOR_KERNEL_EX( + Tile, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +Status TileProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + ss << "var input_indices: input_indices_t;\n"; + for (auto i = 0; i < input.Rank(); i++) { + std::string input_dim_i = "input_dim_" + std::to_string(i); + std::string input_dim_value = "input_dim_" + std::to_string(i) + "_value"; + ss << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n"; + ss << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n"; + ss << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; + } + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", + ss.str(), + output.SetByOffset("global_idx", input.GetByIndices("input_indices"))); + + return Status::OK(); +} + +Status Tile::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + + const auto* repeats_tensor = context.Input(1); + const auto* repeats_data = repeats_tensor->Data(); + std::vector repeats; + + for (size_t i = 0; i < static_cast(repeats_tensor->Shape().Size()); i++) { + repeats.push_back(static_cast(repeats_data[i])); + } + + auto output_dims = input_shape.AsShapeVector(); + for (size_t axis = 0; axis < input_rank; axis++) { + output_dims[axis] *= repeats[axis]; + } + + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_tensor->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + TileProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {repeats}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.h b/onnxruntime/core/providers/webgpu/tensor/tile.h new file mode 100644 index 000000000000..9b6ab420b325 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class TileProgram final : public Program { + public: + TileProgram() : Program{"Tile"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"repeats", ProgramUniformVariableDataType::Uint32}); +}; + +class Tile final : public WebGpuKernel { + public: + Tile(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index c1f13d652413..a43428f55ce8 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -668,8 +668,8 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index b517b1a2837f..5902fbe3ddd6 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -142,7 +142,7 @@ void RunTestWrapper() { RunTest({2, 1, 3}, {2, 2, 1}); RunTest({2, 1, 3}, {2, 2, 1}, true); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) // _TileMemcpyKernelFromInput, vectorized 4 RunTest({256, 512}, {3, 1}); @@ -253,7 +253,7 @@ TEST(TensorOpTest, TileStringType) { RunTestWrapper(); } TEST(TensorOpTest, TileBoolType) { RunTestWrapperForBool(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(TensorOpTest, TileMLFloat16Type) { RunTestWrapper(); } #endif From c1ae1fd88799420d0a3619243b588351aeda768c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 02:28:29 -0700 Subject: [PATCH 104/114] use Abseil OStringStream in WebGPU EP string concat (#22241) ### Description use Abseil OStringStream in WebGPU EP string concat. Allows multiple calls to shader_helper.MainFunctionBody() --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 27 +- .../webgpu/bert/rotary_embedding.cc | 49 ++-- .../webgpu/quantization/matmul_nbits.cc | 250 +++++++++--------- .../core/providers/webgpu/docs/Conventions.md | 23 +- .../webgpu/math/binary_elementwise_ops.cc | 123 +++++---- .../webgpu/math/unary_elementwise_ops.cc | 8 +- onnxruntime/core/providers/webgpu/program.cc | 2 +- onnxruntime/core/providers/webgpu/program.h | 8 +- .../providers/webgpu/program_cache_key.cc | 16 +- .../core/providers/webgpu/shader_helper.cc | 47 ++-- .../core/providers/webgpu/shader_helper.h | 32 +-- .../core/providers/webgpu/shader_macros.h | 66 ----- .../core/providers/webgpu/shader_variable.cc | 108 ++++---- .../core/providers/webgpu/shader_variable.h | 4 +- .../core/providers/webgpu/string_macros.h | 18 ++ .../core/providers/webgpu/string_utils.h | 46 ++++ .../core/providers/webgpu/tensor/cast.cc | 6 +- .../core/providers/webgpu/tensor/concat.cc | 83 +++--- .../core/providers/webgpu/tensor/expand.cc | 8 +- .../core/providers/webgpu/tensor/gather.cc | 27 +- .../core/providers/webgpu/tensor/tile.cc | 22 +- .../core/providers/webgpu/tensor/transpose.cc | 66 ++--- .../core/providers/webgpu/tensor/where.cc | 66 +++-- 23 files changed, 524 insertions(+), 581 deletions(-) delete mode 100644 onnxruntime/core/providers/webgpu/shader_macros.h create mode 100644 onnxruntime/core/providers/webgpu/string_macros.h create mode 100644 onnxruntime/core/providers/webgpu/string_utils.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 52459b0632d5..d1e5f53d7f63 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -24,22 +24,23 @@ Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform); - std::string add_bias = ""; + shader.AdditionalImplementation() << TanhImpl; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " var a = " << x.GetByOffset("global_idx") << ";\n"; if (Inputs().size() > 1) { const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); - add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n" - " a += x_value_t(" + - bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " + - bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") + ", " + - bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") + ", " + - bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") + ");\n" - : " a += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + if (bias_components_ == 1) { + shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n" + " a += x_value_t(" + << bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n"; + } else { + shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } } - shader.AppendImplementation(TanhImpl); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " var a = ", x.GetByOffset("global_idx"), ";\n", - add_bias, - y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr)); + shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index eb5cfad87597..85ab94706b14 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -29,38 +29,23 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { // TODO: remove output_indices. const auto& output_indices = shader.AddIndices("output_indices", false); const auto interleaved_str = interleaved_ ? "true" : "false"; - shader.SetMainFunctionBody( - " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" - " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" - " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n", - " if (global_idx >= size) { return; }\n" - " if (bsnh[3] < half_rotary_emb_dim) {\n" - " let position_ids_idx = " + - position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) + ";\n" + - " let position_id = u32(" + - position_ids.GetByOffset("position_ids_idx") + ")" + - " + select(0, bsnh[1], position_ids_idx == 0);\n" - " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " + - interleaved_str + - ");\n" - " let j = i + select(half_rotary_emb_dim, 1, " + - interleaved_str + - ");\n" - " let re = " + - input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + "-" + - input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + ";\n" + - " " + output.SetByOffset("i", "re") + "\n" + - " let im = " + input.GetByOffset("i") + " * " + - sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + - "+ " + input.GetByOffset("j") + - " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + - ";\n " + output.SetByOffset("j", "im") + - "\n" - " } else { \n" - " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + - " " + output.SetByOffset("k", input.GetByOffset("k")) + - "\n" - " }"); + shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" + " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n" + " if (global_idx >= size) { return; }\n" + " if (bsnh[3] < half_rotary_emb_dim) {\n" + << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n" + << " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n" + << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n" + << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n" + << " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("i", "re") << "\n" + << " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("j", "im") << "\n" + << " } else { \n" + " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + << " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n" + << " }"; return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index b1f1a3a9ad8d..2057627c27c2 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -59,175 +59,161 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); const int output_element_number = y.NumComponents() * SafeInt(output_number_); - std::ostringstream prepare_scale_and_zero_point; - prepare_scale_and_zero_point.imbue(std::locale::classic()); - prepare_scale_and_zero_point << " var col_index = col * " << y.NumComponents() << ";\n"; + + const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; + std::string offset = "workgroup_idx * " + std::to_string(output_number_); + shader.AdditionalImplementation() << "var workgroup_shared : array;\n"; + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + " let blob_size = uniforms.input_b_shape[2];\n" + " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" + << " var word_offset = block * uniforms.block_size / " << a.NumComponents() << ";\n"; + + // prepare scale and zero point + shader.MainFunctionBody() << " var col_index = col * " << y.NumComponents() << ";\n"; if (has_zero_points_) { const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - prepare_scale_and_zero_point << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - << " var zero_point_byte_count: u32;\n" - << " var zero_point_word_index: u32;\n" - << " var zero_point_byte_offset: u32;\n" - << " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - << " var zero_point_bits_offset: u32;\n" - << " var zero_point_word: u32;\n"; + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " var zero_point_byte_count: u32;\n" + " var zero_point_word_index: u32;\n" + " var zero_point_byte_offset: u32;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " var zero_point_bits_offset: u32;\n" + " var zero_point_word: u32;\n"; for (int c = 0; c < output_element_number; c++) { - prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; - prepare_scale_and_zero_point << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" - << " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - << " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - << " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n"; - prepare_scale_and_zero_point << " col_index += 1;\n"; + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" + " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n" + << " col_index += 1;\n"; } } else { - prepare_scale_and_zero_point << " let zero_point = output_element_t(8.0);\n"; + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; for (int c = 0; c < output_element_number; c++) { - prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; - prepare_scale_and_zero_point << " col_index += 1;\n"; + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " col_index += 1;\n"; } } - std::ostringstream prepare_b_data; - prepare_b_data.imbue(std::locale::classic()); - prepare_b_data << " col_index = col * " << y.NumComponents() << ";\n"; + shader.MainFunctionBody() << " for (var word: u32 = 0; word < blob_size; word += 1) {\n"; + + // prepare b data + shader.MainFunctionBody() << " col_index = col * " << y.NumComponents() << ";\n"; for (int c = 0; c < output_element_number; c++) { - prepare_b_data << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" - << " col_index += 1;\n"; + shader.MainFunctionBody() << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" + << " col_index += 1;\n"; } - prepare_b_data << " var b_value : u32;\n" - << " let b_mask : u32 = 0x0F0F0F0Fu;\n" - << " var b_value_lower : vec4;\n" - << " var b_value_upper : vec4;\n" - << " var b_quantized_values : " << quantized_data_type << ";\n" - << " var b_dequantized_values : " << quantized_data_type << ";\n"; + shader.MainFunctionBody() << " var b_value : u32;\n" + " let b_mask : u32 = 0x0F0F0F0Fu;\n" + " var b_value_lower : vec4;\n" + " var b_value_upper : vec4;\n" + << " var b_quantized_values : " << quantized_data_type << ";\n" + << " var b_dequantized_values : " << quantized_data_type << ";\n"; + + shader.MainFunctionBody() << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - std::ostringstream process_one_word; - process_one_word.imbue(std::locale::classic()); - process_one_word << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" - << " var a_data: " << quantized_data_type << ";\n" - << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" - << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" - << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" - << " input_offset++;\n" - << " } else {\n" - << " a_data[j] = input_a_value_t(0);\n" - << " }\n" - << " }\n"; + // process one word + shader.MainFunctionBody() << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" + << " var a_data: " << quantized_data_type << ";\n" + << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" + << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" + << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" + << " input_offset++;\n" + " } else {\n" + " a_data[j] = input_a_value_t(0);\n" + " }\n" + " }\n"; for (int c = 0; c < output_element_number; c++) { - process_one_word << " b_value = " << "b" << c << "_data"; + shader.MainFunctionBody() << " b_value = b" << c << "_data"; if (components_b_ > 1) { - process_one_word << "[i]"; + shader.MainFunctionBody() << "[i]"; } - process_one_word << ";\n" - << " b_value_lower = unpack4xU8(b_value & b_mask);\n" - << " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" - << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - << " b_dequantized_values = "; + shader.MainFunctionBody() << ";\n" + " b_value_lower = unpack4xU8(b_value & b_mask);\n" + " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" + << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + << " b_dequantized_values = "; if (a.NumComponents() == 1) { if (has_zero_points_) { - process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; } else { - process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " - << "(b_quantized_values[1] - zero_point) * scale" << c << "," - << "(b_quantized_values[2] - zero_point) * scale" << c << "," - << "(b_quantized_values[3] - zero_point) * scale" << c << "," - << "(b_quantized_values[4] - zero_point) * scale" << c << "," - << "(b_quantized_values[5] - zero_point) * scale" << c << "," - << "(b_quantized_values[6] - zero_point) * scale" << c << "," - << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " + << "(b_quantized_values[1] - zero_point) * scale" << c << "," + << "(b_quantized_values[2] - zero_point) * scale" << c << "," + << "(b_quantized_values[3] - zero_point) * scale" << c << "," + << "(b_quantized_values[4] - zero_point) * scale" << c << "," + << "(b_quantized_values[5] - zero_point) * scale" << c << "," + << "(b_quantized_values[6] - zero_point) * scale" << c << "," + << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; } } else { - process_one_word << "(b_quantized_values - " << quantized_data_type << "("; + shader.MainFunctionBody() << "(b_quantized_values - " << quantized_data_type << "("; for (int i = 0; i < 8; i++) { if (has_zero_points_) { - process_one_word << "zero_point" << c; + shader.MainFunctionBody() << "zero_point" << c; } else { - process_one_word << "zero_point"; + shader.MainFunctionBody() << "zero_point"; } if (i < 7) { - process_one_word << ", "; + shader.MainFunctionBody() << ", "; } } - process_one_word << ")) * scale" << c << ";\n"; + shader.MainFunctionBody() << ")) * scale" << c << ";\n"; } - process_one_word << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; + shader.MainFunctionBody() << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; if (y.NumComponents() > 1) { - process_one_word << "[" << c % y.NumComponents() << "]"; + shader.MainFunctionBody() << "[" << c % y.NumComponents() << "]"; } - process_one_word << " += "; + shader.MainFunctionBody() << " += "; if (a.NumComponents() == 1) { - process_one_word << "a_data[0] * b_dequantized_values[0] + " - << "a_data[1] * b_dequantized_values[1] + " - << "a_data[2] * b_dequantized_values[2] + " - << "a_data[3] * b_dequantized_values[3] + " - << "a_data[4] * b_dequantized_values[4] + " - << "a_data[5] * b_dequantized_values[5] + " - << "a_data[6] * b_dequantized_values[6] + " - << "a_data[7] * b_dequantized_values[7];\n"; + shader.MainFunctionBody() << "a_data[0] * b_dequantized_values[0] + " + "a_data[1] * b_dequantized_values[1] + " + "a_data[2] * b_dequantized_values[2] + " + "a_data[3] * b_dequantized_values[3] + " + "a_data[4] * b_dequantized_values[4] + " + "a_data[5] * b_dequantized_values[5] + " + "a_data[6] * b_dequantized_values[6] + " + "a_data[7] * b_dequantized_values[7];\n"; } else if (a.NumComponents() == 2) { - process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " - << "dot(a_data[1], b_dequantized_values[1]) + " - << "dot(a_data[2], b_dequantized_values[2]) + " - << "dot(a_data[3], b_dequantized_values[3]);\n"; + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]) + " + "dot(a_data[2], b_dequantized_values[2]) + " + "dot(a_data[3], b_dequantized_values[3]);\n"; } else if (a.NumComponents() == 4) { - process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " - << "dot(a_data[1], b_dequantized_values[1]);\n"; + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]);\n"; } } - const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; - std::string offset = "workgroup_idx * " + std::to_string(output_number_); - shader.AppendImplementation("var workgroup_shared : array;\n"); - shader.SetMainFunctionBody(" let output_indices = ", y.OffsetToIndices(offset), - ";\n" - " let col = output_indices[2];\n" - " let row = output_indices[1];\n" - " let batch = output_indices[0];\n" - " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - " let blob_size = uniforms.input_b_shape[2]" - ";\n" - " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" - " var word_offset = block * uniforms.block_size / ", - a.NumComponents(), ";\n", - prepare_scale_and_zero_point.str(), - " for (var word: u32 = 0; word < blob_size; word += 1) {\n", - prepare_b_data.str(), - " for (var i: u32 = 0; i < ", components_b_, "; i++) {\n", - process_one_word.str(), - " word_offset += ", 8 / a.NumComponents(), - ";\n" - " }\n" - " }\n" - " }\n" - " workgroupBarrier();\n" - " if (local_id.x < ", - output_number_, - ") {\n" - " var output_value = output_value_t(0);\n" - " var workgroup_shared_offset = local_id.x;\n" - " let blocks_num = min(", - shared_memory_size, - ", n_blocks_per_col);\n" - " for (var b = 0u; b < blocks_num; b++) {\n" - " output_value += workgroup_shared[workgroup_shared_offset];\n" - " workgroup_shared_offset += ", - output_number_, - ";\n" - " }\n", - " ", - y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value"), - "\n" - " }\n"); + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + << " if (local_id.x < " << output_number_ << ") {\n" + << " var output_value = output_value_t(0);\n" + " var workgroup_shared_offset = local_id.x;\n" + << " let blocks_num = min(" << shared_memory_size << ", n_blocks_per_col);\n" + << " for (var b = 0u; b < blocks_num; b++) {\n" + " output_value += workgroup_shared[workgroup_shared_offset];\n" + << " workgroup_shared_offset += " << output_number_ << ";\n" + << " }\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value") << "\n" + << " }\n"; return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/docs/Conventions.md b/onnxruntime/core/providers/webgpu/docs/Conventions.md index 1a86e508cdda..fecccc76a4db 100644 --- a/onnxruntime/core/providers/webgpu/docs/Conventions.md +++ b/onnxruntime/core/providers/webgpu/docs/Conventions.md @@ -10,20 +10,27 @@ Let's keep it "webgpu" for this folder for now. I have a very good reason to do And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") -### Use macros defined in shader_macros.h +### Use `OStringStream` defined in string_utils.h and macros defined in string_macros.h -Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. +Type `onnxruntime::webgpu::OStringStream` is a type alias of Abseil's OStringStream. It's a lightweight implementation +of `std::ostream`. It's recommended to use `OStringStream` instead of `std::ostringstream` in the code base. -I prefer to use the macro because I feel like it's easier to read. Check the following code: +The macros defined in `string_macros.h` are used to make coding easier: ```cpp -ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; -``` +std::string MyFunction() { + SS(code /* name of the string stream */, 2048 /* initial capacity */); -vs. + code << "var my_var = "; -```cpp -SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); + // function call style string append. equivalent to: + // + // code << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; + // + SS_APPEND(code, "vec4(", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); + + return SS_GET(code); // return the string +} ``` ### Use the subfolder for kernel implementation diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index bae7c6a73c4c..6077ef049906 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -12,15 +12,27 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - std::string common; - std::string get_a_data = is_lhs_scalar_ ? "let a = input_a_value_t(" + a.GetByOffset("0") + ".x" + ");\n" - : "let a = " + a.GetByOffset("global_idx") + ";\n"; - std::string get_b_data = is_rhs_scalar_ ? "let b = input_b_value_t(" + b.GetByOffset("0") + ".x" + ");\n" - : "let b = " + b.GetByOffset("global_idx") + ";\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + // check whether can use element-wise mode. // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. // In element-wise mode, no indices calculation is needed. - if (!is_lhs_scalar_ && !is_rhs_scalar_ && is_broadcast_) { + if (is_lhs_scalar_ || is_rhs_scalar_ || !is_broadcast_) { + // get A data + if (is_lhs_scalar_) { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("global_idx") << ";\n"; + } + + // get B data + if (is_rhs_scalar_) { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("global_idx") << ";\n"; + } + } else { const auto& c_indices = shader.AddIndices("bcast_indices"); // check whether can use vectorize mode. // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode @@ -30,67 +42,54 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const if (vectorize_) { const auto& a_indices = shader.AddIndices("a_indices"); const auto& b_indices = shader.AddIndices("b_indices"); - common = "let outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + - ";\n" - "let offset_a = " + - a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b = " + - b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; - get_a_data = a.NumComponents() == 4 ? "let a = " + a.GetByOffset("offset_a / 4") + ";\n" - : "let a = input_b_value_t(" + a.GetByOffset("offset_a") + ");\n"; - get_b_data = b.NumComponents() == 4 ? "let b = " + b.GetByOffset("offset_b / 4") + ";\n" - : "let b = input_a_value_t(" + b.GetByOffset("offset_b") + ");\n"; + + shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + // get A data + if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("offset_a") << ");\n"; + } + + // get B data + if (b.NumComponents() == 4) { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("offset_b / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("offset_b") << ");\n"; + } } else { // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. - common = "var outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + - ";\n" - "let offset_a0 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b0 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "outputIndices = " + - c_indices.OffsetToIndices("global_idx * 4 + 1") + - ";\n" - "let offset_a1 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b1 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "outputIndices = " + - c_indices.OffsetToIndices("global_idx * 4 + 2") + - ";\n" - "let offset_a2 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b2 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "outputIndices = " + - c_indices.OffsetToIndices("global_idx * 4 + 3") + - ";\n" - "let offset_a3 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b3 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; - get_a_data = "let a = vec4(" + a.GetByOffset("offset_a0") + ", " + - a.GetByOffset("offset_a1") + ", " + - a.GetByOffset("offset_a2") + ", " + - a.GetByOffset("offset_a3") + ");\n"; - get_b_data = "let b = vec4(" + b.GetByOffset("offset_b0") + ", " + - b.GetByOffset("offset_b1") + ", " + - b.GetByOffset("offset_b2") + ", " + - b.GetByOffset("offset_b3") + ");\n"; + shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a0 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b0 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n" + << "let offset_a1 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b1 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n" + << "let offset_a2 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b2 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n" + << "let offset_a3 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b3 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + + // get A data + shader.MainFunctionBody() << "let a = vec4(" + << a.GetByOffset("offset_a0") << ", " + << a.GetByOffset("offset_a1") << ", " + << a.GetByOffset("offset_a2") << ", " + << a.GetByOffset("offset_a3") << ");\n"; + // get B data + shader.MainFunctionBody() << "let b = vec4(" + << b.GetByOffset("offset_b0") << ", " + << b.GetByOffset("offset_b1") << ", " + << b.GetByOffset("offset_b2") << ", " + << b.GetByOffset("offset_b3") << ");\n"; } } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - common, get_a_data, get_b_data, - c.SetByOffset("global_idx", expression_)); + shader.MainFunctionBody() << c.SetByOffset("global_idx", expression_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 9e8117aa34a9..f6d6b18a3d36 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -12,10 +12,10 @@ namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_); const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform); - shader.AppendImplementation(additional_impl_); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " let a = ", input.GetByOffset("global_idx"), ";\n ", - output.SetByOffset("global_idx", expression_)); + shader.AdditionalImplementation() << additional_impl_; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 25c0a4278be8..d1d4c242c469 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -257,7 +257,7 @@ ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dep use_override_shape{true}, override_shape{override_shape} {} -ProgramBase::ProgramBase(const std::string& name, ProgramMetadata&& metadata) +ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) : name_{name}, metadata_{metadata}, dispatch_group_size_x_{0}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index bd9a26b0fcfb..45fa11715a9e 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -246,9 +246,9 @@ class ProgramBase { // // set the cache hint for the program - template - ProgramBase& CacheHint(T&& hint) { - cache_hint_ = std::forward(hint); + template + ProgramBase& CacheHint(T&&... hints) { + cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(hints)...), "|"); return *this; } @@ -327,7 +327,7 @@ class ProgramBase { private: // Make the constructor private to prevent direct instantiation or inheritance from this class // Use the Program template class as base class to create a new program class - explicit ProgramBase(const std::string& name, ProgramMetadata&& metadata); + explicit ProgramBase(std::string_view name, ProgramMetadata&& metadata); std::string name_; ProgramMetadata metadata_; diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 6c7ef2bc89c6..a5c21563dbfc 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -3,14 +3,21 @@ #include "core/providers/webgpu/program_cache_key.h" -#include "core/providers/webgpu/shader_macros.h" +#include "core/providers/webgpu/string_macros.h" namespace onnxruntime { namespace webgpu { +// macro "D" - append to the ostream only in debug build +#ifndef NDEBUG // if debug build +#define D(str) << str +#else +#define D(str) +#endif + namespace { // append the info of an input or output to the cachekey -void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, +void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) { if (first) { first = false; @@ -36,8 +43,7 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVaria } // namespace std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeCacheKey); // final key format: // =[]:::: @@ -100,7 +106,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); } - return ss.str(); + return SS_GET(ss); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index a88687fce18b..d722bcb07cdb 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -10,6 +10,8 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/string_utils.h" +#include "core/providers/webgpu/string_macros.h" namespace onnxruntime { namespace webgpu { @@ -27,7 +29,9 @@ ShaderHelper::ShaderHelper(const ProgramBase& program, dispatch_group_size_y_{dispatch_group_size_y}, dispatch_group_size_z_{dispatch_group_size_z}, program_{program}, - program_metadata_{program_metadata} {} + program_metadata_{program_metadata}, + additional_implementation_ss_{&additional_implementation_}, + body_ss_{&body_} {} Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here @@ -50,31 +54,29 @@ Status ShaderHelper::Init() { // init body string stream bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; - body_.imbue(std::locale::classic()); + body_.reserve(4096); + additional_implementation_.reserve(1024); // append header for main function so it is ready for user to append main function body - body_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" - "fn main(@builtin(global_invocation_id) global_id : vec3,\n" - " @builtin(workgroup_id) workgroup_id : vec3,\n" - " @builtin(local_invocation_id) local_id : vec3"; + body_ss_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" + "fn main(@builtin(global_invocation_id) global_id : vec3,\n" + " @builtin(workgroup_id) workgroup_id : vec3,\n" + " @builtin(local_invocation_id) local_id : vec3"; if (!is_1d_dispatch) { - body_ << ",\n" - " @builtin(local_invocation_index) local_idx : u32,\n" - " @builtin(num_workgroups) num_workgroups : vec3"; + body_ss_ << ",\n" + " @builtin(local_invocation_index) local_idx : u32,\n" + " @builtin(num_workgroups) num_workgroups : vec3"; } - body_ << ") {\n"; + body_ss_ << ") {\n"; if (is_1d_dispatch) { - body_ << " let global_idx = global_id.x;\n" - " let local_idx = local_id.x;\n" - " let workgroup_idx = workgroup_id.x;\n"; + body_ss_ << " let global_idx = global_id.x;\n" + " let local_idx = local_id.x;\n" + " let workgroup_idx = workgroup_id.x;\n"; } else { - body_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" - " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + body_ss_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; } - // init additional implementation string stream - additional_implementation_.imbue(std::locale::classic()); - return Status::OK(); } @@ -322,8 +324,7 @@ Status ShaderHelper::ValidateIndices() const { } Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeShaderSourceCode); // // Section feature enabling @@ -513,16 +514,16 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // // Additional Implementation // - ss << additional_implementation_.str(); + ss << additional_implementation_; // // Main Function Body // - ss << body_.str(); + ss << body_; ss << "\n" "}\n"; - code = ss.str(); + code = SS_GET(ss); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index bdc14669cfb5..5e60c1293ace 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -16,6 +16,7 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/string_utils.h" namespace onnxruntime { namespace webgpu { @@ -92,23 +93,14 @@ class ShaderHelper final { // Add an indices variable to the shader. const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform = true); - // Append additional implementation code to the shader. - // - // can be called multiple times. - template - inline ShaderHelper& AppendImplementation(Strs&&... impl) { - onnxruntime::detail::MakeStringImpl(additional_implementation_, std::forward(impl)...); - return *this; + // Get the string stream for additional implementation code to the shader. + inline OStringStream& AdditionalImplementation() { + return additional_implementation_ss_; } - // Set the main function body of the shader. - // - // can be called only once. - template - inline void SetMainFunctionBody(const Strs&... body) { - ORT_ENFORCE(!body_set_, "Main function body is already set"); - onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); - body_set_ = true; + // Get the string stream for the main function body of the shader. + inline OStringStream& MainFunctionBody() { + return body_ss_; } std::string GuardAgainstOutOfBoundsWorkgroupSizes(std::string_view size) const { @@ -117,7 +109,7 @@ class ShaderHelper final { private: template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} - void WriteConstantValue(std::ostringstream& ss, const ConstantType& constant) const { + void WriteConstantValue(std::ostream& ss, const ConstantType& constant) const { switch (constant.type) { case ProgramConstantDataType::Float16: ss << constant.f16.ToFloat(); @@ -179,10 +171,10 @@ class ShaderHelper final { std::vector> input_vars_; std::vector> output_vars_; std::vector> indices_vars_; - std::ostringstream additional_implementation_; - std::ostringstream body_; - - bool body_set_ = false; + std::string additional_implementation_; + OStringStream additional_implementation_ss_; + std::string body_; + OStringStream body_ss_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/shader_macros.h b/onnxruntime/core/providers/webgpu/shader_macros.h deleted file mode 100644 index a1c61950e6a1..000000000000 --- a/onnxruntime/core/providers/webgpu/shader_macros.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -// macro "D": append to the ostream only in debug build -// -// Usage example: -// -// ss << "error code: " << err_code D(" (") << D(err_msg) D(")"); -// -// This resolves to: (debug build) -// ss << "error code: " << err_code << " (" << err_msg << ")"; -// -// This resolves to: (release build) -// ss << "error code: " << err_code; - -#ifdef D -#undef D -#endif - -#ifndef NDEBUG // if debug build -#define D(str) << str -#else -#define D(str) -#endif - -// macro "DSS" append to the ostream only in debug build -// (assume variable "ss" is in scope) -// -// Usage example: -// -// DSS << "detail error message: " << err_msg; -// -// This resolves to: (debug build) -// ss << "detail error message: " << err_msg; -// -// This resolves to: (release build) -// if constexpr (false) ss << "detail error message: " << err_msg; // no-op - -#ifdef DSS -#undef DSS -#endif - -#ifndef NDEBUG // if debug build -#define DSS ss -#else -#define DSS \ - if constexpr (false) ss -#endif - -// macro "SS" - use function call style to append to the ostream -// (assume variable "ss" is in scope) -// -// Usage example: -// -// SS("error code: ", err_code, " (", err_msg, ")"); -// -// This resolves to: -// ss << "error code: " << err_code << " (" << err_msg << ")"; - -#ifdef SS -#undef SS -#endif - -#define SS(...) ::onnxruntime::detail::MakeStringImpl(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index cbc39c86e504..e60a06800851 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -8,7 +8,7 @@ #include "core/common/safeint.h" #include "core/providers/webgpu/shader_variable.h" -#include "core/providers/webgpu/shader_macros.h" +#include "core/providers/webgpu/string_macros.h" namespace onnxruntime { namespace webgpu { @@ -103,7 +103,7 @@ ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariabl ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } -void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { +void ShaderIndicesHelper::Impl(std::ostream& ss) const { // Start generating code const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; @@ -111,18 +111,18 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { // Types if (usage_ & ShaderUsage::UseValueTypeAlias) { - SS("alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); + SS_APPEND(ss, "alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); } if (usage_ & ShaderUsage::UseIndicesTypeAlias) { - SS("alias ", indices_type_alias_, " = ", indices_type_, ";\n"); + SS_APPEND(ss, "alias ", indices_type_alias_, " = ", indices_type_, ";\n"); } if (usage_ & ShaderUsage::UseElementTypeAlias) { - SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); + SS_APPEND(ss, "alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } // Need shape and strides when (not use uniform) and (use shape and stride is enabled) if (!(usage_ & ShaderUsage::UseUniform) && (usage_ & ShaderUsage::UseShapeAndStride) && rank_ > 0) { - SS("const ", shape, " = ", IndicesType(), "("); + SS_APPEND(ss, "const ", shape, " = ", IndicesType(), "("); bool first = true; for (auto dim : dims_.GetDims()) { @@ -136,7 +136,7 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { ss << ");\n"; if (rank_ > 1) { - SS("const ", stride, " = ", GetIndicesType(rank_ - 1), "("); + SS_APPEND(ss, "const ", stride, " = ", GetIndicesType(rank_ - 1), "("); first = true; for (int i = 1; i < rank_; i++) { if (!first) { @@ -152,32 +152,32 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { // Implementation of "fn o2i_{name}" if (usage_ & ShaderUsage::UseOffsetToIndices) { if (rank_ >= 2) { - SS("fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); - SS(" var indices: ", IndicesType(), ";\n"); - SS(" var current = offset;\n"); + SS_APPEND(ss, "fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); + SS_APPEND(ss, " var indices: ", IndicesType(), ";\n"); + SS_APPEND(ss, " var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_ - 1); - SS(" let dim", i, " = current / ", current_stride, ";\n"); - SS(" let rest", i, " = current % ", current_stride, ";\n"); - SS(" indices[", i, "] = dim", i, ";\n"); - SS(" current = rest", i, ";\n"); + SS_APPEND(ss, " let dim", i, " = current / ", current_stride, ";\n"); + SS_APPEND(ss, " let rest", i, " = current % ", current_stride, ";\n"); + SS_APPEND(ss, " indices[", i, "] = dim", i, ";\n"); + SS_APPEND(ss, " current = rest", i, ";\n"); } - SS(" indices[", rank_ - 1, "] = current;\n"); - SS(" return indices;\n"); - SS("}\n"); + SS_APPEND(ss, " indices[", rank_ - 1, "] = current;\n"); + SS_APPEND(ss, " return indices;\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn i2o_{name}" if (usage_ & ShaderUsage::UseIndicesToOffset) { if (rank_ >= 2) { - SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); - SS(" return "); + SS_APPEND(ss, "fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); + SS_APPEND(ss, " return "); for (int i = 0; i < rank_ - 1; i++) { - SS("indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); + SS_APPEND(ss, "indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); } - SS("indices[", rank_ - 1, "];\n"); - SS("}\n"); + SS_APPEND(ss, "indices[", rank_ - 1, "];\n"); + SS_APPEND(ss, "}\n"); } } @@ -186,83 +186,82 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { if (rank_ > 0) { for (const auto& broadcasted_result_ptr : broadcasted_to_) { const auto& broadcasted_result = *broadcasted_result_ptr; - SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); + SS_APPEND(ss, "fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); if (rank_ == 1) { - SS(" return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); + SS_APPEND(ss, " return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); } else { - SS(" return "); + SS_APPEND(ss, " return "); for (int i = 0; i < rank_ - 1; i++) { auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); std::string current_stride = rank_ == 2 ? stride : GetElementAt(stride, i, rank_ - 1); - SS(current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); + SS_APPEND(ss, current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); } - SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + SS_APPEND(ss, broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); } - SS("}\n"); + SS_APPEND(ss, "}\n"); } } } } -void ShaderVariableHelper::Impl(std::ostringstream& ss) const { +void ShaderVariableHelper::Impl(std::ostream& ss) const { ShaderIndicesHelper::Impl(ss); // Implementation of "fn set_{name}" if (usage_ & ShaderUsage::UseSet) { if (rank_ >= 2) { - SS("fn set_", name_, "(d0: u32"); + SS_APPEND(ss, "fn set_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { - SS(", d", i, ": u32"); + SS_APPEND(ss, ", d", i, ": u32"); } - SS(", value: ", ValueType(), ") {\n"); - SS(" set_", name_, "_by_indices(d0"); + SS_APPEND(ss, ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " set_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { - SS(", d", i); + SS_APPEND(ss, ", d", i); } - SS(", value);\n"); - SS("}\n"); + SS_APPEND(ss, ", value);\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn set_{name}_by_indices" if (usage_ & ShaderUsage::UseSetByIndices) { if (rank_ >= 2) { - SS("fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); - SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); - SS("}\n"); + SS_APPEND(ss, "fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn get_{name}" if (usage_ & ShaderUsage::UseGet) { if (rank_ >= 2) { - SS("fn get_", name_, "(d0: u32"); + SS_APPEND(ss, "fn get_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { - SS(", d", i, ": u32"); + SS_APPEND(ss, ", d", i, ": u32"); } - SS(")->", ValueType(), " {\n"); - SS(" return get_", name_, "_by_indices(d0"); + SS_APPEND(ss, ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return get_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { - SS(", d", i); + SS_APPEND(ss, ", d", i); } - SS(");\n"); - SS("}\n"); + SS_APPEND(ss, ");\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn get_{name}_by_indices" if (usage_ & ShaderUsage::UseGetByIndices) { if (rank_ >= 2) { - SS("fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); - SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); - SS("}\n"); + SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS_APPEND(ss, "}\n"); } } } std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeGetByOffsetImpl); switch (type_) { case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: @@ -283,12 +282,11 @@ std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const ss << name_ << "[" << offset << "]"; } - return ss.str(); + return SS_GET(ss); } std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::string_view value) const { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeSetByOffsetImpl); switch (type_) { case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: @@ -307,7 +305,7 @@ std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std:: ss << name_ << "[" << offset << "]=" << value << ";"; } - return ss.str(); + return SS_GET(ss); } std::string_view ShaderVariableHelper::StorageType() const { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 72f38aecb99c..cad7b0ceb830 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -110,7 +110,7 @@ class ShaderIndicesHelper { protected: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper); - void Impl(std::ostringstream& ss) const; + void Impl(std::ostream& ss) const; std::string_view IndicesType() const; @@ -175,7 +175,7 @@ class ShaderVariableHelper : public ShaderIndicesHelper { private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); - void Impl(std::ostringstream& ss) const; + void Impl(std::ostream& ss) const; std::string GetByOffsetImpl(std::string_view offset) const; std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; diff --git a/onnxruntime/core/providers/webgpu/string_macros.h b/onnxruntime/core/providers/webgpu/string_macros.h new file mode 100644 index 000000000000..7821d9c49a17 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_macros.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/string_utils.h" + +// macro "SS" - declare an ostream variable and its string buffer +#define SS(ss, reserve_size) \ + std::string ss##_str; \ + ss##_str.reserve(reserve_size); \ + ::onnxruntime::webgpu::OStringStream ss(&ss##_str) + +// macro "SS_GET" - get the string from the ostream +#define SS_GET(ss) ss##_str + +// macro "SS_APPEND" - use function call style to append to the ostream +#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppend(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/string_utils.h b/onnxruntime/core/providers/webgpu/string_utils.h new file mode 100644 index 000000000000..e6d7097ad618 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_utils.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/make_string.h" +#include + +namespace onnxruntime { +namespace webgpu { + +constexpr const size_t kStringInitialSizeSetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeGetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeShaderSourceCode = 2048; +#ifndef NDEBUG +constexpr const size_t kStringInitialSizeCacheKey = 512; +#else +constexpr const size_t kStringInitialSizeCacheKey = 256; +#endif + +using OStringStream = absl::strings_internal::OStringStream; + +namespace detail { +inline void OStringStreamAppendImpl(std::ostream& /*ss*/) noexcept { +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... args) noexcept { + OStringStreamAppendImpl(ss, t); + OStringStreamAppendImpl(ss, args...); +} + +template +inline void OStringStreamAppend(std::ostream& ss, const Args&... args) { + return OStringStreamAppendImpl(ss, ::onnxruntime::detail::if_char_array_make_ptr_t(args)...); +} + +} // namespace detail + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 8d59570de996..06eae971309c 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -106,9 +106,9 @@ Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const { default: ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported."); } - sh.SetMainFunctionBody(sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " let a = ", input.GetByOffset("global_idx"), ";\n ", - output.SetByOffset("global_idx", expression)); + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 671a6a1ed072..866f99b587bc 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,48 +38,33 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) -const std::string AppendCalCulateInputIndexFunction(size_t input_count) { - std::ostringstream ss; - ss.imbue(std::locale::classic()); - ss << "fn calculate_input_index(index: u32) -> u32 {" << std::endl - << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {" << std::endl - << " if (index < uniforms.size_in_concat_axis[i]) {" << std::endl - << " return i;" << std::endl - << " }" << std::endl - << " }" << std::endl - << " return " << input_count << ";" << std::endl - << "}" << std::endl; - return ss.str(); +void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(index: u32) -> u32 {\n" + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" + << " if (index < uniforms.size_in_concat_axis[i]) {\n" + << " return i;\n" + << " }\n" + << " }\n" + << " return " << input_count << ";\n" + << "}\n"; } -const void AppendAssignOutput(std::ostringstream& ss, const ShaderVariableHelper& input, const ShaderVariableHelper& output) { - ss << output.SetByOffset("global_idx", input.GetByIndices("indices")) << ";" << std::endl; -} - -const std::string AppendAssignOutputDataFunction(gsl::span inputs, const ShaderVariableHelper& output) { - std::ostringstream ss; - size_t input_count = inputs.size(); - ss.imbue(std::locale::classic()); - ss << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {" << std::endl; - if (input_count == 0) { - AppendAssignOutput(ss, *inputs[0], output); - } else { - for (size_t i = 0; i < input_count; ++i) { - if (i == 0) { - ss << " if (input_index == 0u) {" << std::endl; - } else if (i == input_count - 1) { - ss << " } else {" << std::endl; - } else { - ss << " } else if (input_index == " << i << "u) {" << std::endl; - } - ss << " "; - AppendAssignOutput(ss, *inputs[i], output); +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { + os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i == 0) { + os << " if (input_index == 0u) {\n"; + } else if (i == inputs.size() - 1) { + os << " } else {\n"; + } else { + os << " } else if (input_index == " << i << "u) {\n"; } - ss << " }" << std::endl; + os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; } - ss << "}" << std::endl; - return ss.str(); + os << " }\n" + "}\n"; } + Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { size_t input_count = Inputs().size(); std::vector inputs; @@ -88,16 +73,20 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AppendImplementation(AppendCalCulateInputIndexFunction(input_count)); - shader.AppendImplementation(AppendAssignOutputDataFunction(gsl::make_span(inputs), output)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " var indices = ", output.OffsetToIndices("global_idx"), ";\n", - " let indices_axis = ", output.IndicesGet("indices", axis_), ";\n", - " let input_index = calculate_input_index(indices_axis);\n", - " if (input_index != 0u) {\n", - " ", output.IndicesSet("indices", axis_, "indices_axis - uniforms.size_in_concat_axis[input_index - 1]"), ";\n", - " }\n", - " assign_output_data(global_idx, input_index, indices);\n"); + + // add implementation of fn calculate_input_index + AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); + // add implementation of fn assign_output_data + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" + << " let input_index = calculate_input_index(indices_axis);\n" + " if (input_index != 0u) {\n" + << " " << output.IndicesSet("indices", axis_, "indices_axis - uniforms.size_in_concat_axis[input_index - 1]") << ";\n" + << " }\n" + " assign_output_data(global_idx, input_index, indices);\n"; return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index a10658365188..84cdb35d77f0 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -14,10 +14,10 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), - " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - " let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n ", - output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 31e0a9e88323..9d5875c4efb4 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -13,31 +13,28 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - std::ostringstream calc_data_indices; - calc_data_indices.imbue(std::locale::classic()); - calc_data_indices << " var indices_indices = input_indices_indices_t(0);\n"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var indices_indices = input_indices_indices_t(0);\n"; for (int i = 0; i < indices.Rank(); i++) { - calc_data_indices << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; } - calc_data_indices << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" - << " if (idx < 0) {\n" - << " idx = idx + input_indices_value_t(uniforms.data_shape[" << axis_ << "]);\n" - << " }\n" - << " var data_indices : data_indices_t;\n"; + shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(uniforms.data_shape[" << axis_ << "]);\n" + << " }\n" + << " var data_indices : data_indices_t;\n"; for (int i = 0, j = 0; i < data.Rank(); i++) { if (i == SafeInt(axis_)) { - calc_data_indices << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; j += indices.Rank(); } else { - calc_data_indices << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; j++; } } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), - " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - calc_data_indices.str(), " ", - output.SetByOffset("global_idx", data.GetByIndices("data_indices"))); + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.cc b/onnxruntime/core/providers/webgpu/tensor/tile.cc index 2737b6dafea8..841c36724df3 100644 --- a/onnxruntime/core/providers/webgpu/tensor/tile.cc +++ b/onnxruntime/core/providers/webgpu/tensor/tile.cc @@ -30,22 +30,18 @@ Status TileProgram::GenerateShaderCode(ShaderHelper& shader) const { const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - std::ostringstream ss; - ss.imbue(std::locale::classic()); - - ss << "var input_indices: input_indices_t;\n"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "var input_indices: input_indices_t;\n"; for (auto i = 0; i < input.Rank(); i++) { - std::string input_dim_i = "input_dim_" + std::to_string(i); - std::string input_dim_value = "input_dim_" + std::to_string(i) + "_value"; - ss << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n"; - ss << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n"; - ss << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; + std::string input_dim_i = absl::StrCat("input_dim_", i); + std::string input_dim_value = absl::StrCat("input_dim_", i, "_value"); + shader.MainFunctionBody() << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n" + << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n" + << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - ss.str(), - output.SetByOffset("global_idx", input.GetByIndices("input_indices"))); + shader.MainFunctionBody() << output.SetByOffset("global_idx", input.GetByIndices("input_indices")); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index e0a0113e1322..adcee8b64fd8 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,19 +47,6 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -const std::string AppendPermFunction(gsl::span perm) { - std::ostringstream ss; - ss.imbue(std::locale::classic()); - ss << "fn perm(i: output_indices_t)->a_indices_t {\n" - " var a: a_indices_t;\n"; - for (size_t i = 0; i < perm.size(); ++i) { - ss << " a[" << perm[i] << "] = i[" << i << "];\n"; - } - ss << " return a;\n" - "}\n"; - return ss.str(); -} - auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { for (auto i = 0; i < shape.size(); ++i) { if (shape[i] != 1) { @@ -76,31 +63,36 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); if (use_shared_) { - shader.AppendImplementation("var tile : array, tile_size>;\n"); - shader.SetMainFunctionBody( - " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" - " let workgroup_id_x = workgroup_idx % stride;\n" - " let workgroup_id_y = workgroup_idx / stride;\n" - " let input_col = workgroup_id_y * tile_size + local_id.x;\n" - " let input_row = workgroup_id_x * tile_size + local_id.y;\n" - " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" - " tile[local_id.y][local_id.x] = " + - input.GetByIndices("a_indices_t(input_row, input_col)") + - ";\n" - " }\n" - " workgroupBarrier();\n" - " let output_col = workgroup_id_x * tile_size + local_id.x;\n" - " let output_row = workgroup_id_y * tile_size + local_id.y;\n" - " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " + - output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") + "\n }"); + shader.AdditionalImplementation() << "var tile : array, tile_size>;\n"; + shader.MainFunctionBody() << " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + << " tile[local_id.y][local_id.x] = " << input.GetByIndices("a_indices_t(input_row, input_col)") << ";\n" + << " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n" + << " " << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n" + << " }"; } else { - shader.AppendImplementation(AppendPermFunction(this->perm_)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " let indices = ", output.OffsetToIndices("global_idx"), - ";\n" - " let x_indices = perm(indices);\n", - " ", - output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + shader.AdditionalImplementation() << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; + for (size_t i = 0; i < perm_.size(); ++i) { + shader.AdditionalImplementation() << " a[" << perm_[i] << "] = i[" << i << "];\n"; + } + shader.AdditionalImplementation() << " return a;\n" + "}\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let indices = " << output.OffsetToIndices("global_idx") + << ";\n" + " let x_indices = perm(indices);\n" + " " + << output.SetByOffset("global_idx", input.GetByIndices("x_indices")); } return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index 1d58538a7489..b37014eb05da 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -59,12 +59,14 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); - const auto expression = [](const std::string& a, const std::string& b, const std::string& c) -> auto { - return "select(" + b + ", " + a + ", " + c + ")"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + const auto expression = [](std::string_view a, std::string_view b, std::string_view c) -> auto { + return absl::StrCat("select(", b, ", ", a, ", ", c, ")"); }; - std::string assignment; + if (!is_broadcast_) { - assignment = output.SetByOffset( + shader.MainFunctionBody() << output.SetByOffset( "global_idx", expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx"))); @@ -75,47 +77,41 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output_indices = shader.AddIndices("output_indices"); const auto single_assignment = - [&expression, &output_indices, &a_indices, &b_indices, &c_indices]( - const std::string& rest_str, const std::string& x, const std::string& type_cast = "") - -> auto { + [&expression, &shader, &output_indices, &a_indices, &b_indices, &c_indices]( + std::string_view rest_str, const std::string& x, std::string_view type_cast = "") + -> void { const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; - std::ostringstream ss; - ss.imbue(std::locale::classic()); - ss << "let output_indices" + x + " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n"; - ss << "let offset_a" + x + " = " + a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; - ss << "let offset_b" + x + " = " + b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; - ss << "let offset_c" + x + " = " + c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; - ss << "let index_a" + x + " = offset_a" + x + " / 4u;\n"; - ss << "let index_b" + x + " = offset_b" + x + " / 4u;\n"; - ss << "let index_c" + x + " = offset_c" + x + " / 4u;\n"; - ss << "let component_a" + x + " = offset_a" + x + " % 4u;\n"; - ss << "let component_b" + x + " = offset_b" + x + " % 4u;\n"; - ss << "let component_c" + x + " = offset_c" + x + " % 4u;\n"; - ss << rest_str + "[" + x + "] = " + type_cast + "(" + expression(a_expression, b_expression, c_expression) + ");\n"; - return ss.str(); + shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n" + << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_b" << x << " = " << b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_c" << x << " = " << c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let index_a" << x << " = offset_a" << x << " / 4u;\n" + << "let index_b" << x << " = offset_b" << x << " / 4u;\n" + << "let index_c" << x << " = offset_c" << x << " / 4u;\n" + << "let component_a" << x << " = offset_a" << x << " % 4u;\n" + << "let component_b" << x << " = offset_b" << x << " % 4u;\n" + << "let component_c" << x << " = offset_c" << x << " % 4u;\n" + << rest_str << "[" << x << "] = " << type_cast << "(" << expression(a_expression, b_expression, c_expression) << ");\n"; }; if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { - assignment = - "var data = vec4(0); \n" + - single_assignment("data", "0", "u32") + - single_assignment("data", "1", "u32") + - single_assignment("data", "2", "u32") + - single_assignment("data", "3", "u32") + - "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; + shader.MainFunctionBody() << "var data = vec4(0);\n"; + single_assignment("data", "0", "u32"); + single_assignment("data", "1", "u32"); + single_assignment("data", "2", "u32"); + single_assignment("data", "3", "u32"); + shader.MainFunctionBody() << "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; } else { - assignment = - single_assignment("output_data[global_idx]", "0") + - single_assignment("output_data[global_idx]", "1") + - single_assignment("output_data[global_idx]", "2") + - single_assignment("output_data[global_idx]", "3"); + single_assignment("output_data[global_idx]", "0"); + single_assignment("output_data[global_idx]", "1"); + single_assignment("output_data[global_idx]", "2"); + single_assignment("output_data[global_idx]", "3"); } } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - assignment); + return Status::OK(); } From b574f2c547ffb9737c8eea2985ee0f978cc57889 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 02:28:53 -0700 Subject: [PATCH 105/114] Range --- .../core/providers/webgpu/generator/range.cc | 66 +++++++++++++++++++ .../core/providers/webgpu/generator/range.h | 31 +++++++++ .../webgpu/webgpu_execution_provider.cc | 6 +- 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/generator/range.cc create mode 100644 onnxruntime/core/providers/webgpu/generator/range.h diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc new file mode 100644 index 000000000000..f704dc4e2cf8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/generator/range.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +template +Status Range::ComputeInternal(ComputeContext& context) const { + T start = context.Input(0)->Data()[0]; + T limit = context.Input(1)->Data()[0]; + T delta = context.Input(2)->Data()[0]; + + int64_t n = static_cast(ceil((1.0 * (limit - start)) / delta)); + if (n <= 0) { + n = 0; + } + auto* output_tensor = context.Output(0, TensorShape{n}); + if (n == 0) { + return Status::OK(); + } + + uint32_t output_size = SafeInt(n); + RangeProgram program{}; + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + output_size, + *reinterpret_cast(&start), + *reinterpret_cast(&delta), + }); + + return context.RunProgram(program); +} + +Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let value = bitcast(uniforms.start) + output_value_t(global_idx) * bitcast(uniforms.delta);\n" + << output.SetByOffset("global_idx", "value"); + + return Status(); +} + +#define WEBGPU_RANGE_KERNEL(TYPE) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Range, \ + kOnnxDomain, \ + 11, \ + TYPE, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 0) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Range); + +WEBGPU_RANGE_KERNEL(float) +WEBGPU_RANGE_KERNEL(int32_t) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.h b/onnxruntime/core/providers/webgpu/generator/range.h new file mode 100644 index 000000000000..2f5812bb460a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +template +class Range : public WebGpuKernel { + public: + explicit Range(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +class RangeProgram : public Program { + public: + RangeProgram() : Program{"Range"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"start", ProgramUniformVariableDataType::Uint32}, + {"delta", ProgramUniformVariableDataType::Uint32}); +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index a43428f55ce8..4600f89cc9c9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -354,7 +354,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, InstanceNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Range); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, float, Einsum); @@ -676,7 +677,8 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, From 14ea5dbcf0251a6c93273e0aa9a918fb62eedf9a Mon Sep 17 00:00:00 2001 From: xhcao Date: Tue, 1 Oct 2024 06:35:08 +0800 Subject: [PATCH 106/114] webgpu: support MultiHeadAttention operator (#22144) ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../webgpu/bert/multihead_attention.cc | 493 ++++++++++++++++++ .../webgpu/bert/multihead_attention.h | 115 ++++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- .../multihead_attention_op_test.cc | 99 ++-- 4 files changed, 675 insertions(+), 34 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc new file mode 100644 index 000000000000..d836c1ddf867 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -0,0 +1,493 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MultiHeadAttention); + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderUsage::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" + << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" + << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; + if (has_bias_) { + shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; + } + shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; + if (has_bias_) { + shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; + } else { + shader.MainFunctionBody() << ";\n"; + } + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + assert(input_tensor->Shape().GetDims().size() == 3); + assert(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)}}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderUsage::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" + "let headIdx = workgroup_id.z;\n" + "let m = workgroup_id.y * TILE_SIZE;\n" + "let n = workgroup_id.x * TILE_SIZE;\n" + "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; + + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" + << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; + } else { + shader.MainFunctionBody() << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; + } + + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; + } + + shader.MainFunctionBody() << "var value = f32_val_t(0);\n" + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_key_ && has_present_key_) { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.past_sequence_length) {\n" + " tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " } else {\n" + " tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" + " }\n"; + } else { + shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; + } + + if (has_present_key_) { + shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" + << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" + << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " + attention_bias[outputIdx]"; + } + shader.MainFunctionBody() << ";\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { + const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + const int tile_size = 12; + const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = parameters.head_size / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size)) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AdditionalImplementation() << "var thread_max: array;\n" + << "var thread_sum: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = f32_val_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" + << " }\n" + << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " var f32input = f32_val_t(x[offset + i]);\n" + << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) { + const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1); + int work_group_size = 64; + const int d_comp = d / components; + if (d_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components}; + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .SetDispatchGroupSize(n) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(1.f / static_cast(d))}, + {static_cast(d_comp)}, + {static_cast(elementsPerThread)}}); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderUsage::UseUniform); + } + + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n"; + + shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; + + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" + << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + } else { + shader.MainFunctionBody() << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + shader.MainFunctionBody() << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_value_ && has_present_value_) { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.past_sequence_length) {\n" + << " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " } else {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; + } + + if (has_present_value_) { + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" + << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" + << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << " let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" + << " output[outputIdx] = value;\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + AttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length) { + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; + const bool has_present_value = output_count > 1 && past_value != nullptr; + const int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size)}, + {static_cast(parameters.num_heads)}, + {static_cast(parameters.v_hidden_size)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + ; + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; + + const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length)); + + return Status::OK(); +} + +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); +} + +Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* bias = context.Input(3); + const Tensor* key_padding_mask = context.Input(4); + const Tensor* attention_bias = context.Input(5); + const Tensor* past_key = context.Input(6); + const Tensor* past_value = context.Input(7); + + if (query->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu"); + } + if (key != nullptr && key->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed KV not implemented for webgpu"); + } + if (key_padding_mask) { + ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); + } + + AttentionParameters parameters; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, + num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(parameters.sequence_length); + output_shape[2] = static_cast(parameters.v_hidden_size); + Tensor* output = context.Output(0, output_shape); + + // If optional outputs aren't needed, present_key and present_value will be null + std::vector present_dims{ + parameters.batch_size, + parameters.num_heads, + parameters.total_sequence_length, + parameters.head_size, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context.Output(1, present_shape); + Tensor* present_value = context.Output(2, present_shape); + + TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, parameters.head_size}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q)); + + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); + } + + TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.head_size}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.head_size, key, bias, parameters.hidden_size, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.v_head_size}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); + + // Compute the attention score and apply the score to V + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h new file mode 100644 index 000000000000..36803e3027b4 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias, int tile_size, int components) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32}, + {"d_comp", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + + private: + int work_group_size_; + int components_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; +}; + +class MultiHeadAttention final : public WebGpuKernel { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + protected: + int num_heads_; + float mask_filter_value_; + float scale_; + bool is_unidirectional_{false}; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index b5d7a90b9bbf..93257d67c00a 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -45,7 +45,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it // BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo("num_heads", static_cast(num_heads)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); @@ -266,6 +268,12 @@ static void RunMultiHeadAttentionTest( execution_providers.push_back(DefaultDmlExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + + if (enable_webgpu) { + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } @@ -295,6 +303,7 @@ static void RunMultiHeadAttentionKernel( bool is_static_kv = true, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, bool disable_dml = false) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { @@ -309,7 +318,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -325,7 +335,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -341,7 +352,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -358,7 +370,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } #endif @@ -376,7 +389,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -392,11 +406,30 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } } -static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { +enum RunMultiHeadAttentionTestToggles : uint32_t { + DISABLE_NONE = 0, + DISABLE_CPU = 1 << 0, + DISABLE_CUDA = 1 << 1, + DISABLE_WEBGPU = 1 << 2, +}; +inline RunMultiHeadAttentionTestToggles operator|(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline RunMultiHeadAttentionTestToggles operator&(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) & static_cast(b)); +} + +static void RunMultiHeadAttentionTests(AttentionTestData& data, + RunMultiHeadAttentionTestToggles toggles = DISABLE_NONE) { + bool disable_cpu = toggles & DISABLE_CPU; + bool disable_cuda = toggles & DISABLE_CUDA; + bool disable_webgpu = toggles & DISABLE_WEBGPU; + if (data.fp32_output_data.size() > 0) { constexpr bool use_float16 = false; @@ -407,7 +440,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -420,7 +453,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } #endif @@ -431,7 +464,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } if (data.fp16_output_data.size() > 0) { @@ -443,7 +476,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; @@ -453,7 +486,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -464,7 +497,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #endif @@ -475,7 +508,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_Default; @@ -484,7 +517,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } @@ -503,40 +536,40 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } // This tests qk_head_size != v_head_size @@ -561,7 +594,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, false, true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // TODO (pavignol): Fix this regression @@ -571,7 +604,7 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetCrossAttentionDataWithPast(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } #endif @@ -579,27 +612,27 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU); } TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; GetAttentionDataCutlassAttnBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { // Whisper decoder cross attention without mask and different sequence lengths for Q and K/V AttentionTestData data; GetCrossAttentionData_DiffSequenceLengths(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { @@ -609,10 +642,10 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) RunMultiHeadAttentionTests(data); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // This test is disabled since it is not used in Whisper anymore, and it fails in ROCm. From c70441ec5b44278f79b7ea3383b4ed5340989479 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 1 Oct 2024 02:37:30 -0700 Subject: [PATCH 107/114] [webgpu-native] support for webgpu layernorms (#22249) adds webgpu support for LayerNormalization, SimplifiedLayerNormalization, SkipLayerNormalization, SkipSimplifiedLayerNormalization --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../contrib_ops/webgpu/bert/layer_norm.cc | 36 ++++ .../webgpu/bert/skip_layer_norm.cc | 168 ++++++++++++++++++ .../contrib_ops/webgpu/bert/skip_layer_norm.h | 62 +++++++ .../webgpu/webgpu_contrib_kernels.cc | 17 +- .../core/providers/webgpu/nn/layer_norm.cc | 155 ++++++++++++++++ .../core/providers/webgpu/nn/layer_norm.h | 68 +++++++ .../webgpu/webgpu_execution_provider.cc | 2 +- .../test/contrib_ops/layer_norm_op_test.cc | 10 +- .../test/contrib_ops/layer_norm_test.cc | 8 +- .../test/contrib_ops/skiplayernorm_op_test.cc | 12 +- .../providers/compare_provider_test_utils.cc | 2 + 11 files changed, 519 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h create mode 100644 onnxruntime/core/providers/webgpu/nn/layer_norm.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/layer_norm.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc new file mode 100644 index 000000000000..8997e8698d96 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc @@ -0,0 +1,36 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 1, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc new file mode 100644 index 000000000000..fb955b45f694 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +static uint32_t GetMaxComponents(int size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("skip", ShaderUsage::UseUniform); + shader.AddInput("gamma", ShaderUsage::UseUniform); + if (hasBeta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + if (hasBias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + int components = x.NumComponents(); + + std::string bias = (hasBias_) ? " + bias[offset1d + i] " : ""; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; + std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; + + shader.AdditionalImplementation() + << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n" + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n" + << "var stride = hidden_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * hidden_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = hidden_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let skip_value = skip[offset + i];\n" + << " let input_value = x[offset + i];\n" + << " let value = input_value + skip_value" << bias << ";\n" + << " output[offset + i] = value;\n" + << " let f32_value = f32_val_t(value);\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" + << "};\n"; + + return Status::OK(); +} + +template +Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* x = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* gamma = context.Input(2); + // optional + const Tensor* beta = context.Input(3); + const Tensor* bias = context.Input(4); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + const uint32_t hidden_size = SafeInt(x_shape[x_shape.NumDimensions() - 1]); + const int components = GetMaxComponents(hidden_size); + + SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, is_fp16, simplified}; + program + .CacheHint(simplified) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize(SafeInt(ceil(1.0 * data_size / hidden_size))) + .AddUniformVariables({ + {static_cast(components)}, + }) + .AddUniformVariables({ + {static_cast(hidden_size)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (beta != nullptr) { + program.AddInput({beta, ProgramTensorMetadataDependency::Type, components}); + } + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + SkipLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SkipSimplifiedLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h new file mode 100644 index 000000000000..d9ef732e28af --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class SkipLayerNormProgram final : public Program { + public: + SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} { + epsilon_ = epsilon; + hasBeta_ = hasBeta; + hasBias_ = hasBias; + epsilon_ = epsilon; + hidden_size_ = hidden_size; + simplified_ = simplified; + is_fp16_ = is_fp16; + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"components", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool hasBeta_; + bool hasBias_; + float epsilon_; + uint32_t hidden_size_; + bool is_fp16_; + bool simplified_; +}; + +template +class SkipLayerNorm final : public WebGpuKernel { + public: + SkipLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + float epsilon_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 93257d67c00a..04652aab2578 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -22,6 +22,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Ma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); @@ -42,19 +43,15 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it - // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo - }; + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc new file mode 100644 index 000000000000..5e4d5b7ad10e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -0,0 +1,155 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" + +namespace onnxruntime { +namespace webgpu { + +static int GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { + int64_t rank = static_cast(tensor_rank); + if (axis < -rank && axis >= rank) { + ORT_THROW("invalid axis: ", axis); + } + return SafeInt(axis < 0 ? axis + rank : axis); +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scale", ShaderUsage::UseUniform); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + int components = x.NumComponents(); + std::string bias = (has_bias_) ? " + bias[j]" : ""; + std::string simpl1 = (simplified_) ? "" : " - mean * mean"; + std::string simpl2 = (simplified_) ? "" : " - mean"; + + shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") + << "let offset = global_idx * uniforms.norm_size_vectorized;\n" + << "var mean_vector = f32_val_t(0);\n" + << "var mean_square_vector = f32_val_t(0);\n" + << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" + << " let value = f32_val_t(x[h + offset]);\n" + << " mean_vector += value;\n" + << " mean_square_vector += value * value;\n" + << "}\n" + << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n" + << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" + << " let f32input = f32_val_t(x[j + offset]);\n" + << " let f32scale = f32_val_t(scale[j]);\n" + << " output[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" + << "}\n"; + + return Status::OK(); +} + +template +Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* scale = context.Input(1); + const auto* bias = context.Input(2); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + + const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); + const uint32_t norm_count = SafeInt(x_shape.SizeToDimension(axis)); + const int64_t norm_size = x_shape.SizeFromDimension(axis); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = SafeInt((norm_size + components - 1) / components); + + const auto scale_size = scale->Shape().Size(); + const auto bias_size = (bias) ? bias->Shape().Size() : 0; + if (scale_size != norm_size || (bias && bias_size != norm_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale and bias (if provided) must match this. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } + + LayerNormProgram program{axis_, epsilon_, stash_type_, bias != nullptr, data_size, is_fp16, simplified}; + + program + .CacheHint(simplified) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(norm_count)}, + }) + .AddUniformVariables({ + {static_cast(norm_size)}, + }) + .AddUniformVariables({ + {static_cast(norm_size_vectorized)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h new file mode 100644 index 000000000000..e7014a1b80e2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class LayerNormProgram final : public Program { + public: + LayerNormProgram(int64_t axis, + float epsilon, + int64_t stash_type, + bool has_bias, + size_t x_size, + bool is_fp16, + bool simplified) : Program{"LayerNorm"}, + axis_{axis}, + epsilon_{epsilon}, + stash_type_{stash_type}, + has_bias_{has_bias}, + x_size_{x_size}, + is_fp16_{is_fp16}, + simplified_{simplified} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"norm_count", ProgramUniformVariableDataType::Uint32}, + {"norm_size", ProgramUniformVariableDataType::Uint32}, + {"norm_size_vectorized", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + int64_t axis_; + float epsilon_; + int64_t stash_type_; + bool has_bias_; + size_t x_size_; + bool is_fp16_; + bool simplified_; +}; + +template +class LayerNorm final : public WebGpuKernel { + public: + LayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("axis", &axis_, -1); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + info.GetAttrOrDefault("stash_type", &stash_type_, 1); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + int64_t axis_; + float epsilon_; + int64_t stash_type_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 4600f89cc9c9..16378d1d58cd 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -672,7 +672,7 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 655c4951f262..7fbaf9182424 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -120,7 +120,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { @@ -134,7 +134,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { @@ -178,7 +178,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, - kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider}); + kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { @@ -193,7 +193,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { @@ -208,7 +208,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 438a1100ca95..46082e1b0cd3 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -6,7 +6,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) constexpr auto k_epsilon_default = 1e-5f; constexpr auto k_random_data_min = -10.0f; constexpr auto k_random_data_max = 10.0f; @@ -65,8 +65,8 @@ static void TestLayerNorm(const std::vector& x_dims, std::vector Y_data = FillZeros(n_x_m_dims); test.AddOutput("output", n_x_m_dims, Y_data); -#ifndef USE_DML - // DML doesn't support more than one output for these ops yet +#if !defined(USE_DML) && !defined(USE_WEBGPU) + // DML and WebGPU don't support more than one output for these ops yet const std::vector& stats_dims = keep_dims ? n_and_ones_dims : n_dims; std::vector mean_data = FillZeros(stats_dims); std::vector var_data = FillZeros(stats_dims); @@ -84,6 +84,8 @@ static void TestLayerNorm(const std::vector& x_dims, test.CompareWithCPU(kRocmExecutionProvider); #elif USE_DML test.CompareWithCPU(kDmlExecutionProvider); +#elif USE_WEBGPU + test.CompareWithCPU(kWebGpuExecutionProvider); #endif } diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index edf9064bb43c..b9ca55073d41 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -62,6 +62,8 @@ static void RunOneTest( auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); auto cpu_ep = DefaultCpuExecutionProvider(); + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + std::vector> execution_providers; if (!use_float16) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); @@ -95,10 +97,14 @@ static void RunOneTest( if (cpu_ep != nullptr) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || - rocm_ep != nullptr) { + rocm_ep != nullptr || + webgpu_ep != nullptr) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("skip", skip_dims, ToFloat16(skip_data)); @@ -132,7 +138,9 @@ static void RunOneTest( ToFloat16(sum_output_data)); } - if (dml_ep != nullptr) { + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } else if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); } else if (rocm_ep != nullptr) { execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 3ef74259e27b..386a5656d8a0 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -36,6 +36,8 @@ std::unique_ptr GetExecutionProvider(const std::string& prov execution_provider = DefaultRocmExecutionProvider(); else if (provider_type == onnxruntime::kDmlExecutionProvider) execution_provider = DefaultDmlExecutionProvider(); + else if (provider_type == onnxruntime::kWebGpuExecutionProvider) + execution_provider = DefaultWebGpuExecutionProvider(); // skip if execution provider is disabled if (execution_provider == nullptr) { return nullptr; From 468c72086c0d4cb3eaab59ab4c588d1819263c4f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 1 Oct 2024 02:36:17 -0700 Subject: [PATCH 108/114] nodejs binding support webgpu --- cmake/onnxruntime_nodejs.cmake | 5 ++- js/node/CMakeLists.txt | 4 +++ js/node/lib/backend.ts | 4 ++- js/node/lib/binding.ts | 34 +++++++++++++++++++- js/node/script/build.ts | 5 +++ js/node/src/inference_session_wrap.cc | 24 ++++++++++++-- js/node/src/inference_session_wrap.h | 11 +++++++ js/node/src/session_options_helper.cc | 45 +++++++++++++++++++++++++++ 8 files changed, 127 insertions(+), 5 deletions(-) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index f11928c11cf1..376d895be34a 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -67,6 +67,9 @@ endif() if (onnxruntime_USE_DML) set(NODEJS_BINDING_USE_DML "--use_dml") endif() +if (onnxruntime_USE_WEBGPU) + set(NODEJS_BINDING_USE_WEBGPU "--use_webgpu") +endif() if (onnxruntime_USE_TENSORRT) set(NODEJS_BINDING_USE_TENSORRT "--use_tensorrt") endif() @@ -92,7 +95,7 @@ add_custom_target(js_common_npm_ci ALL add_custom_target(nodejs_binding_wrapper ALL COMMAND ${NPM_CLI} ci COMMAND ${NPM_CLI} run build -- --onnxruntime-build-dir=${CMAKE_CURRENT_BINARY_DIR} --config=${CMAKE_BUILD_TYPE} --onnxruntime-generator=${CMAKE_GENERATOR} - --arch=${NODEJS_BINDING_ARCH} ${NODEJS_BINDING_USE_CUDA} ${NODEJS_BINDING_USE_DML} ${NODEJS_BINDING_USE_TENSORRT} + --arch=${NODEJS_BINDING_ARCH} ${NODEJS_BINDING_USE_CUDA} ${NODEJS_BINDING_USE_DML} ${NODEJS_BINDING_USE_WEBGPU} ${NODEJS_BINDING_USE_TENSORRT} ${NODEJS_BINDING_USE_COREML} ${NODEJS_BINDING_USE_QNN} WORKING_DIRECTORY ${JS_NODE_ROOT} COMMENT "Using cmake-js to build OnnxRuntime Node.js binding") diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index 1ce6d66881c3..5d83790dc273 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api) # optional providers option(USE_DML "Build with DirectML support" OFF) +option(USE_WEBGPU "Build with WebGPU support" OFF) option(USE_CUDA "Build with CUDA support" OFF) option(USE_TENSORRT "Build with TensorRT support" OFF) option(USE_COREML "Build with CoreML support" OFF) @@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF) if(USE_DML) add_compile_definitions(USE_DML=1) endif() +if(USE_WEBGPU) + add_compile_definitions(USE_WEBGPU=1) +endif() if(USE_CUDA) add_compile_definitions(USE_CUDA=1) endif() diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index 46f8b83b0c5c..d4b4ce159044 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -3,12 +3,14 @@ import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common'; -import { Binding, binding } from './binding'; +import { Binding, binding, initOrt } from './binding'; class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) { + initOrt(); + this.#inferenceSession = new binding.InferenceSession(); if (typeof pathOrBuffer === 'string') { this.#inferenceSession.loadModel(pathOrBuffer, options); diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts index d6d592a1665b..50f32f613e5b 100644 --- a/js/node/lib/binding.ts +++ b/js/node/lib/binding.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { InferenceSession, OnnxValue } from 'onnxruntime-common'; +import { InferenceSession, OnnxValue, env } from 'onnxruntime-common'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = { @@ -48,4 +48,36 @@ export const binding = // eslint-disable-next-line @typescript-eslint/naming-convention InferenceSession: Binding.InferenceSessionConstructor; listSupportedBackends: () => Binding.SupportedBackend[]; + initOrtOnce: (logLevel: number) => void; }; + +let ortInitialized = false; +export const initOrt = (): void => { + if (!ortInitialized) { + ortInitialized = true; + if (env.logLevel) { + switch (env.logLevel) { + case 'verbose': + binding.initOrtOnce(0); + break; + case 'info': + binding.initOrtOnce(1); + break; + case 'warning': + binding.initOrtOnce(2); + break; + case 'error': + binding.initOrtOnce(3); + break; + case 'fatal': + binding.initOrtOnce(4); + break; + default: + throw new Error(`Unsupported log level: ${env.logLevel}`); + } + } else { + // default log level = warning + binding.initOrtOnce(2); + } + } +}; diff --git a/js/node/script/build.ts b/js/node/script/build.ts index 133d1a0d981a..dcdcb93377b4 100644 --- a/js/node/script/build.ts +++ b/js/node/script/build.ts @@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator']; const REBUILD = !!buildArgs.rebuild; // --use_dml const USE_DML = !!buildArgs.use_dml; +// --use_webgpu +const USE_WEBGPU = !!buildArgs.use_webgpu; // --use_cuda const USE_CUDA = !!buildArgs.use_cuda; // --use_tensorrt @@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') { if (USE_DML) { args.push('--CDUSE_DML=ON'); } +if (USE_WEBGPU) { + args.push('--CDUSE_WEBGPU=ON'); +} if (USE_CUDA) { args.push('--CDUSE_CUDA=ON'); } diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 057066507621..fbf6b99bd0e0 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -23,8 +23,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { Ort::Global::api_ == nullptr, env, "Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version " "ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)."); - auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"}; - env.SetInstanceData(ortEnv); + // initialize binding Napi::HandleScope scope(env); @@ -42,9 +41,27 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends); exports.Set("listSupportedBackends", listSupportedBackends); + Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce); + exports.Set("initOrtOnce", initOrtOnce); + return exports; } +Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + Napi::HandleScope scope(env); + + int log_level = info[0].As().Int32Value(); + + Ort::Env* ortEnv = env.GetInstanceData(); + if (ortEnv == nullptr) { + ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"}; + env.SetInstanceData(ortEnv); + } + + return env.Undefined(); +} + InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {} @@ -242,6 +259,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo #ifdef USE_DML result.Set(result.Length(), createObject("dml", true)); #endif +#ifdef USE_WEBGPU + result.Set(result.Length(), createObject("webgpu", true)); +#endif #ifdef USE_CUDA result.Set(result.Length(), createObject("cuda", false)); #endif diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index effdd83e3aa0..4a9adb4f065f 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -15,6 +15,17 @@ class InferenceSessionWrap : public Napi::ObjectWrap { InferenceSessionWrap(const Napi::CallbackInfo& info); private: + /** + * [sync] initialize ONNX Runtime once. + * + * This function must be called before any other functions. + * + * @param arg0 a number specifying the log level. + * + * @returns undefined + */ + static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info); + /** * [sync] list supported backend list * @returns array with objects { "name": "cpu", requirementsInstalled: true } diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 0ed1ba08e6bf..f0628de79c8a 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -15,6 +15,9 @@ #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" #endif +#ifdef USE_WEBGPU +#include "core/providers/webgpu/webgpu_provider_factory.h" +#endif #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" #endif @@ -77,6 +80,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } else if (name == "dml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId)); #endif +#ifdef USE_WEBGPU + } else if (name == "webgpu") { + Ort::ThrowOnError(Ort::GetApi().SessionOptionsAppendExecutionProvider(sessionOptions, "WebGPU", nullptr, nullptr, 0)); +#endif #ifdef USE_COREML } else if (name == "coreml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags)); @@ -195,4 +202,42 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio sessionOptions.SetLogSeverityLevel(static_cast(logLevelNumber)); } + + // external data + if (options.Has("externalData")) { + auto externalDataValue = options.Get("externalData"); + ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(), + "Invalid argument: sessionOptions.externalData must be an array."); + auto externalData = externalDataValue.As(); + std::vector paths; + std::vector buffs; + std::vector sizes; + + for (const auto& kvp : externalData) { + Napi::Value value = kvp.second; + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(), + "Invalid argument: sessionOptions.externalData value must be an object in Node.js binding."); + Napi::Object obj = value.As(); + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(), + "Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding."); + auto path = obj.Get("path").As().Utf16Value(); + paths.push_back(std::wstring{path.begin(), path.end()}); + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") || + !obj.Get("data").IsBuffer() || + !(obj.Get("data").IsTypedArray() && obj.Get("data").As().TypedArrayType() == napi_uint8_array), + options.Env(), + "Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding."); + + auto data = obj.Get("data"); + if (data.IsBuffer()) { + buffs.push_back(data.As>().Data()); + sizes.push_back(data.As>().Length()); + } else { + auto typedArray = data.As(); + buffs.push_back(reinterpret_cast(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset()); + sizes.push_back(typedArray.ByteLength()); + } + } + sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes); + } } From cbf106e7ed695ace68985ed636ac408f6ec01e81 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 1 Oct 2024 02:36:30 -0700 Subject: [PATCH 109/114] fix where --- .../core/providers/webgpu/tensor/where.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index b37014eb05da..dada446b4bd4 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -84,16 +84,16 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; - shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n" + shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4 + " + x) << ";\n" << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" << "let offset_b" << x << " = " << b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" << "let offset_c" << x << " = " << c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" - << "let index_a" << x << " = offset_a" << x << " / 4u;\n" - << "let index_b" << x << " = offset_b" << x << " / 4u;\n" - << "let index_c" << x << " = offset_c" << x << " / 4u;\n" - << "let component_a" << x << " = offset_a" << x << " % 4u;\n" - << "let component_b" << x << " = offset_b" << x << " % 4u;\n" - << "let component_c" << x << " = offset_c" << x << " % 4u;\n" + << "let index_a" << x << " = offset_a" << x << " / 4;\n" + << "let index_b" << x << " = offset_b" << x << " / 4;\n" + << "let index_c" << x << " = offset_c" << x << " / 4;\n" + << "let component_a" << x << " = offset_a" << x << " % 4;\n" + << "let component_b" << x << " = offset_b" << x << " % 4;\n" + << "let component_c" << x << " = offset_c" << x << " % 4;\n" << rest_str << "[" << x << "] = " << type_cast << "(" << expression(a_expression, b_expression, c_expression) << ");\n"; }; @@ -134,9 +134,9 @@ Status Where::ComputeInternal(ComputeContext& context) const { program .CacheHint(is_broadcast) .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::None, {(cond_shape.Size() + 3) / 4}, 4}, - {x_tensor, ProgramTensorMetadataDependency::None, {(x_shape.Size() + 3) / 4}, 4}, - {y_tensor, ProgramTensorMetadataDependency::None, {(y_shape.Size() + 3) / 4}, 4}}) + .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4}, + {x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4}, + {y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}}) .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) .AddUniformVariables({ {static_cast(vec_size)}, From 5086c7cbcc87c26e9ff8d4f46e2e978b5001f389 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 1 Oct 2024 03:06:13 -0700 Subject: [PATCH 110/114] revert some changes that are not necessary --- cmake/onnxruntime_providers_webnn.cmake | 2 +- .../core/session/onnxruntime_c_api.h | 32 ----------- .../core/session/onnxruntime_cxx_api.h | 3 - .../core/session/onnxruntime_cxx_inline.h | 19 ------- onnxruntime/core/session/ort_apis.h | 7 --- .../core/session/provider_registration.cc | 56 ------------------- 6 files changed, 1 insertion(+), 118 deletions(-) diff --git a/cmake/onnxruntime_providers_webnn.cmake b/cmake/onnxruntime_providers_webnn.cmake index 39ca476810f4..05c63c22244d 100644 --- a/cmake/onnxruntime_providers_webnn.cmake +++ b/cmake/onnxruntime_providers_webnn.cmake @@ -22,4 +22,4 @@ add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b3e4b9fc5712..39e0361b7ff4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -625,38 +625,6 @@ typedef struct OrtMIGraphXProviderOptions { bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false } OrtMIGraphXProviderOptions; -/** \brief WebGPU Execution Provider Options - * - * When a user wants to use WebGPU as the execution provider, there are 2 ways to specify the WebGPU device: - * - * 1. Use the default WebGPU device. The default WebGPU device is managed by WebGPU EP internally. The user doesn't - * need to provide any device information in this case. All the fields should be set to nullptr or 0. - * - * 2. Use a custom WebGPU device. The user should create their own handles of `WGPUInstance`, `WGPUAdapter`, and - * `WGPUDevice` and use arbitrary number in [1..65536) as the device id. The user should provide the handles - * and the device id in the options. - * - * When specifying an existing Device ID, the user should provide the handles of `WGPUInstance`, `WGPUAdapter`, and - * `WGPUDevice` in the options. The device id should be the same as the one used previously. - * - * It's user's responsibility to manage the lifecycle of the handles and ensure the handles are valid during the - * lifetime of the inference session. - * - * About DawnProcTable: - * - * When using an ONNX Runtime build that is not directly linked dawn during the build, a pointer to the runtime memory - * address of the DawnProcTable should be provided. Otherwise, keep it as nullptr. - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_WGPU - */ -typedef struct OrtWGPUProviderOptions { - int device_id; // WebGPU device id. - void* instance_handle; // WebGPU instance handle. - void* adapter_handle; // WebGPU adapter handle. - void* device_handle; // WebGPU device handle. - void* dawn_proc_table; // DawnProcTable pointer. -} OrtWGPUProviderOptions; - /** \brief OpenVINO Provider Options * * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index f534f53a796a..12a6a5c87c0a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -925,9 +925,6 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WGPU - SessionOptionsImpl& AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options, - const std::unordered_map& string_options = {}); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 49f8242249d8..7401cb243812 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -857,25 +857,6 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIG return *this; } -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options, - const std::unordered_map& string_options) { - auto num_entries = string_options.size(); - std::vector keys, values; - if (num_entries > 0) { - keys.reserve(num_entries); - values.reserve(num_entries); - - for (const auto& entry : string_options) { - keys.push_back(entry.first.c_str()); - values.push_back(entry.second.c_str()); - } - } - - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WGPU(this->p_, &wgpu_options, keys.data(), values.data(), num_entries)); - return *this; -} - template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 47691cf24af2..905424687323 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -384,13 +384,6 @@ ORT_API_STATUS_IMPL(InvokeOp, ORT_API(void, ReleaseOp, _Frees_ptr_opt_ OrtOp* op); -ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_WGPU, - _In_ OrtSessionOptions* options, - _In_ const OrtWGPUProviderOptions* wgpu_options, - _In_reads_(num_keys) const char* const* string_options_keys, - _In_reads_(num_keys) const char* const* string_options_values, - _In_ size_t num_keys); - ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* provider_name, diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 08222acf8209..8c512c561ea8 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include -#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -167,61 +166,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WGPU, - _In_ OrtSessionOptions* options, - _In_ const OrtWGPUProviderOptions* webgpu_options, - _In_reads_(num_keys) const char* const* string_options_keys, - _In_reads_(num_keys) const char* const* string_options_values, - _In_ size_t num_keys) { - API_IMPL_BEGIN - std::vector options_keys; - options_keys.reserve(num_keys + 4); - std::vector options_values; - options_values.reserve(num_keys + 4); - - // the following code uses std::to_chars() to convert int/size_t to string. - // unlike std::to_string(), std::to_chars() is guaranteed locale-independent. - // - // uint64_t to string is no more than 20 characters, and - // int32_t to string is no more than 11 characters. - static_assert(sizeof(size_t) == 4 || sizeof(size_t) == 8); - char buffer[sizeof(size_t) == 4 ? 11 : 20]; - - auto res = std::to_chars(buffer, buffer + sizeof(buffer), webgpu_options->device_id); - ORT_ENFORCE(res.ec == std::errc(), "Failed to convert device_id to string"); - std::string device_id(buffer, res.ptr - buffer); - options_keys.push_back("deviceId"); - options_values.push_back(device_id.c_str()); - - res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->instance_handle)); - ORT_ENFORCE(res.ec == std::errc(), "Failed to convert instance_handle to string"); - std::string instance_handle(buffer, res.ptr - buffer); - options_keys.push_back("webgpuInstance"); - options_values.push_back(instance_handle.c_str()); - - res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->adapter_handle)); - ORT_ENFORCE(res.ec == std::errc(), "Failed to convert adapter_handle to string"); - std::string adapter_handle(buffer, res.ptr - buffer); - options_keys.push_back("webgpuAdapter"); - options_values.push_back(adapter_handle.c_str()); - - res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->device_handle)); - ORT_ENFORCE(res.ec == std::errc(), "Failed to convert device_handle to string"); - std::string device_handle(buffer, res.ptr - buffer); - options_keys.push_back("webgpuDevice"); - options_values.push_back(device_handle.c_str()); - - // TODO: dawn proc table - - for (size_t i = 0; i != num_keys; ++i) { - options_keys.push_back(string_options_keys[i]); - options_values.push_back(string_options_values[i]); - } - - return OrtApis::SessionOptionsAppendExecutionProvider(options, "WebGPU", options_keys.data(), options_values.data(), options_keys.size()); - API_IMPL_END -} - #if defined(__APPLE__) || defined(ORT_MINIMAL_BUILD) static OrtStatus* CreateNotEnabledStatus(const std::string& ep) { return OrtApis::CreateStatus(ORT_FAIL, (ep + " execution provider is not enabled in this build. ").c_str()); From fe7d3e477b5f616813a8fccb50bc03195dcc0616 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 1 Oct 2024 03:21:08 -0700 Subject: [PATCH 111/114] revise perftest help msg --- onnxruntime/test/perftest/command_args_parser.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 1a8d33a80f95..42b73ec384cf 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -38,7 +38,7 @@ namespace perftest { "\t-A: Disable memory arena\n" "\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n" "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" - "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai:webgpu]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " + "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai|webgpu]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'. " "Default:'cpu'.\n" "\t-b [tf|ort]: backend to use. Default:ort\n" From d219bb77a9848c64b826b5a7f0425982fee00109 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 1 Oct 2024 11:46:44 -0700 Subject: [PATCH 112/114] [webgpu-native] Fix a few build errors on Linux (#22286) 1. On Linux path strings are not std::wstring. They should be std::string. 2. "auto i =0" means i will be int. But sometimes we want to it to be size_t. --- js/node/src/session_options_helper.cc | 7 ++++++- onnxruntime/core/providers/webgpu/tensor/transpose.cc | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index f0628de79c8a..262d78acc3f4 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -209,7 +209,7 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(), "Invalid argument: sessionOptions.externalData must be an array."); auto externalData = externalDataValue.As(); - std::vector paths; + std::vector> paths; std::vector buffs; std::vector sizes; @@ -220,8 +220,13 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio Napi::Object obj = value.As(); ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(), "Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding."); +#ifdef _WIN32 auto path = obj.Get("path").As().Utf16Value(); paths.push_back(std::wstring{path.begin(), path.end()}); +#else + auto path = obj.Get("path").As().Utf8Value(); + paths.push_back(path); +#endif ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") || !obj.Get("data").IsBuffer() || !(obj.Get("data").IsTypedArray() && obj.Get("data").As().TypedArrayType() == napi_uint8_array), diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index adcee8b64fd8..c40ec43dd000 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -48,7 +48,7 @@ ONNX_OPERATOR_KERNEL_EX( Transpose); auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { - for (auto i = 0; i < shape.size(); ++i) { + for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] != 1) { new_shape.push_back(shape[i]); } From 7f7d6dadc44a0d7ec9f4117addaeea7eef137948 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 1 Oct 2024 11:58:10 -0700 Subject: [PATCH 113/114] format --- js/node/src/inference_session_wrap.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index 4a9adb4f065f..15dcd1bf9d72 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -17,11 +17,11 @@ class InferenceSessionWrap : public Napi::ObjectWrap { private: /** * [sync] initialize ONNX Runtime once. - * + * * This function must be called before any other functions. - * + * * @param arg0 a number specifying the log level. - * + * * @returns undefined */ static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info); From 27640e3dec141730e80ee0dba96afcb99055a36d Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 1 Oct 2024 14:35:03 -0700 Subject: [PATCH 114/114] fix issues for e2e phi3 (#22287) - support for input_skip_bias_sum in SkipLayerNorm - use GetElementAt in concat --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../contrib_ops/webgpu/bert/skip_layer_norm.cc | 15 ++++++++++++--- .../contrib_ops/webgpu/bert/skip_layer_norm.h | 4 +++- .../contrib_ops/webgpu/webgpu_contrib_kernels.cc | 2 +- .../core/providers/webgpu/shader_variable.h | 6 ++++-- .../core/providers/webgpu/tensor/concat.cc | 8 ++++---- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index fb955b45f694..254dd26b8a14 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -43,6 +43,9 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("bias", ShaderUsage::UseUniform); } shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_input_skip_bias_sum_) { + shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseUniform); + } int components = x.NumComponents(); @@ -50,6 +53,7 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string simpl1 = (simplified_) ? "" : "- mean * mean "; std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; + std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? "input_skip_bias_sum[offset + i] = value;\n" : ""; shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") @@ -72,6 +76,7 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let input_value = x[offset + i];\n" << " let value = input_value + skip_value" << bias << ";\n" << " output[offset + i] = value;\n" + << input_skip_bias_sum << " let f32_value = f32_val_t(value);\n" << " sum_shared[ix] += f32_value;\n" << " sum_squared_shared[ix] += f32_value * f32_value;\n" @@ -109,6 +114,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo const auto x_shape = x->Shape(); auto* output = context.Output(0, x_shape); + auto* input_skip_bias_sum = context.Output(3, x_shape); size_t data_size = x_shape.Size(); if (data_size == 0) { @@ -118,10 +124,11 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; const uint32_t hidden_size = SafeInt(x_shape[x_shape.NumDimensions() - 1]); const int components = GetMaxComponents(hidden_size); + const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr; - SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, is_fp16, simplified}; + SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, is_fp16, simplified}; program - .CacheHint(simplified) + .CacheHint(simplified, has_input_skip_bias_sum) .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) @@ -143,7 +150,9 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo if (bias != nullptr) { program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); } - + if (has_input_skip_bias_sum) { + program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); + } return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h index d9ef732e28af..03de1a4b568b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h @@ -15,12 +15,13 @@ using onnxruntime::webgpu::ComputeContext; class SkipLayerNormProgram final : public Program { public: - SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} { + SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool has_input_skip_bias_sum, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} { epsilon_ = epsilon; hasBeta_ = hasBeta; hasBias_ = hasBias; epsilon_ = epsilon; hidden_size_ = hidden_size; + has_input_skip_bias_sum_ = has_input_skip_bias_sum; simplified_ = simplified; is_fp16_ = is_fp16; } @@ -37,6 +38,7 @@ class SkipLayerNormProgram final : public Program { bool hasBias_; float epsilon_; uint32_t hidden_size_; + bool has_input_skip_bias_sum_; bool is_fp16_; bool simplified_; }; diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 04652aab2578..4006006a76ba 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -46,7 +46,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index cad7b0ceb830..4d4655925c98 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -13,8 +13,10 @@ namespace onnxruntime { namespace webgpu { -template -std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool is_f16 = false) { +template || std::is_same_v>> +std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. if (var.rfind("uniforms.", 0) == 0) { if (rank > 4) { diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 866f99b587bc..be1b971f6abe 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -41,7 +41,7 @@ WEBGPU_CONCAT_KERNEL(13) void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { os << "fn calculate_input_index(index: u32) -> u32 {\n" << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" - << " if (index < uniforms.size_in_concat_axis[i]) {\n" + << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" << " return i;\n" << " }\n" << " }\n" @@ -78,13 +78,13 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); // add implementation of fn assign_output_data AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); - + const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" << " let input_index = calculate_input_index(indices_axis);\n" - " if (input_index != 0u) {\n" - << " " << output.IndicesSet("indices", axis_, "indices_axis - uniforms.size_in_concat_axis[input_index - 1]") << ";\n" + << " if (input_index != 0u) {\n" + << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" << " }\n" " assign_output_data(global_idx, input_index, indices);\n"; return Status::OK();