diff --git a/.github/actions/rust-toolchain-setup/action.yml b/.github/actions/rust-toolchain-setup/action.yml deleted file mode 100644 index bf73fede16c7..000000000000 --- a/.github/actions/rust-toolchain-setup/action.yml +++ /dev/null @@ -1,44 +0,0 @@ -# yaml-language-server: $schema=https://json.schemastore.org/github-action.json - -name: 'Rust toolchain setup' -description: 'Common setup steps for GitHub workflows for Rust projects' - -runs: - using: composite - steps: - - uses: dtolnay/rust-toolchain@1.71.0 - with: - components: clippy, rustfmt - - uses: extractions/setup-just@v1 - with: - just-version: '1.15.0' # optional semver specification, otherwise latest - - ### - ### Linux setup - ### - - name: rustup - # We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds. - if: ${{ (runner.os == 'Linux') }} - run: | - rustup set profile minimal - rustup install - shell: bash - # - name: Cargo login - # if: ${{ (runner.os == 'Linux') }} - # run: just cargo-login-ci - # shell: bash - - ### - ### Windows setup - ### - - name: rustup - # We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds. - if: ${{ (runner.os == 'Windows') }} - run: | - rustup set profile minimal - rustup install - shell: pwsh - # - name: Cargo login - # if: ${{ (runner.os == 'Windows') }} - # run: just cargo-login-ci-windows - # shell: pwsh diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 4a5b87b3e69e..e4d1b91bab73 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -47,6 +47,14 @@ jobs: # Details on CodeQL's query packs refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs queries: security-extended,security-and-quality + # Setup Java to use a version that is not too old for the project + - if: ${{ matrix.language == 'java' }} + name: Setup Java 11 + uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: 'microsoft' + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - if: ${{ matrix.language != 'cpp' }} diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 03ea773a2513..73df5e31fda6 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -11,4 +11,4 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: gradle/wrapper-validation-action@v1 + - uses: gradle/wrapper-validation-action@v3 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index ce8fb3160954..a196226a4b83 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -3,11 +3,14 @@ on: issues: types: [opened, edited] +permissions: + issues: write + jobs: triage: runs-on: ubuntu-latest steps: - - uses: github/issue-labeler@v3.3 + - uses: github/issue-labeler@v3.4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" configuration-path: .github/labeler.yml diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index c03399f4693b..5bc21595bf88 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -37,7 +37,7 @@ jobs: wget https://github.com/dotnet/docfx/releases/download/v${DOCFXVERSION}/docfx-linux-x64-v${DOCFXVERSION}.zip -O build/docfx/docfx.zip unzip build/docfx/docfx.zip -d build/docfx - name: Install NuGet - uses: nuget/setup-nuget@v1 + uses: nuget/setup-nuget@v2 - name: Build Documentation run: | build/docfx/docfx metadata csharp/ApiDocs/docfx.json diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 708842e59f9f..3e553049a186 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -30,7 +30,7 @@ jobs: java-version: '11' distribution: 'adopt' - name: Build with Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 with: build-root-directory: java gradle-executable: java/gradlew diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index b9f3c0b9a398..ebacd38f1f88 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -21,7 +21,7 @@ permissions: jobs: build: name: Generate Objective-C API docs - runs-on: macos-13 + runs-on: macos-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml deleted file mode 100644 index 725c40c2ded5..000000000000 --- a/.github/workflows/rust-ci.yml +++ /dev/null @@ -1,132 +0,0 @@ -name: Rust - -on: [pull_request] - -env: - CARGO_TERM_COLOR: always - RUST_LOG: onnxruntime=debug,onnxruntime-sys=debug - RUST_BACKTRACE: 1 - MANIFEST_PATH: ${{ github.workspace }}/rust/Cargo.toml - -jobs: - fmt: - name: Rustfmt - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - name: fmt - run: cargo fmt --all -- --check - - download: - name: Download prebuilt ONNX Runtime archive from build.rs - runs-on: ubuntu-latest - env: - ORT_RUST_STRATEGY: download - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - run: rustup target install x86_64-unknown-linux-gnu - - run: rustup target install x86_64-apple-darwin - - run: rustup target install i686-pc-windows-msvc - - run: rustup target install x86_64-pc-windows-msvc - # ****************************************************************** - - name: Download prebuilt archive (CPU, x86_64-unknown-linux-gnu) - run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, x86_64-unknown-linux-gnu) - run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-1.*.tgz - # ****************************************************************** - - name: Download prebuilt archive (CPU, x86_64-apple-darwin) - run: cargo build --target x86_64-apple-darwin --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, x86_64-apple-darwin) - run: ls -lh target/x86_64-apple-darwin/debug/build/onnxruntime-sys-*/out/onnxruntime-osx-x64-1.*.tgz - # ****************************************************************** - - name: Download prebuilt archive (CPU, i686-pc-windows-msvc) - run: cargo build --target i686-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, i686-pc-windows-msvc) - run: ls -lh target/i686-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x86-1.*.zip - # ****************************************************************** - - name: Download prebuilt archive (CPU, x86_64-pc-windows-msvc) - run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, x86_64-pc-windows-msvc) - run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x64-1.*.zip - # ****************************************************************** - - name: Download prebuilt archive (GPU, x86_64-unknown-linux-gnu) - env: - ORT_USE_CUDA: "yes" - run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (GPU, x86_64-unknown-linux-gnu) - run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-gpu-1.*.tgz - # ****************************************************************** - - name: Download prebuilt archive (GPU, x86_64-pc-windows-msvc) - env: - ORT_USE_CUDA: "yes" - run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (GPU, x86_64-pc-windows-msvc) - run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-gpu-x64-1.*.zip - - test: - name: Test Suite - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - target: - [ - x86_64-unknown-linux-gnu, - x86_64-apple-darwin, - x86_64-pc-windows-msvc, - i686-pc-windows-msvc, - ] - include: - - target: x86_64-unknown-linux-gnu - os: ubuntu-latest - - target: x86_64-apple-darwin - os: macos-latest - - target: x86_64-pc-windows-msvc - os: windows-latest - - target: i686-pc-windows-msvc - os: windows-latest - env: - CARGO_BUILD_TARGET: ${{ matrix.target }} - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - run: rustup target install ${{ matrix.target }} - - name: Install additional packages (macOS) - if: contains(matrix.target, 'x86_64-apple-darwin') - run: brew install libomp - - name: Build (cargo build) - run: cargo build --all --manifest-path ${{ env.MANIFEST_PATH }} - - name: Build tests (cargo test) - run: cargo test --no-run --manifest-path ${{ env.MANIFEST_PATH }} - - name: Build onnxruntime with 'model-fetching' feature - run: cargo build --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching - - name: Test onnxruntime-sys - run: cargo build --package onnxruntime-sys -- --test-threads=1 --nocapture - - name: Test onnxruntime - run: cargo test --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching -- --test-threads=1 --nocapture - - clippy: - name: Clippy - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - run: clippy --all-features --manifest-path ${{ env.MANIFEST_PATH }} -- -D warnings - - package-sys: - name: Package onnxruntime-sys - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - run: cargo package --allow-dirty --package onnxruntime-sys diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c94e3fa5bcb8..181f3fb17d33 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9.0.0 + - uses: actions/stale@v8 with: # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: contributions welcome, feature request, regression diff --git a/.gitmodules b/.gitmodules index 7bb49e98bfec..f874660971d4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "cmake/external/emsdk"] path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git - branch = 3.1.44 + branch = 3.1.51 diff --git a/.lintrunner.toml b/.lintrunner.toml index 4e5d077b08ff..be95e03479cf 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -132,6 +132,7 @@ exclude_patterns = [ 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/core/graph/contrib_ops/quantization_defs.cc', 'onnxruntime/core/mlas/**', # Contains assembly code + 'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks ] command = [ diff --git a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml index b9de1b79e1d5..fd3b7266d30f 100644 --- a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml +++ b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml @@ -29,6 +29,8 @@ extends: git: submodules: false globalSdl: # https://aka.ms/obpipelines/sdl + asyncSdl: + enabled: false tsa: enabled: true prefast: @@ -53,10 +55,6 @@ extends: BuildArch: x86 PythonPackageName: pythonx86 - - template: .pipelines/windowsai-steps.yml@self - parameters: - BuildArch: arm - - template: .pipelines/windowsai-steps.yml@self parameters: BuildArch: arm64 @@ -72,11 +70,6 @@ extends: PythonPackageName: pythonx86 Runtime: static - - template: .pipelines/windowsai-steps.yml@self - parameters: - BuildArch: arm - Runtime: static - - template: .pipelines/windowsai-steps.yml@self parameters: BuildArch: arm64 @@ -94,11 +87,9 @@ extends: dependsOn: - Windows_Packaging_x64_dynamic - Windows_Packaging_x86_dynamic - - Windows_Packaging_arm_dynamic - Windows_Packaging_arm64_dynamic - Windows_Packaging_x64_static - Windows_Packaging_x86_static - - Windows_Packaging_arm_static - Windows_Packaging_arm64_static condition: succeeded() steps: @@ -120,12 +111,6 @@ extends: artifactName: 'drop_Windows_Build_Windows_Packaging_arm64_dynamic' targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet DirectML arm' - inputs: - artifactName: 'drop_Windows_Build_Windows_Packaging_arm_dynamic' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm' - - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet DirectML x64 StaticRuntime' inputs: @@ -144,12 +129,6 @@ extends: artifactName: 'drop_Windows_Build_Windows_Packaging_arm64_static' targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64-static-runtime' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet DirectML arm StaticRuntime' - inputs: - artifactName: 'drop_Windows_Build_Windows_Packaging_arm_static' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm-static-runtime' - - task: PowerShell@2 displayName: 'Bundle NuGet and other binaries' inputs: @@ -194,17 +173,7 @@ extends: $arm64_static_runtime_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm64_static_runtime_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm64_static_runtime_nuget_package)) [System.IO.Compression.ZipFile]::ExtractToDirectory($arm64_static_runtime_nuget_package, $arm64_static_runtime_nupkg_unzipped_directory) - $nupkgs = (Get-ChildItem ..\nuget-artifact-arm -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) - $arm_nuget_package = $nupkgs[0].FullName - $arm_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName - $arm_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm_nuget_package)) - [System.IO.Compression.ZipFile]::ExtractToDirectory($arm_nuget_package, $arm_nupkg_unzipped_directory) - - $nupkgs = (Get-ChildItem ..\nuget-artifact-arm-static-runtime -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) - $arm_static_runtime_nuget_package = $nupkgs[0].FullName - $arm_static_runtime_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName - $arm_static_runtime_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm_static_runtime_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm_static_runtime_nuget_package)) - [System.IO.Compression.ZipFile]::ExtractToDirectory($arm_static_runtime_nuget_package, $arm_static_runtime_nupkg_unzipped_directory) + $x64_static_runtime_path_old = [System.IO.Path]::Combine($x64_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-x64', '_native') $x64_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x64', '_native', 'static') @@ -216,10 +185,7 @@ extends: $arm64_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') $arm64_static_runtime_path_old = [System.IO.Path]::Combine($arm64_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') $arm64_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native', 'static') - $arm_runtime_path_old = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_static_runtime_path_old = [System.IO.Path]::Combine($arm_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native', 'static') + $uap_build_path_old = [System.IO.Path]::Combine($x64_static_runtime_nupkg_unzipped_directory, 'build', 'native') $uap_build_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'build', 'uap10.0') @@ -228,8 +194,6 @@ extends: New-Item -Path $x86_static_runtime_path_new -ItemType Directory New-Item -Path $arm64_runtime_path_new -ItemType Directory New-Item -Path $arm64_static_runtime_path_new -ItemType Directory - New-Item -Path $arm_runtime_path_new -ItemType Directory - New-Item -Path $arm_static_runtime_path_new -ItemType Directory Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.dll')) $x86_runtime_path_new Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.lib')) $x86_runtime_path_new @@ -241,11 +205,6 @@ extends: Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.dll')) $arm64_runtime_path_new Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.lib')) $arm64_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'onnxruntime.dll')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'onnxruntime.lib')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'microsoft.ai.machinelearning.dll')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'microsoft.ai.machinelearning.lib')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'onnxruntime.dll')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'onnxruntime.dll')) Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'onnxruntime.lib')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'onnxruntime.lib')) Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) @@ -261,11 +220,6 @@ extends: Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'microsoft.ai.machinelearning.lib')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'microsoft.ai.machinelearning.lib')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'onnxruntime.dll')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'onnxruntime.dll')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'onnxruntime.lib')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'onnxruntime.lib')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'microsoft.ai.machinelearning.lib')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'microsoft.ai.machinelearning.lib')) - Copy-Item -Recurse $uap_build_path_old $uap_build_path_new $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'merged') @@ -304,22 +258,13 @@ extends: $arm64_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory_root, 'symbols', [System.IO.Path]::GetFileNameWithoutExtension($arm64_nuget_package)) [System.IO.Compression.ZipFile]::ExtractToDirectory($arm64_nuget_package, $arm64_nupkg_unzipped_directory) - $nupkgs = (Get-ChildItem ..\nuget-artifact-arm -Filter Microsoft.AI.MachineLearning*.snupkg -Recurse) - $arm_nuget_package = $nupkgs[0].FullName - $arm_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName - $arm_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory_root, 'symbols', [System.IO.Path]::GetFileNameWithoutExtension($arm_nuget_package)) - [System.IO.Compression.ZipFile]::ExtractToDirectory($arm_nuget_package, $arm_nupkg_unzipped_directory) - $x86_runtime_path_old = [System.IO.Path]::Combine($x86_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') $x86_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') $arm64_runtime_path_old = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') $arm64_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') - $arm_runtime_path_old = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - + New-Item -Path $x86_runtime_path_new -ItemType Directory New-Item -Path $arm64_runtime_path_new -ItemType Directory - New-Item -Path $arm_runtime_path_new -ItemType Directory Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.pdb')) $x86_runtime_path_new Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $x86_runtime_path_new @@ -327,9 +272,6 @@ extends: Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'onnxruntime.pdb')) $arm64_runtime_path_new Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $arm64_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'onnxruntime.pdb')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $arm_runtime_path_new - $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'merged') if (!(Test-Path $merged_nuget_path)) { New-Item -Path $merged_nuget_path -ItemType Directory diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 2ac650b0e6dc..b862dec5e1c8 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index f80f96194a23..c348dd3e9cda 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml index 292ce60c6b6c..855573de753b 100644 --- a/.pipelines/windowsai-steps.yml +++ b/.pipelines/windowsai-steps.yml @@ -80,11 +80,11 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | - curl -O -L https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-windows-x86_64.zip - 7z x cmake-3.26.3-windows-x86_64.zip + curl -O -L https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-windows-x86_64.zip + 7z x cmake-3.28.3-windows-x86_64.zip set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools - $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe + $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" --cmake_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\ctest.exe workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/.vscode/settings.json b/.vscode/settings.json index 2f2adc78f6de..98d23090fd47 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,7 +11,7 @@ // Auto sort imports "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, "editor.defaultFormatter": "ms-python.black-formatter" }, @@ -21,5 +21,8 @@ "cpplint.filters": [ "-build/include_subdir", "-runtime/references" - ] + ], + "files.associations": { + "span": "cpp" + } } diff --git a/CITATION.cff b/CITATION.cff index 82bcac5a7b75..10b7290022ae 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,8 +3,7 @@ title: ONNX Runtime message: "Please use this information to cite ONNX Runtime in research or other publications." authors: - - affiliation: Microsoft Corporation - given-names: ONNX Runtime developers + - name: ONNX Runtime developers date-released: 2018-11-29 url: "https://onnxruntime.ai" repository-code: "https://github.com/microsoft/onnxruntime" diff --git a/README.md b/README.md index 33bce867e3bd..24c3e191c115 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ * **General Information**: [onnxruntime.ai](https://onnxruntime.ai) -* **Usage documention and tutorials**: [onnxruntime.ai/docs](https://onnxruntime.ai/docs) +* **Usage documentation and tutorials**: [onnxruntime.ai/docs](https://onnxruntime.ai/docs) * **YouTube video tutorials**: [youtube.com/@ONNXRuntime](https://www.youtube.com/@ONNXRuntime) diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 700206180dec..8ec770da2215 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -1829,7 +1829,7 @@ Zbigniew Skowron _____ -HalidelR +HalideIR Copyright (c) 2016 HalideIR contributors Copyright (c) 2012-2014 MIT CSAIL, Google Inc., and other contributors @@ -6299,3 +6299,210 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +_____ + +neural-speed + +https://github.com/intel/neural-speed + + Apache License + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + ============================================================================ + + Copyright 2016-2019 Intel Corporation + Copyright 2018 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This distribution includes third party software ("third party programs"). + This third party software, even if included with the distribution of + the Intel software, may be governed by separate license terms, including + without limitation, third party license terms, other Intel software license + terms, and open source software license terms. These separate license terms + govern your use of the third party programs as set forth in the + "THIRD-PARTY-PROGRAMS" file. diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 092afa15df4d..84cc529467b0 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.17.0 +1.18.0 diff --git a/build_arm64x.bat b/build_arm64x.bat index fbcdd373086a..1ed268ae94a4 100644 --- a/build_arm64x.bat +++ b/build_arm64x.bat @@ -5,7 +5,6 @@ setlocal set PATH=C:\Program Files\Git\usr\bin;%PATH% -set LINK_REPRO_NAME=/mylink.rsp rem Requires a Python install to be available in your PATH python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %* diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index e8dbc9cf9eff..cf245e63a3a5 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -469,7 +469,7 @@ "type": "pip", "pip": { "Name": "transformers", - "Version": "2.11.0" + "Version": "4.36.0" }, "comments": "Installed in the training docker image" } @@ -570,7 +570,7 @@ "git": { "commitHash": "e7248b26a1ed53fa030c5c459f7ea095dfd276ac", "repositoryUrl": "https://gitlab.com/libeigen/eigen.git" - } + } } } ], diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py index 81181d3ccfb2..3cecbb0cc977 100644 --- a/cgmanifests/generate_cgmanifest.py +++ b/cgmanifests/generate_cgmanifest.py @@ -115,8 +115,8 @@ def normalize_path_separators(path): submodule_lines = proc.stdout.splitlines() for submodule_line in submodule_lines: (absolute_path, url, commit) = submodule_line.split(" ") - git_deps[GitDep(commit, url)] = "git submodule at {}".format( - normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR)) + git_deps[GitDep(commit, url)] = ( + f"git submodule at {normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR))}" ) with open(os.path.join(SCRIPT_DIR, "..", "cmake", "deps.txt")) as f: diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 137ea8a50c01..b26455379b96 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -6,7 +6,7 @@ "component": { "type": "git", "git": { - "commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57", + "commitHash": "4e2496141eda15040c44e9bbf237a1326368e34c", "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" }, "comments": "git submodule at cmake/external/emsdk" @@ -26,7 +26,7 @@ "component": { "type": "git", "git": { - "commitHash": "b86cc54efce19530fb953e4b21f57e6b3888534c", + "commitHash": "990217f043af7222348ca8f0301e17fa7b841781", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "git submodule at cmake/external/onnx" @@ -36,12 +36,22 @@ "component": { "type": "git", "git": { - "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", + "commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "dbb0094fd0cb936469e35320bf37e866ef7a1da4", + "repositoryUrl": "https://github.com/apple/coremltools.git" + }, + "comments": "coremltools" + } + }, { "component": { "type": "git", @@ -76,7 +86,7 @@ "component": { "type": "git", "git": { - "commitHash": "6df40a2471737b27271bdd9b900ab5f3aec746c7", + "commitHash": "0100f6a5779831fa7a651e4b67ef389a8752bd9b", "repositoryUrl": "https://github.com/google/flatbuffers.git" }, "comments": "flatbuffers" @@ -106,7 +116,7 @@ "component": { "type": "git", "git": { - "commitHash": "361e8d1cfe0c6c36d30b39f1b61302ece5507320", + "commitHash": "344117638c8ff7e239044fd0fa7085839fc03021", "repositoryUrl": "https://github.com/google/benchmark.git" }, "comments": "google_benchmark" @@ -196,7 +206,17 @@ "component": { "type": "git", "git": { - "commitHash": "a43ce67187bab219520fd80f21af8bbd4354bc8c", + "commitHash": "150e7527d5286ddd3a995c228dedf8d76a7a86bc", + "repositoryUrl": "https://github.com/intel/neural-speed.git" + }, + "comments": "neural_speed" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "bacfaaa951653cd4e72efe727a543567cb38f7de", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" @@ -321,6 +341,16 @@ }, "comments": "composable_kernel" } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "de28d93dfa9ebf3e473127c1c657e1920a5345ee", + "repositoryUrl": "https://github.com/microsoft/DirectX-Headers.git" + }, + "comments": "directx_headers" + } } ] } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 73c974f20c25..87355c94223a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -76,9 +76,10 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) # Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through # OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical # use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. -cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) +cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF) option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) +option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -87,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -96,7 +97,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) @@ -117,9 +117,7 @@ option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) -option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF) @@ -131,6 +129,7 @@ option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF) +option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF) option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF) option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON) option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON) @@ -324,17 +323,29 @@ if (onnxruntime_USE_ROCM) endif() # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173 - file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - if (ROCM_VERSION_DEV_MATCH) + # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # with modification + if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version") + file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + endif() + + if (ROCM_VERSION_MATCH) set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + else() + message(FATAL_ERROR "Cannot determine ROCm version string") endif() - message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n") + message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") @@ -354,13 +365,7 @@ if (onnxruntime_USE_ROCM) endif() endif() -if (APPLE) - if (NOT CMAKE_OSX_ARCHITECTURES) - message("Building ONNX Runtime for ${CMAKE_HOST_SYSTEM_PROCESSOR}") - endif() -elseif (NOT WIN32 AND NOT APPLE) - message("Building ONNX Runtime for ${CMAKE_SYSTEM_PROCESSOR}") -endif() + # Single output director for all binaries set(RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE PATH "Single output directory for all binaries.") @@ -493,6 +498,14 @@ endif() include(adjust_global_compile_flags.cmake) +if (APPLE) + if (NOT CMAKE_OSX_ARCHITECTURES) + message("Building ONNX Runtime for ${CMAKE_HOST_SYSTEM_PROCESSOR} CPU ARCH") + endif() +elseif (NOT WIN32 AND NOT APPLE) + message("Building ONNX Runtime for ${onnxruntime_target_platform} CPU ARCH") +endif() + # We need to link with libatomic on systems that do not have built-in atomics, or # don't have built-in support for 8 byte atomics # Derived from https://github.com/protocolbuffers/protobuf/blob/master/cmake/CMakeLists.txt @@ -638,8 +651,18 @@ else() check_cxx_compiler_flag(-Wunused-but-set-variable HAS_UNUSED_BUT_SET_VARIABLE) check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) + check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) check_function_exists(reallocarray HAS_REALLOCARRAY) - + if (NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_target_platform STREQUAL "aarch64") + check_cxx_compiler_flag(-march=armv8.2-a+bf16 HAS_ARM64_BFLOAT16) + if(NOT HAS_ARM64_BFLOAT16) + message(FATAL_ERROR "The compiler doesn't support BFLOAT16!!!") + endif() + check_cxx_compiler_flag(-march=armv8.2-a+fp16 HAS_ARM64_FLOAT16) + if(NOT HAS_ARM64_FLOAT16) + message(FATAL_ERROR "The compiler doesn't support FLOAT16!!!") + endif() + endif() if (HAS_TAUTOLOGICAL_POINTER_COMPARE) #we may have extra null pointer checkings in debug build, it's not an issue list(APPEND ORT_WARNING_FLAGS -Wno-tautological-pointer-compare) @@ -694,20 +717,19 @@ if (onnxruntime_USE_CUDA) enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") + if (onnxruntime_DISABLE_CONTRIB_OPS) + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) - message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") - set(onnxruntime_USE_CUTLASS OFF) + message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) + message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") endif() else() - set(onnxruntime_USE_CUTLASS OFF) -endif() - -if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS) - if (onnxruntime_DISABLE_CONTRIB_OPS) - message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") - else() - message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled") - endif() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -727,8 +749,8 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - endif() + if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1) @@ -769,6 +791,38 @@ if (onnxruntime_USE_QNN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_QNN=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_QNN=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES qnn) + if (NOT QNN_ARCH_ABI) + string(TOLOWER ${onnxruntime_target_platform} GEN_PLATFORM) + if(MSVC) + message(STATUS "Building MSVC for architecture ${CMAKE_SYSTEM_PROCESSOR} with CMAKE_GENERATOR_PLATFORM as ${GEN_PLATFORM}") + if (${GEN_PLATFORM} STREQUAL "arm64") + set(QNN_ARCH_ABI aarch64-windows-msvc) + else() + set(QNN_ARCH_ABI x86_64-windows-msvc) + endif() + else() + if (${CMAKE_SYSTEM_NAME} STREQUAL "Android") + set(QNN_ARCH_ABI aarch64-android-clang6.0) + elseif (${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + if (${GEN_PLATFORM} STREQUAL "x86_64") + set(QNN_ARCH_ABI x86_64-linux-clang) + else() + set(QNN_ARCH_ABI aarch64-android) + endif() + endif() + endif() + endif() + + if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/libQnn*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/Qnn*.dll") + if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc") + file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so" + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat") + list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB}) + endif() + message(STATUS "QNN lib files: " ${QNN_LIB_FILES}) + endif() endif() if (onnxruntime_USE_SNPE) list(APPEND ORT_PROVIDER_FLAGS -DUSE_SNPE=1) @@ -893,8 +947,8 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() - if (onnxruntime_USE_CUTLASS) - target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) + if(USE_NEURAL_SPEED) + target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) endif() set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) @@ -976,9 +1030,12 @@ function(onnxruntime_set_compile_flags target_name) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") endforeach() - if ((NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") OR (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda")) + if (NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") endif() + if (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda") + target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") + endif() endif() if (onnxruntime_USE_ROCM) # flags are detected with CXX language mode, some flags are not supported with hipclang @@ -1099,7 +1156,7 @@ function(onnxruntime_add_include_to_target dst_target) endfunction() # ACL -if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002) +if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308) set(onnxruntime_USE_ACL ON) if (onnxruntime_USE_ACL_1902) add_definitions(-DACL_1902=1) @@ -1110,7 +1167,11 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 if (onnxruntime_USE_ACL_2002) add_definitions(-DACL_2002=1) else() - add_definitions(-DACL_1905=1) + if (onnxruntime_USE_ACL_2308) + add_definitions(-DACL_2308=1) + else() + add_definitions(-DACL_1905=1) + endif() endif() endif() endif() @@ -1177,14 +1238,10 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -set(USE_JBLAS FALSE) -if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD) - if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) - elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) +if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_TVM) + include(neural_speed) + if (USE_NEURAL_SPEED) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) endif() endif() @@ -1228,17 +1285,15 @@ if (onnxruntime_USE_TVM) $) set(onnxruntime_tvm_libs onnxruntime_providers_tvm) - - # needs to link with stdc++fs in Linux - if (UNIX) - if (NOT APPLE) - set(FS_STDLIB stdc++fs) - endif() - endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm ${FS_STDLIB}) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm) list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm) endif() +# needs to link with stdc++fs in Linux +if (UNIX AND "${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 9) + set(FS_STDLIB stdc++fs) +endif() +list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${FS_STDLIB}) # onnxruntime-extensions if (onnxruntime_USE_EXTENSIONS) @@ -1248,11 +1303,7 @@ endif() #Dependencies end. In the next we'll enable "treat warning as error" #Adjust warning flags -if (onnxruntime_USE_CUDA) - set_msvc_c_cpp_compiler_warning_level(3) -else() - set_msvc_c_cpp_compiler_warning_level(4) -endif() +set_msvc_c_cpp_compiler_warning_level(4) set(onnxruntime_DELAYLOAD_FLAGS "") @@ -1271,34 +1322,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DUSE_OPENVINO=1) - if (EXISTS "$ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/version.txt") - file(READ $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/version.txt VER) - endif() - - if (NOT DEFINED ENV{INTEL_OPENVINO_DIR}) - message(FATAL_ERROR "[Couldn't locate OpenVINO] OpenVINO may not have been initialized") - endif() - - # Check OpenVINO version for support - if ($ENV{INTEL_OPENVINO_DIR} MATCHES "2022.3") - set(OPENVINO_VERSION "2022.3") - add_definitions(-DOPENVINO_2022_3=1) - elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.0") - set(OPENVINO_VERSION "2023.0") - add_definitions(-DOPENVINO_2023_0=1) - elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.1") - set(OPENVINO_VERSION "2023.1") - add_definitions(-DOPENVINO_2023_1=1) - elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.2") - set(OPENVINO_VERSION "2023.2") - add_definitions(-DOPENVINO_2023_1=1) - elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino") - set(OPENVINO_VERSION "2023.2") - add_definitions(-DOPENVINO_2023_2=1) - else() - message(FATAL_ERROR "Unsupported OpenVINO version: ${INTEL_OPENVINO_DIR}") - endif() - if (onnxruntime_USE_OPENVINO_GPU_FP32) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) endif() @@ -1315,6 +1338,10 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() + if (onnxruntime_USE_OPENVINO_NPU) + add_definitions(-DOPENVINO_CONFIG_NPU=1) + endif() + if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1335,6 +1362,11 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() + if (onnxruntime_USE_OPENVINO_NPU_NP) + add_definitions(-DOPENVINO_CONFIG_NPU=1) + add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) + endif() + if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") @@ -1389,6 +1421,10 @@ endif() if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) + if(onnxruntime_CUDA_HOME) + file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) + endif() + find_package(CUDAToolkit REQUIRED) if(onnxruntime_CUDNN_HOME) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) endif() @@ -1430,6 +1466,11 @@ if (onnxruntime_USE_CUDA) if (NOT WIN32) list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC) endif() + if(MSVC) + if(CUDA_NVCC_FLAGS MATCHES "Zi") + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "-FS") + endif() + endif() # Options passed to cudafe set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=bad_friend_decl\"") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_compare_with_zero\"") @@ -1589,7 +1630,7 @@ if (UNIX AND onnxruntime_USE_NCCL) else() set(onnxruntime_USE_NCCL OFF) set(onnxruntime_USE_MPI OFF) -message( WARNING "MPI and NCCL disabled on Win build." ) + message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." ) endif() if (onnxruntime_USE_MPI) @@ -1718,14 +1759,12 @@ if(onnxruntime_BUILD_KERNEL_EXPLORER) endif() # When GDK_PLATFORM is set then WINAPI_FAMILY is defined in gdk_toolchain.cmake (along with other relevant flags/definitions). -if (WIN32 AND NOT GDK_PLATFORM) +if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) if (NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) # On onecore, link to the onecore build of the MSVC runtime get_filename_component(msvc_path "${CMAKE_C_COMPILER}/../../../.." ABSOLUTE) link_directories(BEFORE "${msvc_path}/lib/onecore/${onnxruntime_target_platform}") - # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, which in turn links to reverse forwarders. - # We ignore that entry and use onecore_apiset.lib instead, since system components must not rely on reverse forwarders. - add_link_options("/NODEFAULTLIB:onecore.lib") + # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, but it shold not cause any conflict with onecoreuap.lib endif() endif() diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 3085beb37927..74d6418ac541 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -8,6 +8,15 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android") string(APPEND CMAKE_ASM_FLAGS_RELEASE " -O3") endif() +# Suggested by https://gitlab.kitware.com/cmake/cmake/-/issues/20132 +# MacCatalyst is not well supported in CMake +# The error that can emerge without this flag can look like: +# "clang : error : overriding '-mmacosx-version-min=11.0' option with '-target x86_64-apple-ios14.0-macabi' [-Werror,-Woverriding-t-option]" +if (PLATFORM_NAME STREQUAL "macabi") + add_compile_options(-Wno-overriding-t-option) + add_link_options(-Wno-overriding-t-option) +endif() + # Enable space optimization for gcc/clang # Cannot use "-ffunction-sections -fdata-sections" if we enable bitcode (iOS) if (NOT MSVC AND NOT onnxruntime_ENABLE_BITCODE) @@ -16,9 +25,7 @@ if (NOT MSVC AND NOT onnxruntime_ENABLE_BITCODE) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - string(APPEND CMAKE_C_FLAGS " -s STRICT=1 -s DEFAULT_TO_CXX=1") - string(APPEND CMAKE_CXX_FLAGS " -s STRICT=1 -s DEFAULT_TO_CXX=1") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -s ALLOW_UNIMPLEMENTED_SYSCALLS=1") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -s ALLOW_UNIMPLEMENTED_SYSCALLS=1 -s DEFAULT_TO_CXX=1") # Enable LTO for release single-thread build if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") @@ -74,11 +81,6 @@ if (onnxruntime_MINIMAL_BUILD) endif() if (MSVC) - # turn on LTO (which adds some compiler flags and turns on LTCG) unless it's a Debug build to minimize binary size - if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") - set(onnxruntime_ENABLE_LTO ON) - endif() - # undocumented internal flag to allow analysis of a minimal build binary size if (ADD_DEBUG_INFO_TO_MINIMAL_BUILD) string(APPEND CMAKE_CXX_FLAGS " /Zi") @@ -99,7 +101,7 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# enable stream for all the non-minimal build +# Enable stream for all the non-minimal build if (NOT onnxruntime_MINIMAL_BUILD) add_compile_definitions(ORT_ENABLE_STREAM) endif() @@ -130,6 +132,11 @@ if (onnxruntime_DISABLE_RTTI) add_compile_options("$<$:/GR->" "$<$:/we4541>") else() add_compile_options("$<$:-fno-rtti>") + if (onnxruntime_USE_WEBNN) + # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled + # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/7001 + add_compile_options("$<$:-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0>") + endif() endif() else() #MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on. @@ -207,7 +214,7 @@ endif() macro(check_nvcc_compiler_flag _FLAG _RESULT) - execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) + execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) message("NVCC_ERROR = ${NVCC_ERROR}") message("NVCC_OUT = ${NVCC_OUT}") if ("${NVCC_OUT}" MATCHES "0") @@ -267,39 +274,38 @@ if (MSVC) string(APPEND CMAKE_C_FLAGS " /arch:AVX512") endif() - if (NOT GDK_PLATFORM) - add_compile_definitions(WINAPI_FAMILY=100) # Desktop app - message("Building ONNX Runtime for Windows 10 and newer") - add_compile_definitions(WINVER=0x0A00 _WIN32_WINNT=0x0A00 NTDDI_VERSION=0x0A000000) - endif() if (onnxruntime_ENABLE_LTO AND NOT onnxruntime_USE_CUDA) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Gw /GL") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /Gw /GL") set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /Gw /GL") endif() - - # The WinML build tool chain builds ARM/ARM64, and the internal tool chain does not have folders for spectre mitigation libs. - # WinML performs spectre mitigation differently. - if (NOT DEFINED onnxruntime_DISABLE_QSPECTRE_CHECK) - check_cxx_compiler_flag(-Qspectre HAS_QSPECTRE) - if (HAS_QSPECTRE) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Qspectre") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Qspectre") - endif() - endif() - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DYNAMICBASE") - check_cxx_compiler_flag(-guard:cf HAS_GUARD_CF) - if (HAS_GUARD_CF) - set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /guard:cf") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /guard:cf") - set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} /guard:cf") - set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /guard:cf") - set(CMAKE_C_FLAGS_MINSIZEREL "${CMAKE_C_FLAGS_MINSIZEREL} /guard:cf") - set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /guard:cf") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /guard:cf") - endif() else() if (NOT APPLE) + #XXX: Sometimes the value of CMAKE_SYSTEM_PROCESSOR is set but it's wrong. For example, if you run an armv7 docker + #image on an aarch64 machine with an aarch64 Ubuntu host OS, in the docker instance cmake may still report + # CMAKE_SYSTEM_PROCESSOR as aarch64 by default. Given compiling this code may need more than 2GB memory, we do not + # support compiling for ARM32 natively(only support cross-compiling), we will ignore this issue for now. + if(NOT CMAKE_SYSTEM_PROCESSOR) + message(WARNING "CMAKE_SYSTEM_PROCESSOR is not set. Please set it in your toolchain cmake file.") + # Try to detect it + if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" OR "${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") + execute_process( + COMMAND "${CMAKE_C_COMPILER}" -dumpmachine + OUTPUT_VARIABLE GCC_DUMP_MACHINE_OUT OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_VARIABLE _err + RESULT_VARIABLE _res + ) + if(NOT _res EQUAL 0) + message(SEND_ERROR "Failed to run 'gcc -dumpmachine':\n ${_res}") + endif() + string(REPLACE "-" ";" GCC_DUMP_MACHINE_OUT_LIST "${GCC_DUMP_MACHINE_OUT}") + list(LENGTH GCC_DUMP_MACHINE_OUT_LIST GCC_TRIPLET_LEN) + if(GCC_TRIPLET_LEN EQUAL 4) + list(GET GCC_DUMP_MACHINE_OUT_LIST 0 CMAKE_SYSTEM_PROCESSOR) + message("Setting CMAKE_SYSTEM_PROCESSOR to ${CMAKE_SYSTEM_PROCESSOR}") + endif() + endif() + endif() set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) endif() if (onnxruntime_BUILD_FOR_NATIVE_MACHINE) @@ -353,16 +359,9 @@ else() endif() -if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - #For Mac compliance - message("Adding flags for Mac builds") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-strong") -elseif (WIN32) - # parallel build - # These compiler opitions cannot be forwarded to NVCC, so cannot use add_compiler_options - string(APPEND CMAKE_CXX_FLAGS " /MP") +if (WIN32) # required to be set explicitly to enable Eigen-Unsupported SpecialFunctions string(APPEND CMAKE_CXX_FLAGS " -DEIGEN_HAS_C99_MATH") -else() +elseif(LINUX) add_compile_definitions("_GNU_SOURCE") endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index ff0780301307..d0f455167168 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,8 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 +coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -22,10 +23,10 @@ dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b3132 # Until the 3.4.1 release this is the best option we have. # Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744 eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a -flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf +flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 -google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 +google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.3.zip;bf9870756ee3f8d2d3b346b24ee3600a41c74d3d google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 @@ -34,9 +35,10 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 -#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 +neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/v0.3.zip;5ec64e3071edc7347ebd8a81679cf06e2bb9b851 +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.16.0.zip;a6d8b619459fb4657f8bec7d1c6d95ad6d4c069d +#use the commit of Final DDS removal. DDS output is now supported by ORT TRT. +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bacfaaa951653cd4e72efe727a543567cb38f7de.zip;26434329612e804164ab7baa6ae629ada56c1b26 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 @@ -55,3 +57,4 @@ cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a79 utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 +directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e \ No newline at end of file diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py index d357284d9122..63df3f6f0386 100644 --- a/cmake/deps_update_and_upload.py +++ b/cmake/deps_update_and_upload.py @@ -1,56 +1,109 @@ -# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. -# Before running the script, increase the version number found at: +# If deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# +# Before running the script, find the latest version number at: # https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Increment it to obtain a new version number to use. +# # Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. -# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload -# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +# E.g.: +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 +# # check contents of C:/temp/onnxruntime_deps +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --no-download --do-upload +# +# Next, update the version number in tools/ci_build/github/azure-pipelines/templates/download-deps.yml. + +import argparse +import contextlib +import pathlib import re import subprocess -import os -import argparse import tempfile +script_dir = pathlib.Path(__file__).parent + parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") parser.add_argument( - "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" + "--root-path", + type=pathlib.Path, + help="Target root path for downloaded files. If not provided, a temporary directory is used.", +) +parser.add_argument( + "--version", + type=str, + help="Package version to publish", +) +parser.add_argument( + "--do-upload", + action="store_true", + dest="upload", + help="Upload the package to Azure Artifacts", +) +parser.add_argument( + "--no-download", + action="store_false", + dest="download", + help="Skip downloading the dependency files. " + "Use with '--do-upload' and '--root-path' to upload the package from existing dependency files.", ) -parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") -parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") args = parser.parse_args() -with open("cmake/deps.txt") as file: +if args.upload: + assert args.version is not None, "'--version' must be specified if uploading." + +if args.upload != args.download: + assert args.root_path is not None, "'--root-path' must be specified if only downloading or uploading." + +deps_path = script_dir / "deps.txt" +with open(deps_path) as file: text = file.read() lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] -root_path = args.root_path - -for line in lines: - url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) - filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) - full_path = os.path.join(root_path, filename) - subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) # noqa: PLW1510 - -package_name = "onnxruntime_build_dependencies" -version = args.version - -# Check if the user is logged in to Azure -result = subprocess.run("az account show", shell=True, capture_output=True, text=True) # noqa: PLW1510 -if "No subscriptions found" in result.stderr: - # Prompt the user to log in to Azure - print("You are not logged in to Azure. Please log in to continue.") - subprocess.run("az login", shell=True) # noqa: PLW1510 - -# Publish the package to Azure Artifacts if --no-upload is not specified - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) +with contextlib.ExitStack() as context_stack: + if args.root_path is not None: + root_path = args.root_path.resolve() + root_path.mkdir(parents=True, exist_ok=True) + else: + temp_dir_name = context_stack.enter_context(tempfile.TemporaryDirectory()) + root_path = pathlib.Path(temp_dir_name) + + if args.download: + print(f"Downloading dependencies to directory: {root_path}") + + dep_pattern = re.compile(r"^[^;]+;https://([^;]+);.*$") + + for line in lines: + match = dep_pattern.fullmatch(line) + if match is None: + continue + + dep_path = match[1] + url = f"https://{dep_path}" + full_path = root_path / dep_path + + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", str(full_path), url], check=True) + + package_name = "onnxruntime_build_dependencies" + version = args.version if args.version is not None else "VERSION_PLACEHOLDER" + + if args.upload: + # Check if the user is logged in to Azure + result = subprocess.run("az account show", shell=True, capture_output=True, text=True, check=False) + if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True, check=True) + + # Publish the package to Azure Artifacts if --do-upload is specified + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3bcd4109e288..57cfbee4644e 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -19,7 +19,7 @@ if(WIN32 AND NOT Patch_FOUND) set(ABSL_ENABLE_INSTALL ON) endif() # NB! Advancing Abseil version changes its internal namespace, -# currently absl::lts_20230125 which affects abseil-cpp.natvis debugger +# currently absl::lts_20240116 which affects abseil-cpp.natvis debugger # visualization file, that must be adjusted accordingly, unless we eliminate # that namespace at build time. FetchContent_Declare( diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 1e5a36fb9efb..a4fb63b6a837 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -1,6 +1,6 @@ - + @@ -24,7 +24,7 @@ - + @@ -51,7 +51,7 @@ - + *($T1 *){value} (*($T1 *){value}) @@ -60,7 +60,7 @@ - + *($T1 *)this (*($T1 *)this) @@ -68,7 +68,7 @@ - + {value.first}, {value.second} ({value.first}, {value.second}) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index efc708bd681c..f04f4bec76cd 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,13 +1,11 @@ -if (onnxruntime_USE_CUTLASS) - include(FetchContent) - FetchContent_Declare( - cutlass - URL ${DEP_URL_cutlass} - URL_HASH SHA1=${DEP_SHA1_cutlass} - ) +include(FetchContent) +FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} +) - FetchContent_GetProperties(cutlass) - if(NOT cutlass_POPULATED) - FetchContent_Populate(cutlass) - endif() +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) endif() diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index 5d25b9529e03..8f18059ffdfe 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.1) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.1) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( @@ -72,12 +72,11 @@ else() if (dml_EXTERNAL_PROJECT) set(dml_preset_config $,debug,release>) set(dml_preset_name ${onnxruntime_target_platform}-win-redist-${dml_preset_config}) - target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1) include(ExternalProject) ExternalProject_Add( directml_repo GIT_REPOSITORY https://dev.azure.com/microsoft/WindowsAI/_git/DirectML - GIT_TAG d460f0f46967bea878786f1bed69487692c779bf + GIT_TAG a5312f72c51864b4d705ac62d25d08bcd88c4fb1 GIT_SHALLOW OFF # not allowed when GIT_TAG is a commit SHA, which is preferred (it's stable, unlike branches) GIT_PROGRESS ON BUILD_IN_SOURCE ON @@ -94,8 +93,20 @@ else() target_link_libraries(DirectML INTERFACE ${directml_install_path}/lib/DirectML.lib) add_dependencies(DirectML directml_repo-install) include_directories(BEFORE ${directml_install_path}/include) + target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1) else() include_directories(BEFORE ${dml_INCLUDE_DIR}) set(DML_PACKAGE_DIR ${dml_INCLUDE_DIR}/..) endif() endif() + +FetchContent_Declare( + directx_headers + URL ${DEP_URL_directx_headers} + URL_HASH SHA1=${DEP_SHA1_directx_headers} +) + +FetchContent_Populate(directx_headers) +set(directx_headers_INCLUDE_DIRS "${directx_headers_SOURCE_DIR}/include") + +include_directories(BEFORE ${directx_headers_INCLUDE_DIRS}) diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake index d7b70640781d..9eb5fed7a1af 100644 --- a/cmake/external/dnnl.cmake +++ b/cmake/external/dnnl.cmake @@ -2,7 +2,7 @@ include (ExternalProject) set(DNNL_URL https://github.com/oneapi-src/onednn.git) # If DNNL_TAG is updated, check if MKLML_VERSION and platform.cmake.patch need to be updated. -set(DNNL_TAG v3.0) +set(DNNL_TAG v3.0.1) if(WIN32) set(DNNL_SHARED_LIB dnnl.dll) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index a896e3d06644..4e2496141eda 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit a896e3d066448b3530dbcaa48869fafefd738f57 +Subproject commit 4e2496141eda15040c44e9bbf237a1326368e34c diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake new file mode 100644 index 000000000000..3fe9c660f89d --- /dev/null +++ b/cmake/external/neural_speed.cmake @@ -0,0 +1,16 @@ +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + set(USE_NEURAL_SPEED TRUE) +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + set(USE_NEURAL_SPEED TRUE) +endif() + +if(USE_NEURAL_SPEED) + FetchContent_Declare( + neural_speed + URL ${DEP_URL_neural_speed} + URL_HASH SHA1=${DEP_SHA1_neural_speed} + PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch + ) + set(BTLA_USE_OPENMP OFF) + onnxruntime_fetchcontent_makeavailable(neural_speed) +endif() diff --git a/cmake/external/onnx b/cmake/external/onnx index b86cc54efce1..990217f043af 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c +Subproject commit 990217f043af7222348ca8f0301e17fa7b841781 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 78f63227c839..8839dbc8fda4 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -14,6 +14,16 @@ foreach(ONNXRUNTIME_DEP IN LISTS ONNXRUNTIME_DEPS_LIST) set(DEP_URL_${ONNXRUNTIME_DEP_NAME} ${ONNXRUNTIME_DEP_URL}) # The third column is SHA1 hash value set(DEP_SHA1_${ONNXRUNTIME_DEP_NAME} ${ONNXRUNTIME_DEP}) + + if(ONNXRUNTIME_DEP_URL MATCHES "^https://") + # Search a local mirror folder + string(REGEX REPLACE "^https://" "${REPO_ROOT}/mirror/" LOCAL_URL "${ONNXRUNTIME_DEP_URL}") + + if(EXISTS "${LOCAL_URL}") + cmake_path(ABSOLUTE_PATH LOCAL_URL) + set(DEP_URL_${ONNXRUNTIME_DEP_NAME} "${LOCAL_URL}") + endif() + endif() endif() endforeach() @@ -37,8 +47,13 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") - # Needs to update onnxruntime/test/xctest/xcgtest.mm + if (IOS OR ANDROID) + # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing + # any args to gtest executables, such as using --gtest_filter to debug a specific test. + # Processing of compile definitions: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21 + # If set, this code throws away the flag and does nothing on registration, which results in no flags being known: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217 set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) else() set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) @@ -104,45 +119,18 @@ FetchContent_Declare( URL ${DEP_URL_flatbuffers} URL_HASH SHA1=${DEP_SHA1_flatbuffers} PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND} - FIND_PACKAGE_ARGS 1.12.0...<2.0.0 NAMES Flatbuffers + FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers ) # Download a protoc binary from Internet if needed -if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) +if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually # download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE # variable. - message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") - if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") - if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") - if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") + if (CMAKE_HOST_APPLE) + # Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices. + # https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html + # To keep it simple, just download and use the universal protoc binary for all Apple host builds. FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) FetchContent_Populate(protoc_binary) if(protoc_binary_SOURCE_DIR) @@ -150,6 +138,38 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() + elseif (CMAKE_CROSSCOMPILING) + message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") + if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + endif() endif() endif() @@ -184,9 +204,9 @@ FetchContent_Declare( ) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) -#TODO: we'd better to turn the following option off. However, it will cause +#TODO: we'd better to turn the following option off. However, it will cause # ".\build.bat --config Debug --parallel --skip_submodule_sync --update" fail with an error message: -# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is +# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is # not in any export set. #set(protobuf_INSTALL OFF CACHE BOOL "Install protobuf binaries and files" FORCE) set(protobuf_USE_EXTERNAL_GTEST ON CACHE BOOL "" FORCE) @@ -219,8 +239,6 @@ FetchContent_Declare( URL_HASH SHA1=${DEP_SHA1_mp11} ) -set(JSON_BuildTests OFF CACHE INTERNAL "") -set(JSON_Install OFF CACHE INTERNAL "") set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_Install OFF CACHE INTERNAL "") @@ -253,14 +271,7 @@ if (onnxruntime_ENABLE_CPUINFO) set(CPUINFO_SUPPORTED TRUE) endif() if (WIN32) - # Exclude Windows ARM build and Windows Store - if (${onnxruntime_target_platform} MATCHES "^(ARM.*|arm.*)$" ) - message(WARNING "Cpuinfo not included for compilation problems with Windows ARM.") - set(CPUINFO_SUPPORTED FALSE) - elseif (WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) - message(WARNING "Cpuinfo not included non-Desktop builds") - set(CPUINFO_SUPPORTED FALSE) - endif() + set(CPUINFO_SUPPORTED TRUE) elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " @@ -304,13 +315,23 @@ if (CPUINFO_SUPPORTED) set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") - - FetchContent_Declare( - pytorch_cpuinfo - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - FIND_PACKAGE_ARGS NAMES cpuinfo - ) + if(onnxruntime_target_platform STREQUAL "ARM64EC") + message("Applying a patch for Windows ARM64EC in cpuinfo") + FetchContent_Declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch + FIND_PACKAGE_ARGS NAMES cpuinfo + ) + else() + FetchContent_Declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + FIND_PACKAGE_ARGS NAMES cpuinfo + ) + endif() set(ONNXRUNTIME_CPUINFO_PROJ pytorch_cpuinfo) endif() @@ -536,22 +557,32 @@ if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxrunt onnxruntime_fetchcontent_makeavailable(cxxopts) endif() +if (onnxruntime_USE_COREML) + FetchContent_Declare( + coremltools + URL ${DEP_URL_coremltools} + URL_HASH SHA1=${DEP_SHA1_coremltools} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/coremltools/crossplatformbuild.patch + ) + # we don't build directly so use Populate. selected files are built from onnxruntime_providers_coreml.cmake + FetchContent_Populate(coremltools) +endif() + message("Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) if (onnxruntime_USE_CUDA) #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same + find_package(CUDAToolkit REQUIRED) if (WIN32) if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64) else() if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64) endif() endif() @@ -562,4 +593,3 @@ endif() FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) - diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index e661aa51bfc1..41f02ce6f22b 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -6,10 +6,14 @@ set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "") +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(XNNPACK_USE_SYSTEM_LIBS OFF) +endif() + # BF16 instructions cause ICE in Android NDK compiler if(CMAKE_ANDROID_ARCH_ABI STREQUAL armeabi-v7a) set(XNNPACK_ENABLE_ARM_BF16 OFF) -ENDIF() +endif() # fp16 depends on psimd FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) diff --git a/cmake/maccatalyst_prepare_objects_for_prelink.py b/cmake/maccatalyst_prepare_objects_for_prelink.py new file mode 100644 index 000000000000..34664b4e0523 --- /dev/null +++ b/cmake/maccatalyst_prepare_objects_for_prelink.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +import sys + + +# Note: This script is mainly used for sanity checking/validating the files in the .a library equal to the .o files +# in the source dir to handle the case of source files having duplicate names under different subdirectories for +# each onnxruntime library. (Only applicable when doing a Mac Catalyst build.) +def main(): + source_dir = sys.argv[1] + dest_dir = sys.argv[2] + files_from_static_lib = sys.argv[3] + files_from_source_dir = [] + for subdir, _, files in os.walk(source_dir): + for file_name in files: + if file_name.endswith(".o"): + files_from_source_dir.append(file_name.strip()) + dest_name_without_extension, _ = os.path.splitext(file_name) + counter = 0 + + dest_file = f"{dest_name_without_extension}.o" + while os.path.exists(os.path.join(dest_dir, dest_file)): + print("Duplicate file name from source: " + os.path.join(source_dir, subdir, file_name)) + counter += 1 + dest_file = f"{dest_name_without_extension}_{counter}.o" + print("Renamed file name in destination: " + os.path.join(dest_dir, dest_file)) + + destination_path = os.path.join(dest_dir, dest_file) + source_file = os.path.join(source_dir, subdir, file_name) + shutil.copy(source_file, destination_path) + + # Sanity check to ensure the number of .o object from the original cmake source directory matches with the number + # of .o files extracted from each .a onnxruntime library + file_lists_from_static_lib = [] + with open(files_from_static_lib) as file: + filenames = file.readlines() + for filename in filenames: + file_lists_from_static_lib.append(filename.strip()) + + sorted_list1 = sorted(file_lists_from_static_lib) + sorted_list2 = sorted(files_from_source_dir) + + if len(sorted_list1) != len(sorted_list2): + print( + "Caught a mismatch in the number of .o object files from the original cmake source directory: ", + len(sorted_list1), + "the number of .o files extracted from the static onnxruntime lib: ", + len(sorted_list2), + "for: ", + os.path.basename(source_dir), + ) + + if sorted_list1 == sorted_list2: + print( + "Sanity check passed: object files from original source directory matches with files extracted " + "from static library for: ", + os.path.basename(source_dir), + ) + else: + print( + "Error: Mismatch between object files from original source directory " + "and the .o files extracted from static library for: ", + os.path.basename(source_dir), + ) + + +if __name__ == "__main__": + main() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index c900f4d4b09a..e15c8a046dc2 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -189,7 +189,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_SNPE} ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} - ${PROVIDERS_VITISAI} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} @@ -282,7 +281,13 @@ endif() # Assemble the Apple static framework (iOS and macOS) if(onnxruntime_BUILD_APPLE_FRAMEWORK) - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) + # when building for mac catalyst, the CMAKE_OSX_SYSROOT is set to MacOSX as well, to avoid duplication, + # we specify as `-macabi` in the name of the output static apple framework directory. + if (PLATFORM_NAME STREQUAL "macabi") + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-macabi) + else() + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) + endif() # Setup the various directories required. Remove any existing ones so we start with a clean directory. set(STATIC_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/static_libraries) @@ -300,18 +305,34 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # to enforce symbol visibility. doing it this way limits the symbols included from the .a files to symbols used # by the ORT .o files. - # If it's an onnxruntime library, extract .o files to a separate directory for each library to avoid any clashes - # with filenames (e.g. utils.o) + # If it's an onnxruntime library, extract .o files from the original cmake build path to a separate directory for + # each library to avoid any clashes with filenames (e.g. utils.o) foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ) GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") set(CUR_STATIC_LIB_OBJ_DIR ${STATIC_LIB_TEMP_DIR}/$) add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${CUR_STATIC_LIB_OBJ_DIR}) - - add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ar ARGS -x $ - WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) + if (PLATFORM_NAME STREQUAL "macabi") + # There exists several duplicate names for source files under different subdirectories within + # each onnxruntime library. (e.g. onnxruntime/contrib_ops/cpu/element_wise_ops.o + # vs. onnxruntime/providers/core/cpu/math/element_wise_ops.o) + # In that case, using 'ar ARGS -x' to extract the .o files from .a lib would possibly cause duplicate naming files being overwritten + # and lead to missing undefined symbol error in the generated binary. + # So we use the below python script as a sanity check to do a recursive find of all .o files in ${CUR_TARGET_CMAKE_SOURCE_LIB_DIR} + # and verifies that matches the content of the .a, and then copy from the source dir. + # TODO: The copying action here isn't really necessary. For future fix, consider using the script extracts from the ar with the rename to potentially + # make both maccatalyst and other builds do the same thing. + set(CUR_TARGET_CMAKE_SOURCE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${_LIB}.dir) + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ar -t $ | grep "\.o$" > ${_LIB}.object_file_list.txt + COMMAND ${CMAKE_COMMAND} -E env python3 ${CMAKE_CURRENT_SOURCE_DIR}/maccatalyst_prepare_objects_for_prelink.py ${CUR_TARGET_CMAKE_SOURCE_LIB_DIR} ${CUR_STATIC_LIB_OBJ_DIR} ${CUR_STATIC_LIB_OBJ_DIR}/${_LIB}.object_file_list.txt + WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) + else() + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ar ARGS -x $ + WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) + endif() endif() endforeach() diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 43d5fa9bdee3..69d8f5fa138c 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -129,7 +129,7 @@ target_include_directories(onnxruntime_common ${OPTIONAL_LITE_INCLUDE_DIR}) -target_link_libraries(onnxruntime_common PUBLIC safeint_interface ${GSL_TARGET} ${ABSEIL_LIBS}) +target_link_libraries(onnxruntime_common PUBLIC safeint_interface ${GSL_TARGET} ${ABSEIL_LIBS} date::date) add_dependencies(onnxruntime_common ${onnxruntime_EXTERNAL_DEPENDENCIES}) @@ -189,6 +189,8 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(ARM TRUE) elseif(dumpmachine_output MATCHES "^aarch64.*") set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(RISCV64 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -198,11 +200,7 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() -if (ARM64 OR ARM OR X86 OR X64 OR X86_64) - if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) - # msvc compiler report syntax error with cpuinfo arm source files - # and cpuinfo does not have code for getting arm uarch info under windows - else() +if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) # Link cpuinfo if supported # Using it mainly in ARM with Android. # Its functionality in detecting x86 cpu features are lacking, so is support for Windows. @@ -210,7 +208,6 @@ if (ARM64 OR ARM OR X86 OR X64 OR X86_64) onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) endif() - endif() endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 3f532ec2c326..4d51325b8414 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -7,8 +7,26 @@ file(GLOB_RECURSE onnxruntime_graph_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/graph/*.cc" ) -# create empty list for any excludes +# start with empty training srcs list +set(orttraining_graph_src) + +if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) + set(orttraining_graph_src + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc" + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" + ) +endif() + +if (onnxruntime_ENABLE_TRAINING) + file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/core/graph/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc" + ) +endif() + +# create empty lists for any excludes set(onnxruntime_graph_src_exclude_patterns) +set(orttraining_graph_src_exclude_patterns) if (onnxruntime_MINIMAL_BUILD) # remove schema registration support @@ -22,11 +40,18 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.h" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.cc" + "${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.h" + "${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.cc" "${ONNXRUNTIME_ROOT}/core/graph/function_template.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.cc" ) + list(APPEND orttraining_graph_src_exclude_patterns + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc" + ) + # no Function support initially list(APPEND onnxruntime_graph_src_exclude_patterns "${ONNXRUNTIME_ROOT}/core/graph/function*" @@ -64,30 +89,12 @@ endif() file(GLOB onnxruntime_graph_src_exclude ${onnxruntime_graph_src_exclude_patterns}) list(REMOVE_ITEM onnxruntime_graph_src ${onnxruntime_graph_src_exclude}) -file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/defs/*.cc" -) - -if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) - set(orttraining_graph_src - "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc" - "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" - ) -endif() - -if (onnxruntime_ENABLE_TRAINING) - file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/core/graph/*.h" - "${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc" - ) -endif() - -set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) if (onnxruntime_ENABLE_TRAINING_OPS) - list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src}) + file(GLOB orttraining_graph_src_exclude ${orttraining_graph_src_exclude_patterns}) + list(REMOVE_ITEM orttraining_graph_src ${orttraining_graph_src_exclude}) endif() -onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_lib_src}) +onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_src} ${orttraining_graph_src}) add_dependencies(onnxruntime_graph onnx_proto flatbuffers::flatbuffers) onnxruntime_add_include_to_target(onnxruntime_graph onnxruntime_common ${WIL_TARGET} onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11) @@ -120,7 +127,7 @@ endif() set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX) -source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) +source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src}) if (onnxruntime_ENABLE_TRAINING_OPS) source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src}) endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index bee83ff07c74..f7103c3b00a3 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -1,7 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) +set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas) +set(MLAS_SRC_DIR ${MLAS_ROOT}/lib) +set(MLAS_INC_DIR ${MLAS_ROOT}/inc) # # All hardware agnostic source files here @@ -9,6 +11,7 @@ set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) # multi-target build # onnxruntime_add_static_library(onnxruntime_mlas + ${MLAS_SRC_DIR}/mlasi.h ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp @@ -33,9 +36,18 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + if (NOT onnxruntime_ORT_MINIMAL_BUILD) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/q4_dq.cpp @@ -45,15 +57,6 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) -function(add_jblas) - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) - target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/jblas_gemm.cpp - ) - set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) -endfunction() - #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -143,10 +146,6 @@ function(setup_mlas_source_for_windows) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/arm/sgemmc.cpp ) - # it should be removed after Visual Stuio is upgraded to 17.7 - if (MSVC) - add_compile_options("-d2SSAOptimizer-") - endif() elseif(onnxruntime_target_platform STREQUAL "x64") file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS @@ -198,6 +197,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/sgemma.asm ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm @@ -300,8 +300,8 @@ else() if(APPLE) get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH GREATER 1) + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) endif() #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below @@ -348,25 +348,31 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -531,6 +537,7 @@ else() ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp @@ -612,15 +619,13 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() -if(USE_JBLAS) - add_jblas() -endif() - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) + target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") endforeach() -set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") + if (WIN32) target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") if (onnxruntime_ENABLE_STATIC_ANALYSIS) @@ -628,6 +633,12 @@ if (WIN32) endif() endif() +if (PLATFORM_NAME STREQUAL "macabi") + # Needed for maccatalyst C compilation + # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" + target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) +endif() + if (NOT onnxruntime_BUILD_SHARED_LIB) install(TARGETS onnxruntime_mlas ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -636,6 +647,21 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() +# set up source group for MLAS source files +block() + set(source_group_srcs) + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + get_target_property(mlas_target_srcs ${mlas_target} SOURCES) + foreach(mlas_target_src ${mlas_target_srcs}) + cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) + if(in_mlas_root) + list(APPEND source_group_srcs ${mlas_target_src}) + endif() + endforeach() + endforeach() + source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs}) +endblock() + if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -647,7 +673,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) onnxruntime_add_executable(onnxruntime_mlas_q4dq ${MLAS_SRC_DIR}/q4_dq_cli.cpp ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) + target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index 6053b9d1088c..555baac6f1a5 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -88,7 +88,7 @@ add_custom_target(js_common_npm_ci ALL add_custom_target(nodejs_binding_wrapper ALL COMMAND ${NPM_CLI} ci - COMMAND ${NPM_CLI} run build -- --onnxruntime-build-dir=${CMAKE_CURRENT_BINARY_DIR} --config=${CMAKE_BUILD_TYPE} + COMMAND ${NPM_CLI} run build -- --onnxruntime-build-dir=${CMAKE_CURRENT_BINARY_DIR} --config=${CMAKE_BUILD_TYPE} --onnxruntime-generator=${CMAKE_GENERATOR} --arch=${NODEJS_BINDING_ARCH} ${NODEJS_BINDING_USE_CUDA} ${NODEJS_BINDING_USE_DML} ${NODEJS_BINDING_USE_TENSORRT} ${NODEJS_BINDING_USE_COREML} WORKING_DIRECTORY ${JS_NODE_ROOT} diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 6f09583199ff..f15d5b8dd6f8 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -130,3 +130,7 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + +if (onnxruntime_USE_ROCM) + add_dependencies(onnxruntime_optimizer generate_hipified_files) +endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 2cf0a6b2b9bd..6c5369ca3be3 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -67,7 +67,7 @@ if(onnxruntime_USE_CUDA) endif() if(onnxruntime_USE_COREML) if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "visionOS") - set(PROVIDERS_COREML onnxruntime_providers_coreml onnxruntime_coreml_proto) + set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) else() set(PROVIDERS_COREML onnxruntime_providers_coreml) endif() diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index 7c712fc40064..b8ebc4ca5323 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -1,107 +1,220 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) - message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") - endif() +if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") +endif() + +add_compile_definitions(USE_COREML=1) - add_compile_definitions(USE_COREML=1) - - # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "visionOS") - set(COREML_PROTO_ROOT ${PROJECT_SOURCE_DIR}/../onnxruntime/core/providers/coreml/mlmodel_format) - file(GLOB coreml_proto_srcs - "${COREML_PROTO_ROOT}/*.proto" - ) - onnxruntime_add_static_library(onnxruntime_coreml_proto ${coreml_proto_srcs}) - target_include_directories(onnxruntime_coreml_proto PUBLIC $ "${CMAKE_CURRENT_BINARY_DIR}") - target_compile_definitions(onnxruntime_coreml_proto PUBLIC $) - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") - set(_src_sub_dir "coreml/") - onnxruntime_protobuf_generate( - APPEND_PATH - GEN_SRC_SUB_DIR ${_src_sub_dir} - IMPORT_DIRS ${COREML_PROTO_ROOT} - TARGET onnxruntime_coreml_proto - ) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_coreml_proto - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() +# Check if we can build the coremltools code for creating an mlpackage with an mlprogram. +# The coremltools source requires std::filesystem::path which is only available from iOS 13 on. +set(_enable_ML_PROGRAM ON) +if (IOS AND CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS 13.0) + message(WARNING "CoreML ML Program is not supported on iOS < 13.0. Excluding ML Program support from build.") + set(_enable_ML_PROGRAM OFF) +elseif(LINUX) + # uuid-dev is required. we don't bother installing on CIs as it's really for manual developer testing. + find_library(LibUUID_LIBRARY NAMES uuid) + find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h) + if (NOT LibUUID_INCLUDE_DIR) + message(STATUS "uuid/uuid.h was not found as is required for ML Program support. " + "Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ") + set(_enable_ML_PROGRAM OFF) endif() +endif() + +if (_enable_ML_PROGRAM) + add_compile_definitions(COREML_ENABLE_MLPROGRAM=1) +endif() + +# Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto +set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) +file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto") + +onnxruntime_add_static_library(coreml_proto ${coreml_proto_srcs}) +target_include_directories(coreml_proto + PUBLIC $ + "${CMAKE_CURRENT_BINARY_DIR}") +target_compile_definitions(coreml_proto + PUBLIC $) +set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") +set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") - # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" +set(_src_sub_dir "coreml_proto/") +onnxruntime_protobuf_generate( + APPEND_PATH + GEN_SRC_SUB_DIR ${_src_sub_dir} + IMPORT_DIRS ${COREML_PROTO_ROOT} + TARGET coreml_proto +) + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS coreml_proto + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} ) +endif() + +# Add the .proto and generated .cc/.h files to the External/coreml_proto folder in Visual Studio. +# Separate source_group for each as the .proto files are in the repo and the .cc/.h files are generated in the build +# output directory. +set_target_properties(coreml_proto PROPERTIES FOLDER "External") +source_group(TREE ${COREML_PROTO_ROOT} PREFIX coreml_proto FILES ${coreml_proto_srcs}) + +# filter to the generated .cc/.h files +get_target_property(coreml_proto_generated_srcs coreml_proto SOURCES) +list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$") +source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs}) + +# These are shared utils, +# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML +file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" +) +file(GLOB onnxruntime_providers_coreml_public_headers CONFIGURE_DEPENDS + "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml/*.h" +) + +file(GLOB + onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" +) + +# Add builder source code +file(GLOB_RECURSE + onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" +) + +if(_enable_ML_PROGRAM) + # Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them + # build on Windows and Linux. file(GLOB - onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" + onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Util/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/BlobDataType.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp" ) - # Add builder source code - file(GLOB_RECURSE - onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" + # Add helpers to create mlpackage + file(GLOB + onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" + "${coremltools_SOURCE_DIR}/modelpackage/src/utils/JsonMap.?pp" ) - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "visionOS") - list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" - ) - endif() - # Add CoreML objective c++ source code - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "visionOS") - file(GLOB - onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" - ) - endif() - - set(onnxruntime_providers_coreml_cc_srcs - ${onnxruntime_providers_coreml_cc_srcs_top} - ${onnxruntime_providers_coreml_cc_srcs_nested} - ${onnxruntime_providers_shared_utils_cc_srcs} + set(coremltools_srcs + ${onnxruntime_providers_coreml_milblob_cc_srcs} + ${onnxruntime_providers_coreml_modelpackage_cc_srcs} ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_coreml - ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} + source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs}) +endif() + +# Add CoreML objective c++ source code +if (APPLE) + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ) - onnxruntime_add_include_to_target(onnxruntime_providers_coreml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface +else() + # add the Model implementation that uses the protobuf types but excludes any actual CoreML dependencies + # by using stub implementations on non-Apple platforms. + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils_stub.cc" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model_stub.cc" ) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "visionOS") - onnxruntime_add_include_to_target(onnxruntime_providers_coreml onnxruntime_coreml_proto) - target_link_libraries(onnxruntime_providers_coreml PRIVATE onnxruntime_coreml_proto "-framework Foundation" "-framework CoreML") - add_dependencies(onnxruntime_providers_coreml onnxruntime_coreml_proto) - endif() - add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) - - set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) - set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) - set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_coreml - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + +set(onnxruntime_providers_coreml_cc_srcs + ${onnxruntime_providers_coreml_cc_srcs_top} + ${onnxruntime_providers_coreml_cc_srcs_nested} + ${onnxruntime_providers_shared_utils_cc_srcs} + ${onnxruntime_providers_coreml_objcc_srcs} +) + +source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_INCLUDE_DIR} FILES ${onnxruntime_providers_coreml_public_headers}) + +onnxruntime_add_static_library(onnxruntime_providers_coreml + ${onnxruntime_providers_coreml_public_headers} + ${onnxruntime_providers_coreml_cc_srcs} + ${coremltools_srcs} +) + +onnxruntime_add_include_to_target(onnxruntime_providers_coreml + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 + safeint_interface +) + +onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) +target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto) +add_dependencies(onnxruntime_providers_coreml coreml_proto) + +if (APPLE) + target_compile_definitions(onnxruntime_providers_coreml PRIVATE __APPLE__) +endif() + +if (_enable_ML_PROGRAM) + # Setup coremltools fp16 and json dependencies for creating an mlpackage. + # + # These are also used by external/xnnpack.cmake. fp16 depends on psimd + FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) + onnxruntime_fetchcontent_makeavailable(psimd) + set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR}) + FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16}) + set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") + set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") + onnxruntime_fetchcontent_makeavailable(fp16) + + # need to tweak the include paths to match what the coreml source code expects + target_include_directories(onnxruntime_providers_coreml PRIVATE + ${fp16_SOURCE_DIR}/include + ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann + ${coremltools_SOURCE_DIR} + ${coremltools_SOURCE_DIR}/mlmodel/src/ + ${coremltools_SOURCE_DIR}/modelpackage/src/ + ) + + add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16) + + if (LINUX) + target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid) endif() +endif() + +if (APPLE) + target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML") +endif() + +add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) +set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") +target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) +set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_coreml + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index 397ef5b5b50e..b211c02f712b 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() + set(onnxruntime_cpu_neural_speed_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" + ) + if(NOT USE_NEURAL_SPEED) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(USE_NEURAL_SPEED) + onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) + endif() +endif() + if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 84d1376f99d5..1346a9ce968c 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -1,10 +1,25 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" - ) + + if (onnxruntime_CUDA_MINIMAL) + file(GLOB onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.cc" + ) + # Remove pch files + list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/integer_gemm.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/triton_kernel.h" + ) + else() + file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" + ) + endif() # Remove pch files list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" @@ -16,11 +31,16 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) - file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" - ) + + if (onnxruntime_CUDA_MINIMAL) + set(onnxruntime_providers_cuda_shared_srcs "") + else() + file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" + ) + endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) @@ -102,7 +122,7 @@ endif() if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and - # add to the lib onnxruntime_providers_cuda separatedly. + # added to the lib onnxruntime_providers_cuda separately. # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) @@ -121,18 +141,22 @@ if (HAS_GUARD_CF) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /guard:cf>") endif() + if (HAS_QSPECTRE) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Qspectre>") endif() + foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") endforeach() + # CUDA 11.3+ supports parallel compilation # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1) target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() + if (UNIX) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" "$<$>:-Wno-reorder>") @@ -142,6 +166,13 @@ #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4127>") + if (MSVC) + # the VS warnings for 'Conditional Expression is Constant' are spurious as they don't handle multiple conditions + # e.g. `if (std::is_same_v && not_a_const)` will generate the warning even though constexpr cannot + # be used due to `&& not_a_const`. This affects too many places for it to be reasonable to disable at a finer + # granularity. + target_compile_options(${target} PRIVATE "$<$:/wd4127>") + endif() endif() onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) @@ -156,10 +187,16 @@ endif() add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) - if(onnxruntime_CUDNN_HOME) - target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) + target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart) + else() + target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart + ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + if(onnxruntime_CUDNN_HOME) + target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) + target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + endif() endif() if (onnxruntime_USE_TRITON_KERNEL) @@ -171,25 +208,24 @@ target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) # lib cuda needed by cuLaunchKernel - target_link_libraries(${target} PRIVATE cuda) + target_link_libraries(${target} PRIVATE CUDA::cuda_driver) endif() include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling - target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64) - target_link_libraries(${target} PRIVATE cupti) + target_link_libraries(${target} PRIVATE CUDA::cupti) endif() - if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32) - target_link_libraries(${target} PRIVATE nvToolsExt) + if (onnxruntime_ENABLE_NVTX_PROFILE) + target_link_libraries(${target} PRIVATE CUDA::nvtx3) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_nnapi.cmake b/cmake/onnxruntime_providers_nnapi.cmake index 5ac25a3b76ef..b718a976eb26 100644 --- a/cmake/onnxruntime_providers_nnapi.cmake +++ b/cmake/onnxruntime_providers_nnapi.cmake @@ -49,12 +49,10 @@ endif() # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML + # TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML list(APPEND onnxruntime_provider_nnapi_cc_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns}) @@ -81,4 +79,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index e26f0bfc0b75..5876b2b5c448 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -16,23 +16,19 @@ endif() # Header paths - find_package(InferenceEngine REQUIRED) - find_package(ngraph REQUIRED) - - if (OPENVINO_2022_1 OR OPENVINO_2022_2) find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) - list (OV_20_LIBS openvino::frontend::onnx openvino::runtime) + if(OpenVINO_VERSION VERSION_LESS 2023.0) + message(FATAL_ERROR "OpenVINO 2023.0 and newer are supported. Please, latest OpenVINO release") endif() if (WIN32) unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) endif() + list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES}) if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS})) add_definitions(-DIO_BUFFER_ENABLED=1) - list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS} ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES}) - else() - list(APPEND OPENVINO_LIB_LIST ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES}) + list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS}) endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) @@ -75,7 +71,14 @@ message(FATAL_ERROR "onnxruntime_providers_openvino unknown platform, need to specify shared library exports for it") endif() - install(TARGETS onnxruntime_providers_openvino - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + if (CMAKE_OPENVINO_LIBRARY_INSTALL_DIR) + install(TARGETS onnxruntime_providers_openvino + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_OPENVINO_LIBRARY_INSTALL_DIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + else() + install(TARGETS onnxruntime_providers_openvino + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index a93a06e960c8..b68d84c23bb3 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -4,12 +4,10 @@ add_compile_definitions(USE_QNN=1) # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + # TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML + file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB_RECURSE @@ -42,4 +40,4 @@ # ignore the warning unknown-pragmas on "pragma region" if(NOT MSVC) target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 686a993de3a4..15ffc29e79ff 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -8,7 +8,7 @@ set(BUILD_LIBRARY_ONLY 1) add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") - set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS}) set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) @@ -58,7 +58,7 @@ URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} ) if (NOT CUDA_INCLUDE_DIR) - set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build + set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build endif() # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. @@ -102,11 +102,12 @@ onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) else() - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) endif() - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) if(onnxruntime_CUDNN_HOME) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) endif() diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 0951c2d02664..183a3e196af4 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -14,14 +14,19 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) - if(NOT MSVC) - target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) - endif(NOT MSVC) + onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} nlohmann_json::nlohmann_json safeint_interface flatbuffers::flatbuffers) + target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED}) + if(MSVC) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai dbghelp) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/vitisai/symbols.def") + else(MSVC) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/vitisai/version_script.lds -Xlinker --gc-sections") + endif(MSVC) target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) if(MSVC) @@ -30,17 +35,18 @@ target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") # for unused formal parameter target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") + # for type name first seen using 'class' now seen using 'struct' + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4099") else(MSVC) + target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) endif(MSVC) set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_vitisai - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + install(TARGETS onnxruntime_providers_vitisai + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 9c00703ca084..796536ac9d12 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -7,9 +7,6 @@ "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc" - # utils for handling QDQ models - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) @@ -19,6 +16,12 @@ flatbuffers::flatbuffers Boost::mp11 safeint_interface ) + # TODO fix stringop-overflow warnings + # Add compile option to suppress stringop-overflow error in Flatbuffers. + if (HAS_STRINGOP_OVERFLOW) + target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=stringop-overflow) + endif() + add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_xnnpack PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 86c1071dba98..17e0f1c5f3fb 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -170,7 +170,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} ${PROVIDERS_TVM} - ${PROVIDERS_VITISAI} ${PROVIDERS_NNAPI} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} @@ -283,10 +282,7 @@ if (WIN32) get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE) string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}") if(NOT onnxruntime_CUDA_VERSION) - message("Reading json file ${onnxruntime_CUDA_HOME}/version.json") - set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json") - file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT) - string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version") + set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION}) message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}") endif() file(APPEND "${VERSION_INFO_FILE}" @@ -354,9 +350,6 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_optim_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/optim/*.py" ) - file(GLOB onnxruntime_python_torchdynamo_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/training/torchdynamo/*.py" - ) file(GLOB onnxruntime_python_ortmodule_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/*.py" ) @@ -477,6 +470,9 @@ file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) +file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py" +) file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" ) @@ -547,6 +543,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/llama COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/phi2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper @@ -650,6 +647,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_phi2_src} + $/onnxruntime/transformers/models/phi2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_stable_diffusion_src} $/onnxruntime/transformers/models/stable_diffusion/ @@ -746,7 +746,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental/gradient_graph COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/optim - COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/torchdynamo COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental/json_config @@ -777,9 +776,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_optim_srcs} $/onnxruntime/training/optim/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_torchdynamo_srcs} - $/onnxruntime/training/torchdynamo/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_srcs} $/onnxruntime/training/ortmodule/ @@ -859,6 +855,16 @@ if (onnxruntime_USE_DNNL) ) endif() +if (onnxruntime_USE_VITISAI) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${DNNL_DLL_PATH} $ + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_TENSORRT) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD @@ -995,6 +1001,15 @@ if (onnxruntime_USE_COREML) ) endif() +if (onnxruntime_USE_QNN) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${QNN_LIB_FILES} + $/onnxruntime/capi/ + ) +endif() + endif() if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) include(onnxruntime_language_interop_ops.cmake) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index f70961a66329..0051f241e4f9 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -20,10 +20,6 @@ set(contrib_ops_excluded_files "bert/fastertransformer_decoder_attention/*" "bert/multihead_attention.cc" "bert/multihead_attention.h" - "bert/fast_gelu_impl.cu" - "bert/fast_gelu_impl.h" - "bert/fast_gelu.cc" - "bert/fast_gelu.h" "bert/relative_attn_bias.cc" "bert/relative_attn_bias.h" "bert/relative_attn_bias_impl.cu" @@ -44,9 +40,7 @@ set(contrib_ops_excluded_files "bert/packed_multihead_attention.cc" "bert/packed_multihead_attention_impl.h" "bert/packed_multihead_attention_impl.cu" - "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" - "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" @@ -66,6 +60,8 @@ set(contrib_ops_excluded_files "quantization/matmul_nbits.cc" "quantization/matmul_nbits.cuh" "quantization/matmul_nbits.cu" + "quantization/moe_quantization.h" + "quantization/moe_quantization.cc" "quantization/quantize_dequantize_linear.cc" "quantization/qordered_ops/qordered_attention_impl.cu" "quantization/qordered_ops/qordered_attention_impl.h" @@ -100,26 +96,18 @@ set(contrib_ops_excluded_files "bert/group_query_attention.cc" "bert/group_query_attention_impl.h" "bert/group_query_attention_impl.cu" + "collective/distributed_*" + "collective/shard*" ) -if (NOT onnxruntime_ENABLE_ATEN) - list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") -endif() if (NOT onnxruntime_USE_NCCL) # Those are string patterns to exclude. Do NOT use stars such as # collective/*.cc or *.h. list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") - list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h") - list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc") - list(APPEND contrib_ops_excluded_files "collective/sharding.cc") - list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_unsqueeze.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_squeeze.cc") +endif() + +if (NOT onnxruntime_ENABLE_ATEN) + list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") endif() set(provider_excluded_files diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6991081f1b0d..fce60090b81f 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") +if (IOS) find_package(XCTest REQUIRED) endif() @@ -18,7 +18,7 @@ function(AddTest) cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS;TEST_ARGS" ${ARGN}) list(REMOVE_DUPLICATES _UT_SOURCES) - if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + if (IOS) onnxruntime_add_executable(${_UT_TARGET} ${TEST_SRC_DIR}/xctest/orttestmain.m) else() onnxruntime_add_executable(${_UT_TARGET} ${_UT_SOURCES}) @@ -67,7 +67,7 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE cudart) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -111,7 +111,9 @@ function(AddTest) target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") - target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") + if (${HAS_NOERROR}) + target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") + endif() endif() set(TEST_ARGS ${_UT_TEST_ARGS}) @@ -127,7 +129,7 @@ function(AddTest) endif() endif(onnxruntime_GENERATE_TEST_REPORTS) - if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + if (IOS) # target_sources(${_UT_TARGET} PRIVATE ${TEST_SRC_DIR}/xctest/orttestmain.m) set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest" MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET} @@ -565,11 +567,7 @@ if(onnxruntime_USE_ROCM) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) - else() - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_ACL) @@ -591,7 +589,6 @@ set(ONNXRUNTIME_TEST_LIBS # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} ${PROVIDERS_JS} - ${PROVIDERS_VITISAI} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} @@ -675,15 +672,9 @@ endif() if(onnxruntime_USE_COREML) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) - else() - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_XNNPACK) @@ -743,34 +734,37 @@ target_include_directories(onnxruntime_test_utils PUBLIC "${TEST_SRC_DIR}/util/i set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest") source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_test_utils_src}) -set(onnx_test_runner_src_dir ${TEST_SRC_DIR}/onnx) -file(GLOB onnx_test_runner_common_srcs CONFIGURE_DEPENDS - ${onnx_test_runner_src_dir}/*.h - ${onnx_test_runner_src_dir}/*.cc) +if(NOT IOS) + set(onnx_test_runner_src_dir ${TEST_SRC_DIR}/onnx) + file(GLOB onnx_test_runner_common_srcs CONFIGURE_DEPENDS + ${onnx_test_runner_src_dir}/*.h + ${onnx_test_runner_src_dir}/*.cc) -list(REMOVE_ITEM onnx_test_runner_common_srcs ${onnx_test_runner_src_dir}/main.cc) + list(REMOVE_ITEM onnx_test_runner_common_srcs ${onnx_test_runner_src_dir}/main.cc) -onnxruntime_add_static_library(onnx_test_runner_common ${onnx_test_runner_common_srcs}) -if(MSVC) - target_compile_options(onnx_test_runner_common PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") -else() - target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11) - target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) - onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp) -endif() -if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - #TODO: fix the warnings, they are dangerous - target_compile_options(onnx_test_runner_common PRIVATE "/wd4244") -endif() -onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework - onnxruntime_test_utils onnx onnx_proto re2::re2 flatbuffers::flatbuffers Boost::mp11 safeint_interface) + onnxruntime_add_static_library(onnx_test_runner_common ${onnx_test_runner_common_srcs}) + if(MSVC) + target_compile_options(onnx_test_runner_common PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") + else() + target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11) + target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp) + endif() + if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + #TODO: fix the warnings, they are dangerous + target_compile_options(onnx_test_runner_common PRIVATE "/wd4244") + endif() + onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework + onnxruntime_test_utils onnx onnx_proto re2::re2 flatbuffers::flatbuffers Boost::mp11 safeint_interface) -add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) -target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS} - ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS} + ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) -set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest") + set(onnx_test_runner_common_lib onnx_test_runner_common) +endif() set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src}) @@ -783,7 +777,15 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + if (MSVC) + # Cutlass code has an issue with the following: + # warning C4100: 'magic': unreferenced formal parameter + target_compile_options(onnxruntime_providers_cuda_ut PRIVATE "$<$:SHELL:--compiler-options /wd4100>" + "$<$>:/wd4100>") + endif() + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() @@ -824,6 +826,17 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/providers/memcpy_test.cc" ) endif() + list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" + "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc") +endif() + +if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR IOS) + # Because we do not run these model tests in our web or iOS CI build pipelines, and some test code uses C++17 + # filesystem functions that are not available in the iOS version we target. + message("Disable model tests in onnxruntime_test_all") + list(REMOVE_ITEM all_tests + "${TEST_SRC_DIR}/providers/cpu/model_tests.cc" + ) endif() set(test_all_args) @@ -843,7 +856,7 @@ AddTest( TARGET onnxruntime_test_all SOURCES ${all_tests} ${onnxruntime_unittest_main_src} LIBS - onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} + ${onnx_test_runner_common_lib} ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} onnx_test_data_proto DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} @@ -881,7 +894,7 @@ endif() # the default logger tests conflict with the need to have an overall default logger # so skip in this type of target_compile_definitions(onnxruntime_test_all PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS) -if (CMAKE_SYSTEM_NAME STREQUAL "iOS") +if (IOS) target_compile_definitions(onnxruntime_test_all_xc PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS) endif() if(onnxruntime_RUN_MODELTEST_IN_DEBUG_MODE) @@ -906,7 +919,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() @@ -969,39 +982,11 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (onnxruntime_USE_QNN) - if (NOT QNN_ARCH_ABI) - string(TOLOWER ${onnxruntime_target_platform} GEN_PLATFORM) - if(MSVC) - message(STATUS "Building MSVC for architecture ${CMAKE_SYSTEM_PROCESSOR} with CMAKE_GENERATOR_PLATFORM as ${GEN_PLATFORM}") - if (${GEN_PLATFORM} STREQUAL "arm64") - set(QNN_ARCH_ABI aarch64-windows-msvc) - else() - set(QNN_ARCH_ABI x86_64-windows-msvc) - endif() - else() - if (${CMAKE_SYSTEM_NAME} STREQUAL "Android") - set(QNN_ARCH_ABI aarch64-android-clang6.0) - elseif (${CMAKE_SYSTEM_NAME} STREQUAL "Linux") - if (${GEN_PLATFORM} STREQUAL "x86_64") - set(QNN_ARCH_ABI x86_64-linux-clang) - else() - set(QNN_ARCH_ABI aarch64-android) - endif() - endif() - endif() - endif() - if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") - file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.dll") - if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc") - file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so") - list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB}) - endif() - message(STATUS "QNN lib files: " ${QNN_LIB_FILES}) - add_custom_command( - TARGET ${test_data_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ - ) + add_custom_command( + TARGET ${test_data_target} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ + ) endif() endif() @@ -1052,45 +1037,42 @@ if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) list(APPEND onnx_test_libs onnxruntime_language_interop onnxruntime_pyop) endif() -onnxruntime_add_executable(onnx_test_runner ${onnx_test_runner_src_dir}/main.cc) -if(MSVC) - target_compile_options(onnx_test_runner PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") -endif() -if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnx_test_runner PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" - ) -endif() -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1") - endif() -endif() +if (NOT IOS) + onnxruntime_add_executable(onnx_test_runner ${onnx_test_runner_src_dir}/main.cc) + if(MSVC) + target_compile_options(onnx_test_runner PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1") + endif() + endif() -target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json) -target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT}) -if (onnxruntime_USE_ROCM) - target_include_directories(onnx_test_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) -endif() -if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - target_link_libraries(onnx_test_runner PRIVATE Python::Python) -endif() -set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") + target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json) + target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT}) + if (onnxruntime_USE_ROCM) + target_include_directories(onnx_test_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) + endif() + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + target_link_libraries(onnx_test_runner PRIVATE Python::Python) + endif() + set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") -if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") - endif() -endif() + if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") + endif() + endif() -install(TARGETS onnx_test_runner - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - BUNDLE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + install(TARGETS onnx_test_runner + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + BUNDLE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(onnxruntime_BUILD_BENCHMARKS) @@ -1171,90 +1153,80 @@ endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - #perf test runner - set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest) - set(onnxruntime_perf_test_src_patterns - "${onnxruntime_perf_test_src_dir}/*.cc" - "${onnxruntime_perf_test_src_dir}/*.h") + if(NOT IOS) + #perf test runner + set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest) + set(onnxruntime_perf_test_src_patterns + "${onnxruntime_perf_test_src_dir}/*.cc" + "${onnxruntime_perf_test_src_dir}/*.h") - if(WIN32) - list(APPEND onnxruntime_perf_test_src_patterns - "${onnxruntime_perf_test_src_dir}/windows/*.cc" - "${onnxruntime_perf_test_src_dir}/windows/*.h" ) - else () - list(APPEND onnxruntime_perf_test_src_patterns - "${onnxruntime_perf_test_src_dir}/posix/*.cc" - "${onnxruntime_perf_test_src_dir}/posix/*.h" ) - endif() + if(WIN32) + list(APPEND onnxruntime_perf_test_src_patterns + "${onnxruntime_perf_test_src_dir}/windows/*.cc" + "${onnxruntime_perf_test_src_dir}/windows/*.h" ) + else () + list(APPEND onnxruntime_perf_test_src_patterns + "${onnxruntime_perf_test_src_dir}/posix/*.cc" + "${onnxruntime_perf_test_src_dir}/posix/*.h" ) + endif() - file(GLOB onnxruntime_perf_test_src CONFIGURE_DEPENDS - ${onnxruntime_perf_test_src_patterns} - ) - onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc) - if(MSVC) - target_compile_options(onnxruntime_perf_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + file(GLOB onnxruntime_perf_test_src CONFIGURE_DEPENDS + ${onnxruntime_perf_test_src_patterns} + ) + onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc) + if(MSVC) + target_compile_options(onnxruntime_perf_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - endif() - target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} + endif() + target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} ${CMAKE_CURRENT_BINARY_DIR}) - if (onnxruntime_USE_ROCM) - target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) - endif() - if (WIN32) - target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings}) - if (NOT DEFINED SYS_PATH_LIB) - set(SYS_PATH_LIB shlwapi) + if (onnxruntime_USE_ROCM) + target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) + endif() + if (WIN32) + target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings}) + if (NOT DEFINED SYS_PATH_LIB) + set(SYS_PATH_LIB shlwapi) + endif() endif() - endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnxruntime_perf_test PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" - ) - endif() - if (onnxruntime_BUILD_SHARED_LIB) - #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. - #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. - set(onnxruntime_perf_test_libs + if (onnxruntime_BUILD_SHARED_LIB) + #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. + #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. + set(onnxruntime_perf_test_libs onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) - if(NOT WIN32) - list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp) - if(onnxruntime_USE_SNPE) - list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe) + if(NOT WIN32) + list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp) + if(onnxruntime_USE_SNPE) + list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe) + endif() endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + list(APPEND onnxruntime_perf_test_libs ${android_shared_libs}) + endif() + target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) + if(WIN32) + target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) + endif() + else() + target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs}) endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - list(APPEND onnxruntime_perf_test_libs ${android_shared_libs}) - endif() - target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) - if(WIN32) - target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) - endif() - if(tensorflow_C_PACKAGE_PATH) - target_include_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/include) - target_link_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/lib) - target_link_libraries(onnxruntime_perf_test PRIVATE tensorflow) - target_compile_definitions(onnxruntime_perf_test PRIVATE HAVE_TENSORFLOW) - endif() - else() - target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs}) - endif() - set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB) - target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop) - endif() + if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB) + target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop) + endif() - if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") + if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") + endif() endif() endif() - # shared lib if (onnxruntime_BUILD_SHARED_LIB) onnxruntime_add_static_library(onnxruntime_mocked_allocator ${TEST_SRC_DIR}/util/test_allocator.cc) @@ -1275,7 +1247,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart) + endif() + if (onnxruntime_USE_ROCM) + list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) @@ -1294,6 +1269,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) + target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" @@ -1302,7 +1281,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_compile_definitions(onnxruntime_shared_lib_test PRIVATE USE_DUMMY_EXA_DEMANGLE=1) endif() - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + if (IOS) add_custom_command( TARGET onnxruntime_shared_lib_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory @@ -1389,7 +1368,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + if(IOS) set_target_properties(onnxruntime_mlas_test PROPERTIES XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" ) @@ -1590,7 +1569,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") DEPENDS ${all_dependencies} ) - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + if (IOS) add_custom_command( TARGET onnxruntime_customopregistration_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory @@ -1662,6 +1641,38 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI ${ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG}) endif() +if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_MINIMAL_BUILD_CUSTOM_OPS)) + + file(GLOB_RECURSE custom_op_local_function_test_library_src + "${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.cc" + "${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.h" + "${TEST_SRC_DIR}/testdata/custom_op_local_function/dummy_gemm.cc" + "${TEST_SRC_DIR}/testdata/custom_op_local_function/dummy_gemm.h" + ) + + onnxruntime_add_shared_library_module(custom_op_local_function ${custom_op_local_function_test_library_src}) + + onnxruntime_add_include_to_target(custom_op_local_function onnxruntime_common GTest::gtest GTest::gmock) + target_include_directories(custom_op_local_function PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session + ${REPO_ROOT}/include/onnxruntime/core/common) + + if(UNIX) + if (APPLE) + set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip") + else() + string(CONCAT ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG + "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.lds " + "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") + endif() + else() + set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG + "-DEF:${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.def") + endif() + + set_property(TARGET custom_op_local_function APPEND_STRING PROPERTY LINK_FLAGS + ${ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG}) +endif() + if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) set (onnxruntime_logging_apis_test_SRC ${ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR}/test_logging_apis.cc) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 9014089cb611..546d50c1ca2d 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -225,6 +225,7 @@ else() "SHELL:-s EXPORT_ALL=0" "SHELL:-s VERBOSE=0" "SHELL:-s FILESYSTEM=0" + "SHELL:-s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm,mainScriptUrlOrBlob]" ${WASM_API_EXCEPTION_CATCHING} --no-entry ) @@ -267,7 +268,10 @@ else() endif() if (onnxruntime_USE_WEBNN) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + if (onnxruntime_DISABLE_RTTI) + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") + endif() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. @@ -281,6 +285,7 @@ else() target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s EXPORT_NAME=ortWasmThreaded" "SHELL:-s DEFAULT_PTHREAD_STACK_SIZE=131072" + "SHELL:-s PTHREAD_POOL_SIZE=Module[\\\"numThreads\\\"]" ) else() target_link_options(onnxruntime_webassembly PRIVATE diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index 66ef0c5125a7..584c49d61229 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -25,17 +25,91 @@ index a6efc98e..8c4de8e7 100644 "/wd4800", ] diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index 0d6c1ec3..75fd935f 100644 +index e6e11949..0aa7d868 100644 --- a/absl/copts/copts.py +++ b/absl/copts/copts.py -@@ -132,10 +132,6 @@ COPT_VARS = { - "/wd4068", # unknown pragma - # qualifier applied to function type has no meaning; ignored - "/wd4180", -- # conversion from 'type1' to 'type2', possible loss of data -- "/wd4244", -- # conversion from 'size_t' to 'type', possible loss of data -- "/wd4267", - # The decorated name was longer than the compiler limit - "/wd4503", - # forcing value to bool 'true' or 'false' (performance warning) +@@ -115,10 +115,6 @@ MSVC_WARNING_FLAGS = [ + "/wd4068", # unknown pragma + # qualifier applied to function type has no meaning; ignored + "/wd4180", +- # conversion from 'type1' to 'type2', possible loss of data +- "/wd4244", +- # conversion from 'size_t' to 'type', possible loss of data +- "/wd4267", + # The decorated name was longer than the compiler limit + "/wd4503", + # forcing value to bool 'true' or 'false' (performance warning) +diff --git a/absl/debugging/symbolize_win32.inc b/absl/debugging/symbolize_win32.inc +index 53a099a1..34d210d6 100644 +--- a/absl/debugging/symbolize_win32.inc ++++ b/absl/debugging/symbolize_win32.inc +@@ -35,15 +35,15 @@ ABSL_NAMESPACE_BEGIN + + static HANDLE process = NULL; + +-void InitializeSymbolizer(const char*) { +- if (process != nullptr) { +- return; +- } ++namespace { ++void InitializeSymbolizerImpl() { ++ + process = GetCurrentProcess(); + + // Symbols are not loaded until a reference is made requiring the + // symbols be loaded. This is the fastest, most efficient way to use + // the symbol handler. ++ + SymSetOptions(SYMOPT_DEFERRED_LOADS | SYMOPT_UNDNAME); + if (!SymInitialize(process, nullptr, true)) { + // GetLastError() returns a Win32 DWORD, but we assign to +@@ -54,6 +54,36 @@ void InitializeSymbolizer(const char*) { + } + } + ++bool LookupAndInitialize(const void* pc, SYMBOL_INFO* symbol) { ++ auto hProcess = (process != NULL) ? process : GetCurrentProcess(); ++ if (SymFromAddr(hProcess, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ if (GetLastError() == ERROR_INVALID_HANDLE && process == NULL) { ++ InitializeSymbolizerImpl(); ++ if (SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ return false; ++ } ++ } else { ++ return false; ++ } ++ return false; ++ } ++ return true; ++} ++} ++ ++void InitializeSymbolizer(const char*) { ++ if (process != nullptr) { ++ return; ++ } ++ ++ alignas(SYMBOL_INFO) char buf[sizeof(SYMBOL_INFO) + MAX_SYM_NAME]; ++ SYMBOL_INFO* symbol = reinterpret_cast(buf); ++ symbol->SizeOfStruct = sizeof(SYMBOL_INFO); ++ symbol->MaxNameLen = MAX_SYM_NAME; ++ ++ static_cast(LookupAndInitialize(reinterpret_cast(&InitializeSymbolizer), symbol)); ++} ++ + bool Symbolize(const void* pc, char* out, int out_size) { + if (out_size <= 0) { + return false; +@@ -62,9 +92,11 @@ bool Symbolize(const void* pc, char* out, int out_size) { + SYMBOL_INFO* symbol = reinterpret_cast(buf); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; +- if (!SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol)) { ++ ++ if(!LookupAndInitialize(pc, symbol)) { + return false; + } ++ + const size_t out_size_t = static_cast(out_size); + strncpy(out, symbol->Name, out_size_t); + if (out[out_size_t - 1] != '\0') { diff --git a/cmake/patches/coremltools/crossplatformbuild.patch b/cmake/patches/coremltools/crossplatformbuild.patch new file mode 100644 index 000000000000..7f2268f50c82 --- /dev/null +++ b/cmake/patches/coremltools/crossplatformbuild.patch @@ -0,0 +1,155 @@ +diff --git a/mlmodel/src/MILBlob/Blob/FileWriter.cpp b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +index adc7bfcf..7b2bf9cc 100644 +--- a/mlmodel/src/MILBlob/Blob/FileWriter.cpp ++++ b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +@@ -8,8 +8,12 @@ + + #include + #include ++ ++// ORT_EDIT: Exclude mmap on Windows. Not used in this file anyway. ++#if !defined(_WIN32) + #include + #include ++#endif + + using namespace MILBlob; + using namespace MILBlob::Blob; +diff --git a/mlmodel/src/MILBlob/Fp16.cpp b/mlmodel/src/MILBlob/Fp16.cpp +index ae1e71a1..77a7161f 100644 +--- a/mlmodel/src/MILBlob/Fp16.cpp ++++ b/mlmodel/src/MILBlob/Fp16.cpp +@@ -5,6 +5,8 @@ + + #include "MILBlob/Fp16.hpp" + ++// ORT_EDIT: Exclude clang specific pragmas from other builds ++#if defined(__clang__) + // fp16 lib code has some conversion warnings we don't want to globally ignore + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wincompatible-pointer-types" +@@ -12,6 +14,9 @@ + #pragma clang diagnostic ignored "-Wconversion" + #include "fp16/fp16.h" + #pragma clang diagnostic pop ++#else ++#include "fp16/fp16.h" ++#endif + + using namespace MILBlob; + +diff --git a/modelpackage/src/ModelPackage.cpp b/modelpackage/src/ModelPackage.cpp +index 8fee56b9..99e0d8d6 100644 +--- a/modelpackage/src/ModelPackage.cpp ++++ b/modelpackage/src/ModelPackage.cpp +@@ -26,7 +26,14 @@ namespace std { + #else + #error "missing required header " + #endif ++ ++// ORT_EDIT: Use UuidCreate on Windows. ++#if defined(_WIN32) ++#pragma comment(lib, "rpcrt4.lib") // UuidCreate ++#include ++#else + #include ++#endif + #include + + #if defined(__cplusplus) +@@ -187,7 +194,10 @@ public: + ModelPackageItemInfo createFile(const std::string& name, const std::string& author, const std::string& description); + }; + ++// ORT_EDIT: pragma only available on APPLE platforms ++#if defined(__APPLE__) + #pragma mark ModelPackageImpl ++#endif + + ModelPackageImpl::ModelPackageImpl(const std::filesystem::path& path, bool createIfNecessary, bool readOnly) + : m_packagePath(path), +@@ -372,6 +382,20 @@ std::filesystem::path ModelPackageImpl::getItemPath(const std::string& name, con + } + + std::string ModelPackageImpl::generateIdentifier() const { ++// ORT_EDIT: Use built-in UUID generation on Windows ++#if defined(_WIN32) ++ UUID uuid; ++ UuidCreate(&uuid); ++ ++ RPC_CSTR uuidStr; ++ UuidToStringA(&uuid, &uuidStr); ++ ++ std::string uuidStrCpp(reinterpret_cast(uuidStr)); ++ ++ RpcStringFreeA(&uuidStr); ++ ++ return uuidStrCpp; ++#else + uuid_t uuid; + + // uuid_unparse generates a 36-character null-terminated string (37 bytes). +@@ -383,6 +407,7 @@ std::string ModelPackageImpl::generateIdentifier() const { + uuid_unparse(uuid, buf); + + return std::string(buf); ++#endif + } + + ModelPackageItemInfo ModelPackageImpl::createFile(const std::string& name, const std::string& author, const std::string& description) { +@@ -468,7 +493,13 @@ std::shared_ptr ModelPackageImpl::findItem(const std::stri + auto author = itemInfoEntry->getString(kModelPackageItemInfoAuthorKey); + auto description = itemInfoEntry->getString(kModelPackageItemInfoDescriptionKey); + ++// ORT_EDIT: need to use path.string() on Windows ++#if defined(_WIN32) ++ return std::make_shared(std::make_shared(identifier, path.string(), name, author, description)); ++ ++#else + return std::make_shared(std::make_shared(identifier, path, name, author, description)); ++#endif + } + + std::shared_ptr ModelPackageImpl::findItem(const std::string& name, const std::string& author) const +@@ -514,7 +545,9 @@ void ModelPackageImpl::removeItem(const std::string& identifier) + } + + auto path = m_packageDataDirPath / itemInfoEntry->getString(kModelPackageItemInfoPathKey); +- if (0 != std::remove(path.c_str())) { ++ // ORT_EDIT: std::remove doesn't work on Windows. Use std::filesystem::remove instead. ++ // if (0 != std::remove(path.c_str())) { ++ if (!std::filesystem::remove(path)) { + throw std::runtime_error("Failed to remove file at path: " + path.string()); + } + +@@ -525,13 +558,16 @@ bool ModelPackageImpl::isValid(const std::filesystem::path& path) + { + try { + ModelPackageImpl(path, false, true); +- } catch (std::runtime_error& e) { ++ } catch (std::runtime_error& /*e*/) { // ORT_EDIT: comment out unused variable + return false; + } + return true; + } + ++// ORT_EDIT: pragma only available on APPLE platforms ++#if defined(__APPLE__) + #pragma mark ModelPackage ++#endif + + ModelPackage::ModelPackage(const std::string& packagePath, bool createIfNecessary, bool readOnly) + : m_modelPackageImpl(std::make_shared(packagePath, createIfNecessary, readOnly)) +@@ -544,7 +580,12 @@ ModelPackage::~ModelPackage() + + std::string ModelPackage::path() const + { ++// ORT_EDIT: Windows doesn't automatically convert to std::string as the native format could be char or wchar. ++#if defined(_WIN32) ++ return m_modelPackageImpl->path().string(); ++#else + return m_modelPackageImpl->path(); ++#endif + } + + std::string ModelPackage::setRootModel(const std::string& path, const std::string& name, const std::string& author, const std::string& description) diff --git a/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch b/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch new file mode 100644 index 000000000000..afb19a45ce0f --- /dev/null +++ b/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch @@ -0,0 +1,22 @@ +diff --git a/include/cpuinfo.h b/include/cpuinfo.h +index c46b65e..8b83a64 100644 +--- a/include/cpuinfo.h ++++ b/include/cpuinfo.h +@@ -18,7 +18,7 @@ + #define CPUINFO_ARCH_X86 1 + #endif + +-#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) ++#if defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) || (defined(_M_AMD64) && !defined(_M_ARM64EC)) + #define CPUINFO_ARCH_X86_64 1 + #endif + +@@ -26,7 +26,7 @@ + #define CPUINFO_ARCH_ARM 1 + #endif + +-#if defined(__aarch64__) || defined(_M_ARM64) ++#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC) + #define CPUINFO_ARCH_ARM64 1 + #endif + diff --git a/cmake/patches/flatbuffers/flatbuffers.patch b/cmake/patches/flatbuffers/flatbuffers.patch index fb2678ef1bdc..fbe8db37ecb0 100644 --- a/cmake/patches/flatbuffers/flatbuffers.patch +++ b/cmake/patches/flatbuffers/flatbuffers.patch @@ -2,35 +2,11 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt index 3987eac9..5e5462f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -223,7 +223,7 @@ elseif(CMAKE_COMPILER_IS_GNUCXX) - "${CMAKE_CXX_FLAGS} -std=c++0x") - endif(CYGWIN) - set(CMAKE_CXX_FLAGS -- "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow") -+ "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow -Wno-error=stringop-overflow") - set(FLATBUFFERS_PRIVATE_CXX_FLAGS "-Wold-style-cast") - if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.4) - if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) -diff --git a/src/idl_gen_rust.cpp b/src/idl_gen_rust.cpp -index 55b8439b..dc03e8a8 100644 ---- a/src/idl_gen_rust.cpp -+++ b/src/idl_gen_rust.cpp -@@ -406,7 +406,8 @@ class RustGenerator : public BaseGenerator { - // example: f(A, D::E) -> super::D::E - // does not include leaf object (typically a struct type). - -- size_t i = 0; -+ // fix unused but set variable warning -+ //size_t i = 0; - std::stringstream stream; - - auto s = src->components.begin(); -@@ -417,7 +418,7 @@ class RustGenerator : public BaseGenerator { - if (*s != *d) { break; } - ++s; - ++d; -- ++i; -+ //++i; - } - - for (; s != src->components.end(); ++s) { stream << "super::"; } +@@ -279,5 +279,5 @@ + # Append FLATBUFFERS_CXX_FLAGS to CMAKE_CXX_FLAGS. + if(DEFINED FLATBUFFERS_CXX_FLAGS) + message(STATUS "extend CXX_FLAGS with ${FLATBUFFERS_CXX_FLAGS}") +- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS}") ++ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS} -Wno-error=stringop-overflow") + endif() + message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") diff --git a/cmake/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch b/cmake/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch new file mode 100644 index 000000000000..e503a512a74f --- /dev/null +++ b/cmake/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch @@ -0,0 +1,30 @@ +diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h +index 99f3ccc..a11de9d 100644 +--- a/bestla/bestla/bestla_prologue_b.h ++++ b/bestla/bestla/bestla_prologue_b.h +@@ -456,9 +456,8 @@ class WeightKBlockNInteger { + auto tmpscales = tmp; + auto tmpzeropoints = reinterpret_cast(tmpscales + N * blks); + if (scales) { +- for (size_t i = 0; i < N * blks; i += 2) { ++ for (size_t i = 0; i < N * blks; i ++) { + tmpscales[i] = scales[i] / 16; +- tmpscales[i + 1] = scales[i + 1] / 16; + } + } + if (zero_points) { +diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h +index 6783ee8..59822e5 100644 +--- a/bestla/bestla/kernel_avx512f.h ++++ b/bestla/bestla/kernel_avx512f.h +@@ -673,8 +673,8 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 + zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift); // int3_clip => int8 + zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift); // int3_clip => int8 + +- _mm512_storeu_epi8((__m512i*)dst, zmm1); +- _mm512_storeu_epi8((__m512i*)(dst + 64), zmm2); ++ _mm512_storeu_si512((__m512i*)dst, zmm1); ++ _mm512_storeu_si512((__m512i*)(dst + 64), zmm2); + }; + + assert(head_ignore_num % 8 == 0); diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index a2d7672a3d48..fe8d6622bcc0 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -1,8 +1,8 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 4dd56b6e..018da488 100644 +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 6d7ca846..69aa622f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -397,6 +397,7 @@ if (MSVC) +@@ -499,6 +499,7 @@ if (MSVC) endif() else() # On non-Windows, hide all symbols we don't need @@ -10,7 +10,7 @@ index 4dd56b6e..018da488 100644 set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -548,20 +549,9 @@ endif() +@@ -653,20 +654,9 @@ endif() if(MSVC) target_compile_options(onnx_proto PRIVATE /MP @@ -31,14 +31,72 @@ index 4dd56b6e..018da488 100644 ${EXTRA_FLAGS}) if(ONNX_USE_PROTOBUF_SHARED_LIBS) target_compile_options(onnx_proto +diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h +index b847798e..a6c31904 100644 +--- a/onnx/common/file_utils.h ++++ b/onnx/common/file_utils.h +@@ -6,7 +6,6 @@ + + #pragma once + +-#include + #include + #include + +@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE { + + template + void LoadProtoFromPath(const std::string proto_path, T& proto) { +- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path); +- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary); ++ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary); + if (!proto_stream.good()) { + fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. "); + } +diff --git a/onnx/defs/quantization/defs.cc b/onnx/defs/quantization/defs.cc +index 70b4a4db..98c11545 100644 +--- a/onnx/defs/quantization/defs.cc ++++ b/onnx/defs/quantization/defs.cc +@@ -200,6 +200,9 @@ ONNX_OPERATOR_SET_SCHEMA( + .SetDoc(DequantizeLinear_ver21_doc) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); ++ if (!hasInputShape(ctx, 0)) { ++ return; ++ } + auto& input_shape = getInputShape(ctx, 0); + updateOutputShape(ctx, 0, input_shape); + })); +diff --git a/onnx/defs/quantization/old.cc b/onnx/defs/quantization/old.cc +index 3f2d6384..d2f7cfd8 100644 +--- a/onnx/defs/quantization/old.cc ++++ b/onnx/defs/quantization/old.cc +@@ -130,6 +130,9 @@ ONNX_OPERATOR_SET_SCHEMA( + .SetDoc(DequantizeLinear_ver19_doc) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); ++ if (!hasInputShape(ctx, 0)) { ++ return; ++ } + auto& input_shape = getInputShape(ctx, 0); + updateOutputShape(ctx, 0, input_shape); + })); +@@ -181,7 +184,6 @@ ONNX_OPERATOR_SET_SCHEMA( + if (!hasInputShape(ctx, 0)) { + return; + } +- + auto& input_shape = getInputShape(ctx, 0); + updateOutputShape(ctx, 0, input_shape); + })); diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h -index 0aab3e26..0f859267 100644 +index 0aab3e26..398ac2d6 100644 --- a/onnx/onnx_pb.h +++ b/onnx/onnx_pb.h @@ -47,10 +47,28 @@ #define ONNX_API ONNX_IMPORT #endif - + +#if defined(__GNUC__) +#pragma GCC diagnostic push + @@ -58,9 +116,61 @@ index 0aab3e26..0f859267 100644 #else #include "onnx/onnx.pb.h" #endif - + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc +index fab1faf2..8723dcd4 100644 +--- a/onnx/shape_inference/implementation.cc ++++ b/onnx/shape_inference/implementation.cc +@@ -488,29 +488,29 @@ class ShapeInferenceImplBase { + ProcessCall(n, *(iter->second), ctx); + } else { + has_unsupported_op = true; ++ return; + } + } else { + has_unsupported_op = true; ++ return; + } +- if (!has_unsupported_op) { +- for (int i = 0; i < n.output_size(); ++i) { +- // skip type and shape propagation for missing optional outputs. +- if (!n.output(i).empty()) +- UpdateType(n.output(i), ctx.getOutputType(i)); +- } +- // Constant values are tracked to improve inference/checking for subsequent nodes. +- ProcessConstant(n); +- // If data-propagation is enabled, partial-evaluation (aka data-propagation) is performed +- // to improve inference/checking for subsequent nodes. +- if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) { +- if (generated_shape_data_by_name == nullptr) { +- fail_shape_inference( +- "Container for generated shape data cannot be nullptr when enable_data_propagation option is set."); +- } +- DataPropagationContextImpl data_propagation_ctx( +- n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name); +- schema->GetDataPropagationFunction()(data_propagation_ctx); ++ for (int i = 0; i < n.output_size(); ++i) { ++ // skip type and shape propagation for missing optional outputs. ++ if (!n.output(i).empty()) ++ UpdateType(n.output(i), ctx.getOutputType(i)); ++ } ++ // Constant values are tracked to improve inference/checking for subsequent nodes. ++ ProcessConstant(n); ++ // If data-propagation is enabled, partial-evaluation (aka data-propagation) is performed ++ // to improve inference/checking for subsequent nodes. ++ if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) { ++ if (generated_shape_data_by_name == nullptr) { ++ fail_shape_inference( ++ "Container for generated shape data cannot be nullptr when enable_data_propagation option is set."); + } ++ DataPropagationContextImpl data_propagation_ctx( ++ n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name); ++ schema->GetDataPropagationFunction()(data_propagation_ctx); + } + } + ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) { diff --git a/cmake/riscv64.toolchain.cmake b/cmake/riscv64.toolchain.cmake new file mode 100644 index 000000000000..0fda239f9a62 --- /dev/null +++ b/cmake/riscv64.toolchain.cmake @@ -0,0 +1,35 @@ +# Copyright (c) 2024 SiFive, Inc. All rights reserved. +# Copyright (c) 2024, Phoebe Chen +# Licensed under the MIT License. + +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES RISCV_TOOLCHAIN_ROOT) + +if(NOT RISCV_TOOLCHAIN_ROOT) + message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT is not defined. Please set the RISCV_TOOLCHAIN_ROOT variable.") +endif() + +set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_ASM_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH ${RISCV_TOOLCHAIN_ROOT}) +set(CMAKE_SYSROOT "${RISCV_TOOLCHAIN_ROOT}/sysroot") +set(CMAKE_INCLUDE_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/include/") +set(CMAKE_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib/") +set(CMAKE_PROGRAM_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/bin/") + +if(RISCV_QEMU_PATH) + message(STATUS "RISCV_QEMU_PATH=${RISCV_QEMU_PATH} is defined during compilation.") + set(CMAKE_CROSSCOMPILING_EMULATOR "${RISCV_QEMU_PATH};-L;${CMAKE_SYSROOT}") +endif() + +set(CMAKE_CROSSCOMPILING TRUE) + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + diff --git a/cmake/wcos_rules_override.cmake b/cmake/wcos_rules_override.cmake index f3d8093629a4..ec2303b073d5 100644 --- a/cmake/wcos_rules_override.cmake +++ b/cmake/wcos_rules_override.cmake @@ -1,2 +1,2 @@ -set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) -set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) +set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap.lib) +set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap.lib) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 268ee3960e75..d74250b96262 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -836,6 +836,13 @@ if (winml_is_inbox) target_include_directories(${new_target} PRIVATE ${include_directories}) target_link_libraries(${new_target} PRIVATE ${link_libraries}) target_link_options(${new_target} PRIVATE ${link_options}) + + # Attempt to copy linker flags + get_target_property(link_flags ${target} LINK_FLAGS) + + if (NOT link_flags MATCHES ".*NOTFOUND") + set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}") + endif() endfunction() if (WAI_ARCH STREQUAL x64 OR WAI_ARCH STREQUAL arm64) diff --git a/csharp/ApiDocs/ApiDocs.csproj b/csharp/ApiDocs/ApiDocs.csproj index 994e57913cf4..6081c444ba1a 100644 --- a/csharp/ApiDocs/ApiDocs.csproj +++ b/csharp/ApiDocs/ApiDocs.csproj @@ -7,7 +7,7 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj index 3d35de1dfc6a..b268079e2cca 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj @@ -7,8 +7,8 @@ - - + + diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj index af8fa611a501..647c0bbe6a24 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj @@ -7,8 +7,8 @@ - - + + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 4128524b3048..8a8426a0b305 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -362,6 +362,7 @@ static NativeMethods() OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern)); OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena)); OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena)); + OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads)); OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId)); OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel)); OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel)); @@ -992,6 +993,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); public static DOrtDisableCpuMemArena OrtDisableCpuMemArena; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options); + public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId); public static DOrtSetSessionLogId OrtSetSessionLogId; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 7a68246c9b67..30d005b3c423 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -696,6 +696,15 @@ public bool EnableCpuMemArena } private bool _enableCpuMemArena = true; + /// + /// Disables the per session threads. Default is true. + /// This makes all sessions in the process use a global TP. + /// + public void DisablePerSessionThreads() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle)); + } + /// /// Log Id to be used for the session. Default is empty string. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 68a399f8b967..7fe16f4156ef 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -65,10 +65,10 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(18 /*ORT_API_VERSION*/); OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); - trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/); + trainingApiPtr = OrtGetTrainingApi(18 /*ORT_API_VERSION*/); if (trainingApiPtr != IntPtr.Zero) { trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi)); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 877677dcad57..fec0d46e96df 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -282,6 +282,48 @@ public IDisposableReadOnlyCollection TrainStep( } } + /// + /// This function performs a training step that computes the outputs of the training model and the gradients + /// of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model + /// that was provided to the training session. + /// The TrainStep method is equivalent of running forward propagation and backward propagation in a single + /// step. + /// The gradients computed are stored inside the training session state so they can be later consumed + /// by the OptimizerStep function. + /// The gradients can be lazily reset by invoking the LazyResetGrad function. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.TrainStep(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the training model. + /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. + public IDisposableReadOnlyCollection TrainStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Convert native OrtValue handles to OrtValue instances /// in an exceptions safe manner. @@ -370,6 +412,42 @@ public void EvalStep( inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } + /// + /// This function performs an eval step that computes the outputs of the eval model for the given inputs. + /// Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was + /// provided to the training session. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.EvalSteps(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the eval model. + public IDisposableReadOnlyCollection EvalStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_evalOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_evalOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Sets the learning rate for this training session. @@ -702,6 +780,35 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v return valuesArray; } + private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection inputValues) + { + var valuesArray = new IntPtr[inputValues.Count]; + for (int index = 0; index < inputValues.Count; ++index) + { + valuesArray[index] = inputValues.ElementAt(index).Handle; + } + return valuesArray; + } + + private static IDisposableReadOnlyCollection CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles) + { + var outputValues = new DisposableList(disposableHandles.Span.Length); + try + { + for (int i = 0; i < disposableHandles.Span.Length; i++) + { + outputValues.Add(new OrtValue(disposableHandles.Span[i])); + disposableHandles.Span[i] = IntPtr.Zero; + } + return outputValues; + } + catch (Exception) + { + outputValues.Dispose(); + throw; + } + } + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, DisposableList cleanupList) { cleanupList.Capacity += names.Count; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj index b90929ad6d1c..7bda34d26629 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj @@ -6,7 +6,7 @@ - + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj index 1c9827c5bac6..5ff924bcf82f 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj @@ -37,10 +37,10 @@ - + - + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index fd8feda359f9..d6a6b9627f41 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -55,6 +55,9 @@ public void TestSessionOptions() Assert.Equal(0, opt.InterOpNumThreads); Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel); + // No get, so no verify + opt.DisablePerSessionThreads(); + // try setting options opt.ExecutionMode = ExecutionMode.ORT_PARALLEL; Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode); @@ -98,7 +101,7 @@ public void TestSessionOptions() Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message); // SessionOptions.RegisterOrtExtensions can be manually tested by referencing the - // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. + // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. ex = Assert.Throws(() => { opt.RegisterOrtExtensions(); }); Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj index ee81ab77432d..ab27d62c3bf3 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj @@ -119,8 +119,8 @@ - - + + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 68b1d5bcc614..9b7232620132 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -612,6 +612,81 @@ public void TestUpdateParameter() } } + [Fact(DisplayName = "TestTrainingSessionTrainStepWithOrtValues")] + public void TestTrainingSessionTrainStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.TrainStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + + [Fact(DisplayName = "TestTrainingSessionEvalStepWithOrtValues")] + public void TestTrainingSessionEvalStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.EvalStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj index 37e83be5e33a..40f6d453c6a9 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj @@ -11,6 +11,6 @@ - + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj index 11855032584a..ef7e0825e919 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj @@ -134,7 +134,7 @@ 5.0.0.2083 - + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 715aed7e1d64..7f3d5d6624b0 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -145,7 +145,7 @@ private void TestCUDAProviderOptions() private void CanRunInferenceOnAModelWithTensorRT() { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); - + int deviceId = 0; string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj index 352de5db0092..56e65833724f 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj @@ -99,7 +99,7 @@ 2.4.1 - + 5.0.0.2083 diff --git a/csharp/tools/MauiModelTester/MauiModelTester.csproj b/csharp/tools/MauiModelTester/MauiModelTester.csproj index b0a17978328c..39e688ce6c1b 100644 --- a/csharp/tools/MauiModelTester/MauiModelTester.csproj +++ b/csharp/tools/MauiModelTester/MauiModelTester.csproj @@ -1,8 +1,8 @@  - net6.0-android;net6.0-ios - $(TargetFrameworks);net6.0-windows10.0.19041.0 + net8.0-ios;net8.0-android34.0 + $(TargetFrameworks);net8.0-windows10.0.19041.0 Exe MauiModelTester true @@ -21,7 +21,7 @@ 1 12.0 - 21.0 + 29.0 10.0.17763.0 10.0.17763.0 true @@ -51,7 +51,7 @@ - + diff --git a/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml b/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml index cc320dab474a..2ef2296d7441 100644 --- a/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml +++ b/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml @@ -4,5 +4,5 @@ - + \ No newline at end of file diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj index 24f0d14ad990..e0420a6ed045 100644 --- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj +++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj @@ -80,7 +80,7 @@ - + diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index bc513a8e8ba6..c3541a8bd342 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -5,57 +5,22 @@ # Dockerfile to run ONNXRuntime with MIGraphX integration #-------------------------------------------------------------------------- -FROM ubuntu:20.04 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ARG ROCM_VERSION=5.4 -# MIGraphX version should be the same as ROCm version -ARG MIGRAPHX_VERSION=rocm-5.4.0 -ENV DEBIAN_FRONTEND noninteractive -ENV MIGRAPHX_DISABLE_FAST_GELU=1 -RUN apt-get clean && apt-get update && apt-get install -y locales -RUN locale-gen en_US.UTF-8 -RUN update-locale LANG=en_US.UTF-8 -ENV LC_ALL C.UTF-8 -ENV LANG C.UTF-8 +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} -# Install rocm -RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ - curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/${ROCM_VERSION}/ ubuntu main > /etc/apt/sources.list.d/rocm.list' - -RUN apt-get update &&\ - apt-get install -y sudo git bash build-essential rocm-dev python3-dev python3-pip miopen-hip \ - rocblas half aria2 libnuma-dev pkg-config - -RUN aria2c -q -d /tmp -o cmake-3.27.3-linux-x86_64.tar.gz \ -https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.tar.gz &&\ -tar -zxf /tmp/cmake-3.27.3-linux-x86_64.tar.gz --strip=1 -C /usr - -# Install rbuild -RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz numpy yapf==0.28.0 - -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Install MIGraphX from source -RUN mkdir -p /migraphx -RUN cd /migraphx && git clone --depth=1 --branch ${MIGRAPHX_VERSION} https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src -RUN cd /migraphx && rbuild package --cxx /opt/rocm/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 -RUN dpkg -i /migraphx/build/*.deb -RUN rm -rf /migraphx - -# Install rocm ep dependencies RUN apt-get update &&\ - apt-get install -y rocrand rccl hipsparse hipfft hipcub hipblas rocthrust + apt-get install -y migraphx WORKDIR /code # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ + cd onnxruntime && pip install --upgrade pip &&\ /bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` --config Release --parallel \ --skip_tests --build_wheel --use_rocm --rocm_version=${ROCM_VERSION} --rocm_home /opt/rocm --use_migraphx &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 78d04a51ba16..049916fac92f 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -1,9 +1,9 @@ #------------------------------------------------------------------------- -# Copyright(C) 2021-2023 Intel Corporation. +# Copyright(C) 2021-2024 Intel Corporation. # SPDX-License-Identifier: MIT #-------------------------------------------------------------------------- -ARG OPENVINO_VERSION=2023.0.0 +ARG OPENVINO_VERSION=2024.0.0 # Build stage @@ -17,7 +17,7 @@ ARG DEVICE=CPU_FP32 ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime.git ARG ONNXRUNTIME_BRANCH=main -ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake +ENV OpenVINO_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake USER root RUN apt update; apt install -y git protobuf-compiler libprotobuf-dev diff --git a/dockerfiles/Dockerfile.openvino-centos7 b/dockerfiles/Dockerfile.openvino-centos7 deleted file mode 100755 index 697db44801e3..000000000000 --- a/dockerfiles/Dockerfile.openvino-centos7 +++ /dev/null @@ -1,105 +0,0 @@ -#------------------------------------------------------------------------- -# Copyright(C) 2021 Intel Corporation. -# SPDX-License-Identifier: MIT -#-------------------------------------------------------------------------- - -FROM centos:7.8.2003 - -WORKDIR /code - -ARG MY_ROOT=/code -ARG YUM_OV_PACKAGE=intel-openvino-runtime-centos7-2021.4.752.x86_64 -ARG DEVICE=CPU_FP32 -ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main - -ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2021.4.752 -ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/share -ENV IE_PLUGINS_PATH=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64 -ENV ngraph_DIR=${INTEL_OPENVINO_DIR}/deployment_tools/ngraph/cmake -ENV LD_LIBRARY_PATH=/opt/intel/opencl:${INTEL_OPENVINO_DIR}/inference_engine/external/gna/lib:${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/mkltiny_lnx/lib:$INTEL_OPENVINO_DIR/deployment_tools/ngraph/lib:${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/omp/lib:${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/tbb/lib:${IE_PLUGINS_PATH}:${LD_LIBRARY_PATH} -ENV OpenCV_DIR=${INTEL_OPENVINO_DIR}/opencv/share/OpenCV -ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/opencv/lib:${INTEL_OPENVINO_DIR}/opencv/share/OpenCV/3rdparty/lib:${LD_LIBRARY_PATH} -ENV HDDL_INSTALL_DIR=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/hddl -ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/hddl/lib:$LD_LIBRARY_PATH -ENV LD_LIBRARY_PATH=/usr/local/lib:/usr/lib:/usr/local/lib64:/usr/lib64:/lib64:$LD_LIBRARY_PATH - -# Install packages -RUN yum update -y && \ - yum groupinstall "Development Tools" -y && \ - yum install -y yum-utils autoconf automake libtool unzip udev wget zlib-devel libffi-devel openssl-devel boost-devel-1.53.0 && \ - yum clean packages && yum clean all && rm -rf /var/cache/yum && \ -# Install cmake - cd $MY_ROOT && \ - wget https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3.tar.gz && \ - tar -zxvf cmake-3.27.3.tar.gz && rm -rf cmake-3.27.3.tar.gz && \ - cd cmake-3.27.3 && \ - ./bootstrap && \ - make && \ - make install && \ - cd $MY_ROOT && \ -# libusb1.0.22 - cd /opt/ && wget https://github.com/libusb/libusb/archive/v1.0.22.zip && \ - unzip v1.0.22.zip && rm -rf v1.0.22.zip && cd /opt/libusb-1.0.22 && \ -# bootstrap steps - ./bootstrap.sh && \ - ./configure --disable-udev --enable-shared && \ - make -j4 && \ -# configure libusb1.0.22 - cd /opt/libusb-1.0.22/libusb && \ - /bin/mkdir -p '/usr/local/lib' && \ - /bin/bash ../libtool --mode=install /usr/bin/install -c libusb-1.0.la '/usr/local/lib' && \ - /bin/mkdir -p '/usr/local/include/libusb-1.0' && \ - /usr/bin/install -c -m 644 libusb.h '/usr/local/include/libusb-1.0' && \ - /bin/mkdir -p '/usr/local/lib/pkgconfig' && \ -# Install openvino - yum-config-manager --add-repo https://yum.repos.intel.com/openvino/2021/setup/intel-openvino-2021.repo && \ - rpm --import https://yum.repos.intel.com/openvino/2021/setup/RPM-GPG-KEY-INTEL-OPENVINO-2021 && \ - yum update -y && yum list intel-openvino* && \ - yum install -y $YUM_OV_PACKAGE && \ - cd ${INTEL_OPENVINO_DIR}/install_dependencies/ && ./install_openvino_dependencies.sh -y && \ - printf "\nexport LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib\n" >> /opt/intel/openvino_2021.4.752/bin/setupvars.sh && \ - cd /opt/libusb-1.0.22 && \ - /usr/bin/install -c -m 644 libusb-1.0.pc '/usr/local/lib/pkgconfig' && \ - cp /opt/intel/openvino_2021/deployment_tools/inference_engine/external/97-myriad-usbboot.rules /etc/udev/rules.d/ && \ - ldconfig && \ -# Install GPU runtime and drivers - cd ${MY_ROOT} && \ - mkdir /tmp/opencl && \ - cd /tmp/opencl && \ - yum install -y epel-release && \ - yum install -y ocl-icd ocl-icd-devel && \ - wget -O intel-igc-core-1.0.2597-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-igc-core-1.0.2597-1.el7.x86_64.rpm/download && \ - wget -O intel-opencl-19.41.14441-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-opencl-19.41.14441-1.el7.x86_64.rpm/download && \ - wget -O intel-igc-opencl-devel-1.0.2597-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-igc-opencl-devel-1.0.2597-1.el7.x86_64.rpm/download && \ - wget -O intel-igc-opencl-1.0.2597-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-igc-opencl-1.0.2597-1.el7.x86_64.rpm/download && \ - wget -O intel-gmmlib-19.3.2-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-gmmlib-19.3.2-1.el7.x86_64.rpm/download && \ - wget -O intel-gmmlib-devel-19.3.2-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-gmmlib-devel-19.3.2-1.el7.x86_64.rpm/download && \ - rpm -i /tmp/opencl/*.rpm && \ - ldconfig && \ - rm -rf /tmp/opencl && \ -# Installing gcc-10 - yum install -y centos-release-scl && \ - yum install -y devtoolset-10-gcc* && \ - echo 'source scl_source enable devtoolset-10' >> ~/.bashrc && \ -# python installation - source scl_source enable devtoolset-10 && \ - cd /code/ && \ - wget https://www.python.org/ftp/python/3.8.3/Python-3.8.3.tgz && tar xvf Python-3.8.3.tgz && \ - cd Python-3.8*/ && ./configure && make && make install && \ - cd ../ && mkdir -p /usr/bin/Python38 && ln -s Python-3.8.3/ /usr/bin/Python38 && \ -# installing dependancies - yum install -y python3-lxml python3-six libusb.x86_64 && \ - yum clean packages && yum clean all && rm -rf /var/cache/yum && \ -# Build onnxruntime - cd $MY_ROOT && \ - pip3 install numpy wheel setuptools cython && \ - git clone --recursive -b ${ONNXRUNTIME_BRANCH} ${ONNXRUNTIME_REPO} && \ - pip3 install onnx && \ - cd /code/onnxruntime && ./build.sh --allow_running_as_root --config Release --update --build --parallel --use_openvino ${DEVICE} --build_shared_lib --build_wheel && \ - pip3 install /code/onnxruntime/build/Linux/Release/dist/*-linux_x86_64.whl && \ -# Clean up - cd $MY_ROOT && rm -rf onnxruntime Python-3* && \ - cd ${MY_ROOT}/ && rm -rf cmake* && \ - cd /usr/share/ && rm -rf gcc* && cd /usr/lib/ && rm -rf gcc cd && rm -rf .cache && \ - cd ${INTEL_OPENVINO_DIR}/ && rm -rf documentation data_processing && cd deployment_tools/ && rm -rf tools diff --git a/dockerfiles/Dockerfile.openvino-csharp b/dockerfiles/Dockerfile.openvino-csharp deleted file mode 100644 index 2529ef4b7320..000000000000 --- a/dockerfiles/Dockerfile.openvino-csharp +++ /dev/null @@ -1,90 +0,0 @@ -#------------------------------------------------------------------------- -# Copyright(C) 2021-2023 Intel Corporation. -# SPDX-License-Identifier: MIT -#-------------------------------------------------------------------------- - -ARG OPENVINO_VERSION=2023.0.0 - -# Build stage -FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} AS base - -ENV WORKDIR_PATH=/home/openvino -WORKDIR $WORKDIR_PATH -ENV DEBIAN_FRONTEND noninteractive - -USER root -RUN apt update; apt install -y --no-install-recommends wget gnupg && \ - rm -rf /var/lib/apt/lists/* - -# Install Mono -RUN wget http://download.mono-project.com/repo/xamarin.gpg && apt-key add xamarin.gpg && rm xamarin.gpg && \ - echo "deb https://download.mono-project.com/repo/ubuntu stable-bionic main" | tee /etc/apt/sources.list.d/mono-official-stable.list && \ - apt update -y && \ - apt install -y mono-devel - -# Install nuget.exe -RUN wget https://dist.nuget.org/win-x86-commandline/latest/nuget.exe && \ - mv nuget.exe /usr/local/bin/nuget.exe && \ - echo 'mono /usr/local/bin/nuget.exe $@' > /usr/local/bin/nuget && \ - chmod a+x /usr/local/bin/nuget - -# Install .NET core -RUN wget https://packages.microsoft.com/config/ubuntu/20.04/packages-microsoft-prod.deb -O packages-microsoft-prod.deb && \ - dpkg -i packages-microsoft-prod.deb && \ - apt-get update -y &&\ - apt-get install -y apt-transport-https && \ - apt-get update -y && \ - apt-get install -y dotnet-sdk-5.0 - -# Build stage -FROM base AS builder - -ENV WORKDIR_PATH=/home/openvino -WORKDIR $WORKDIR_PATH -ENV DEBIAN_FRONTEND noninteractive - -ARG DEVICE=CPU_FP32 -ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime.git -ARG ONNXRUNTIME_BRANCH=main - -ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake -ENV LANG en_US.UTF-8 - -USER root -RUN apt update; apt install -y --no-install-recommends git protobuf-compiler libprotobuf-dev ca-certificates unattended-upgrades && \ - unattended-upgrade && \ - rm -rf /var/lib/apt/lists/* - -RUN git clone --recursive -b ${ONNXRUNTIME_BRANCH} ${ONNXRUNTIME_REPO} -RUN /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh -RUN ln -s cmake-* cmake-dir -RUN python3 -m pip install wheel -ENV PATH=${WORKDIR_PATH}/cmake-dir/bin:$PATH -RUN pip3 install onnx -RUN ln -s /usr/bin/python3 /usr/bin/python -RUN apt install locales && \ - locale-gen en_US en_US.UTF-8 && \ - dpkg-reconfigure locales -RUN cd onnxruntime && ./build.sh --allow_running_as_root --config Release --update --build --parallel --use_openvino ${DEVICE} --build_nuget --build_shared_lib -RUN cp /home/openvino/onnxruntime/build/Linux/Release/Microsoft.ML.OnnxRuntime.Managed* /home/openvino/onnxruntime/build/Linux/Release/nuget-artifacts - -# Deploy stage -FROM base - -ENV DEBIAN_FRONTEND noninteractive -USER root - -RUN apt update; apt install -y unattended-upgrades fonts-freefont-ttf && \ - unattended-upgrade -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -RUN usermod -a -G video,users ${BUILD_USER} -ENV WORKDIR_PATH /home/${BUILD_USER} -WORKDIR ${WORKDIR_PATH} -COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/nuget-artifacts ${WORKDIR_PATH}/nuget-artifacts - -USER ${BUILD_USER} -ENV PATH=${WORKDIR_PATH}/miniconda/bin:${WORKDIR_PATH}/cmake-dir/bin:$PATH -ENV IE_PLUGINS_PATH=${INTEL_OPENVINO_DIR}/runtime/lib/intel64 -ENV LD_LIBRARY_PATH=/opt/intel/opencl:${INTEL_OPENVINO_DIR}/runtime/3rdparty/tbb/lib:${IE_PLUGINS_PATH}:${LD_LIBRARY_PATH} diff --git a/dockerfiles/Dockerfile.openvino-rhel8 b/dockerfiles/Dockerfile.openvino-rhel8 deleted file mode 100644 index 5c504cfa553a..000000000000 --- a/dockerfiles/Dockerfile.openvino-rhel8 +++ /dev/null @@ -1,87 +0,0 @@ -# Build stage -FROM registry.access.redhat.com/ubi8/ubi:8.4 - -WORKDIR /code - -ARG MY_ROOT=/code -ARG DEVICE=CPU_FP32 -ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main - -ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2022.3.0 - -ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake -ENV IE_PLUGINS_PATH=${INTEL_OPENVINO_DIR}/runtime/lib/intel64/ -ENV ngraph_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake -ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/runtime/3rdparty/tbb/lib/:${IE_PLUGINS_PATH}:${LD_LIBRARY_PATH} -ENV OpenCV_DIR=${INTEL_OPENVINO_DIR}/extras/opencv/cmake -ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/extras/opencv/lib:${LD_LIBRARY_PATH} -ENV LD_LIBRARY_PATH=/usr/local/lib:/usr/lib:/usr/local/lib64:/usr/lib64:/lib64:${LD_LIBRARY_PATH} -ENV PATH=${MY_ROOT}/cmake-dir/bin:$PATH - -# Install packages -RUN yum install -y yum-utils autoconf automake libtool unzip udev wget zlib-devel libffi-devel openssl-devel git make gcc && \ - yum clean packages && yum clean all && rm -rf /var/cache/yum && \ -# Install python 3.8 - cd $MY_ROOT && \ - wget https://www.python.org/ftp/python/3.8.9/Python-3.8.9.tgz && tar xvf Python-3.8.9.tgz && rm -rf Python-3.8.9.tgz && \ - cd Python-3.8*/ && ./configure && make && make install && \ - cd ../ && mkdir -p /usr/bin/Python38 && ln -s Python-3.8.9/ /usr/bin/Python38 && ln -s /usr/bin/pip3 /usr/bin/pip && \ -# libusb1.0.22 - cd /opt/ && wget https://github.com/libusb/libusb/archive/v1.0.22.zip && \ - unzip v1.0.22.zip && rm -rf v1.0.22.zip && cd /opt/libusb-1.0.22 && \ -# bootstrap steps - ./bootstrap.sh && \ - ./configure --disable-udev --enable-shared && \ - make -j4 && \ -# configure libusb1.0.22 - cd /opt/libusb-1.0.22/libusb && \ - /bin/mkdir -p '/usr/local/lib' && \ - /bin/bash ../libtool --mode=install /usr/bin/install -c libusb-1.0.la '/usr/local/lib' && \ - /bin/mkdir -p '/usr/local/include/libusb-1.0' && \ - /usr/bin/install -c -m 644 libusb.h '/usr/local/include/libusb-1.0' && \ - /bin/mkdir -p '/usr/local/lib/pkgconfig' && \ -# Install openvino - cd /opt/ && mkdir intel/ && cd intel && \ - wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2022.3/linux/l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64.tgz && \ - tar xvf l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64.tgz && \ - rm -rf l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64.tgz && \ - mv l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64 openvino_2022.3.0 && \ - cd ${INTEL_OPENVINO_DIR}/install_dependencies/ && ./install_openvino_dependencies.sh -y && ./install_NEO_OCL_driver.sh -y && \ - printf "\nexport LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib\n" >> /opt/intel/openvino_2022.3.0/setupvars.sh && \ - cd /opt/libusb-1.0.22 && \ - /usr/bin/install -c -m 644 libusb-1.0.pc '/usr/local/lib/pkgconfig' && \ - # MYRIAD plugins are not available for openvino 2022.3.0 release - #cp /opt/intel/openvino_2022.3.0/install_dependencies/97-myriad-usbboot.rules /etc/udev/rules.d/ && \ - ldconfig && \ -#Install protobuf - cd $MY_ROOT && \ - git clone https://github.com/protocolbuffers/protobuf.git && \ - cd protobuf && \ - git checkout v3.16.0 && \ - git submodule update --init --recursive && \ - mkdir build_source && cd build_source && \ - cmake ../cmake -DCMAKE_INSTALL_LIBDIR=lib64 -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_INSTALL_PREFIX=/usr -DCMAKE_INSTALL_SYSCONFDIR=/etc -DCMAKE_POSITION_INDEPENDENT_CODE=ON -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release && \ - make -j$(nproc) && \ - make install && \ -# Build onnxruntime - cd $MY_ROOT && \ - pip3 install numpy wheel setuptools cython onnx && \ - git clone --recursive -b ${ONNXRUNTIME_BRANCH} ${ONNXRUNTIME_REPO} && \ - bash onnxruntime/dockerfiles/scripts/install_common_deps.sh && \ - ln -s cmake-* cmake-dir && \ - source /opt/intel/openvino_2022.3.0/setupvars.sh && \ - cd /code/onnxruntime && ./build.sh --allow_running_as_root --config Release --update --build --parallel --use_openvino ${DEVICE} --build_shared_lib --build_wheel && \ - pip3 install /code/onnxruntime/build/Linux/Release/dist/*-linux_x86_64.whl && \ -# Clean up - cd ${MY_ROOT} && rm -rf onnxruntime && rm -rf Python-3.8.9 && rm -rf protobuf - -# Deploy stage -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -RUN usermod -a -G video,users,render ${BUILD_USER} -ENV WORKDIR_PATH /home/${BUILD_USER} - -WORKDIR ${WORKDIR_PATH} -USER ${BUILD_USER} diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index 35a676383337..c242933f677f 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -5,14 +5,14 @@ # Dockerfile to run ONNXRuntime with ROCm integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm5.4_ubuntu20.04_py3.7_pytorch_1.12.1 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main WORKDIR /code -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/README.md b/dockerfiles/README.md index f226ebfe8b19..a2e99d66d465 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -277,7 +277,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropiate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime). ## MIGraphX -**Ubuntu 20.04, ROCm5.4, AMDMIGraphX v1.2** +**Ubuntu 20.04, ROCm6.0, MIGraphX** 1. Build the docker image from the Dockerfile in this repository. ``` @@ -291,7 +291,7 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` ## ROCm -**Ubuntu 20.04, ROCm5.4** +**Ubuntu 20.04, ROCm6.0** 1. Build the docker image from the Dockerfile in this repository. ``` diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 131db5d8d9b3..3d984a54c049 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -41,6 +41,7 @@ Do not modify directly.* * com.microsoft.Gelu * com.microsoft.GemmFastGelu * com.microsoft.GemmFloat8 + * com.microsoft.GemmaRotaryEmbedding * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm @@ -78,6 +79,7 @@ Do not modify directly.* * com.microsoft.QLinearSigmoid * com.microsoft.QLinearSoftmax * com.microsoft.QLinearWhere + * com.microsoft.QMoE * com.microsoft.QOrderedAttention * com.microsoft.QOrderedGelu * com.microsoft.QOrderedLayerNormalization @@ -155,6 +157,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+
rotary_embedding_dim : int
+
Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
@@ -459,7 +463,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -1586,6 +1590,8 @@ This version of the operator has been available since version 1 of the 'com.micr
payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.
ep_sdk_version : string
(Optional) SDK version used to convert the model.
+
hardware_architecture : string
+
(Optional) Hardware architecture.
main_context : int
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
notes : string
@@ -2205,6 +2211,69 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GemmaRotaryEmbedding** + + GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py. + + Here's onnxscript that was tested + + from onnxscript import FLOAT, FLOAT16, script + from onnxscript import opset18 as op + + @script() + def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]): + sin_val = op.Sin(emb) + casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported. + cos_val = op.Cos(emb) + casted_cos = op.Cast(cos_val, to=10) + unsqueezed_sin = op.Unsqueeze(casted_sin, [1]) + unsqueezed_cos = op.Unsqueeze(casted_cos, [1]) + q_embed = (q * casted_cos) + (q_rot * casted_sin) + k_embed = (k * casted_cos) + (k_rot * casted_sin) + return q_embed, k_embed + + onnx_model = gemma_rotary_embedding.to_model_proto() + + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
emb : U
+
embeddding - 3D tensor with shape (batch_size, seq_len, dim)
+
q : T
+
q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
q_rot : T
+
half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
k : T
+
k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
k_rot : T
+
k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
+ +#### Outputs + +
+
output1 : T
+
4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
output2 : T
+
4D tensor with shape (batch_size, num_heads, seq_len, dim)
+
+ +#### Type Constraints + +
+
T : tensor(float16)
+
Constrain input and output types to float16 tensors.
+
U : tensor(float)
+
Constrain input 0 type to float tensors
+
+ + ### **com.microsoft.GreedySearch** Greedy Search for text generation. @@ -2248,7 +2317,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2394,24 +2463,28 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
kv_num_heads : int (required)
Number of attention heads for k and v
local_window_size : int
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs +#### Inputs (7 - 9)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-
value : T
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
past_key (optional) : T
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2421,6 +2494,10 @@ This version of the operator has been available since version 1 of the 'com.micr
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
total_sequence_length : M
Scalar tensor of total sequence length (past + new).
+
cos_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs @@ -2437,7 +2514,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16)
+
T : tensor(float16), tensor(bfloat16)
Constrain input and output to float tensors.
M : tensor(int32)
Constrain mask to int tensor.
@@ -2783,7 +2860,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Constrain input A data type to 8-bit integer tensor.
T2 : tensor(int8), tensor(uint8)
Constrain input B data type to 8-bit integer tensor.
-
T3 : tensor(float)
+
T3 : tensor(float), tensor(float16)
Constrain input a_scale, b_scale and output Y data type as float tensor.
@@ -2796,22 +2873,23 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - + Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. #### Version @@ -2832,17 +2910,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4) +#### Inputs (3 - 5)
A : T1
The input tensor, not quantized
B : T2
-
1-dimensional data blob
+
1 or 2 dimensional data blob
scales : T1
quantization scale
-
zero_points (optional) : T2
+
zero_points (optional) : T3
quantization zero points
+
g_idx (optional) : T4
+
group_idx
#### Outputs @@ -2857,8 +2937,12 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float), tensor(float16)
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8)
-
Constrain quantized weight types to uint8.
+
T2 : tensor(uint8), tensor(int32)
+
Constrain quantized weight types to uint8/int32.
+
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+
Constrain quantized zero point types to uint8/int32/float16/float.
+
T4 : tensor(int32)
+
the index tensor.
@@ -2912,8 +2996,8 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.MoE** Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, - GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) - usually uses top 32 experts. + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). #### Version @@ -2927,9 +3011,11 @@ This version of the operator has been available since version 1 of the 'com.micr
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
k : int
Number of top experts to select from expert pool
+
normalize_routing_weights : int
+
Whether to normalize routing weights
-#### Inputs (4 - 6) +#### Inputs (5 - 8)
input : T
@@ -2938,12 +3024,16 @@ This version of the operator has been available since version 1 of the 'com.micr
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
3D input tensor with shape (num_experts, hidden_size, inter_size)
-
fc2_experts_weights : T
-
3D input tensor with shape (num_experts, inter_size, hidden_size)
fc1_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
+
fc2_experts_weights : T
+
3D input tensor with shape (num_experts, inter_size, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
+
fc3_experts_weights (optional) : T
+
3D optional input tensor with shape (num_experts, hidden_size, inter_size)
+
fc3_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
#### Outputs @@ -3027,6 +3117,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
unidirectional : int
+
Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8) @@ -4234,6 +4326,69 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.QMoE** + + Int4 MoE + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation_type : string
+
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
k : int
+
Number of top experts to select from expert pool
+
normalize_routing_weights : int
+
Whether to normalize routing weights
+
+ +#### Inputs (7 - 11) + +
+
input : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
router_probs : T
+
2D input tensor with shape (num_rows, num_experts)
+
fc1_experts_weights : T1
+
3D input tensor with shape (num_experts, hidden_size, inter_size / 2)
+
fc1_scales : T
+
2D input tensor with shape (num_experts, inter_size)
+
fc1_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
fc2_experts_weights : T1
+
3D input tensor with shape (num_experts, inter_size, hidden_size / 2)
+
fc2_scales : T
+
2D input tensor with shape (num_experts, hidden_size)
+
fc2_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, hidden_size)
+
fc3_experts_weights (optional) : T1
+
3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)
+
fc3_scales (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
fc3_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
+ +#### Outputs + +
+
output : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float16)
+
Constrain input and output types to float or float16 tensors.
+
T1 : tensor(uint8)
+
Constrain weights type to uint8 tensors.
+
+ + ### **com.microsoft.QOrderedAttention** Quantized version of simplified Multi-Head Self Attention(using int8 with specific matrix Layout). @@ -5017,6 +5172,10 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
+
num_heads : int
+
Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
+
rotary_embedding_dim : int
+
Rotary embedding dimension. Default value is 0.
scale : float
Custom scale will be used if specified. Default value is 1.0
@@ -5029,9 +5188,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
sin_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
#### Outputs @@ -5044,7 +5203,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output types to float tensors.
M : tensor(int64)
Constrain input and output types to integer tensors
@@ -5136,7 +5295,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5725,12 +5884,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5743,15 +5904,23 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int
+
The id of the token that indicates LM starts
+
transcribe_token_id : int
+
The id of the transcribe task
+
translate_token_id : int
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 14) +#### Inputs (5 - 15)
input_ids : F
@@ -5765,11 +5934,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5779,9 +5948,11 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
temperature (optional) : T
+
Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
#### Outputs (1 - 5) @@ -5792,11 +5963,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/docs/How_To_Update_ONNX_Dev_Notes.md b/docs/How_To_Update_ONNX_Dev_Notes.md index fd787b017617..264c620a8e69 100644 --- a/docs/How_To_Update_ONNX_Dev_Notes.md +++ b/docs/How_To_Update_ONNX_Dev_Notes.md @@ -17,9 +17,12 @@ git add onnx 1. Update [cgmanifests/generated/cgmanifest.json](/cgmanifests/generated/cgmanifest.json). This file should be generated. See [cgmanifests/README](/cgmanifests/README.md) for instructions. -1. Update [tools/ci_build/github/linux/docker/scripts/requirements.txt](/tools/ci_build/github/linux/docker/scripts/requirements.txt) - and [tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt](/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt). - Update the commit hash for `git+http://github.com/onnx/onnx.git@targetonnxcommithash#egg=onnx`. +1. Update Python requirements files with the updated ONNX version (e.g., `onnx==1.16.0`) or commit hash if building from source (e.g., `git+http://github.com/onnx/onnx.git@targetonnxcommithash#egg=onnx`). +- [onnxruntime/test/python/requirements.txt](/onnxruntime/test/python/requirements.txt) +- [tools/ci_build/github/linux/docker/scripts/requirements.txt](/tools/ci_build/github/linux/docker/scripts/requirements.txt) +- [tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt](/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt) +- [tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt](/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt) +- Run `git grep -rn "onnx==1" .` to find other locations and update this document if necessary. 1. If there is any change to `cmake/external/onnx/onnx/*.in.proto`, you need to regenerate OnnxMl.cs. [Building onnxruntime with Nuget](https://onnxruntime.ai/docs/build/inferencing.html#build-nuget-packages) will do diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index 97f7e7ff2c14..d08ba7b8f83c 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -30,10 +30,10 @@ Integrate models using `ORTModule`. ``` There are two modes to enable the memory optimizations: -- Aggressively Recompute All Within Each Transformer Layer, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention+MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected. -- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. +- Transformer layerwise recompute, e.g. aggressively recompute all supported nodes within each transformer layer (usually including attention and mlp sublayers), enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected. +- Manual selected subgraph recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. -### Mode 1 - Simple Usage (Aggressively Recompute All Within Each Transformer Layer) +### Mode 1 - Simple Usage (Transformer Layerwise Recompute) 1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1` @@ -51,6 +51,7 @@ There are two modes to enable the memory optimizations: - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` 3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. +4. By `export ORTMODULE_MEMORY_OPT_LEVEL=2`, all plans including compromised recomptable subgraphs will also be enabled. ### Mode 2 - Advanced Usage (User Selected Subgraph Recompute) diff --git a/docs/ORTModule_Convergence_Notes.md b/docs/ORTModule_Convergence_Notes.md index 791b6c32c9b4..2374e7b7c538 100644 --- a/docs/ORTModule_Convergence_Notes.md +++ b/docs/ORTModule_Convergence_Notes.md @@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example: ```diff -+ from onnxruntime.training.utils import inspect_activation ++ from onnxruntime.training.utils.hooks import inspect_activation class BloomForCausalLM(BloomPreTrainedModel): def __init__(self, config: BloomConfig): ... diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index bede16204d42..54137937ad56 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -246,7 +246,7 @@ to standard outputs. #### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input data sparsity based performance optimizations. ```bash @@ -287,12 +287,25 @@ A classical usage of disabling the deep copy: when the deep copy before module e #### ORTMODULE_MEMORY_OPT_LEVEL - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. +- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. + - Setting the level to be 1 means all detected recomputable subgraphs (NOT including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - Setting the level to be 2 means all detected recomputable subgraphs (including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - When the level is 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. ```bash export ORTMODULE_MEMORY_OPT_LEVEL=0 ``` +#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes. + + ```bash + export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 # Enable + export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1ce9b3254d91..5bae5ea62657 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -51,7 +51,8 @@ Do not modify directly.* |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| @@ -68,7 +69,8 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|20+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||20|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| @@ -85,7 +87,8 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Det|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float)| @@ -111,7 +114,8 @@ Do not modify directly.* |Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[8, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |EyeLike|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)
**T2** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)| -|Flatten|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Flatten|*in* input:**T**
*out* output:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 8]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -127,6 +131,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[9, 10]|**T** = tensor(double), tensor(float)| @@ -147,21 +152,23 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[11, 12]|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|21+|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[19, 20]|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[14, 15]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|If|*in* cond:**B**
*out* outputs:**V**|19+|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|If|*in* cond:**B**
*out* outputs:**V**|21+|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[19, 20]|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 15]|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| @@ -182,7 +189,8 @@ Do not modify directly.* |LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| -|Loop|*in* M:**I**
*in* cond:**B**
*in* v_initial:**V**
*out* v_final_and_scan_outputs:**V**|19+|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Loop|*in* M:**I**
*in* cond:**B**
*in* v_initial:**V**
*out* v_final_and_scan_outputs:**V**|21+|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[19, 20]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 15]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -238,7 +246,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float)| |||[9, 15]|**T** = tensor(float)| |||[7, 8]|**T** = tensor(float)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||18|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -249,8 +258,9 @@ Do not modify directly.* |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 11]|**T** = tensor(double), tensor(float)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| -|QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| +|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| +|||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float)
**T1** = tensor(int32)| @@ -278,7 +288,8 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| @@ -287,7 +298,8 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| @@ -303,10 +315,12 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|RegexFullMatch|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(string)
**T2** = tensor(bool)| |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)| |||13|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|||[19, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[14, 18]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| @@ -323,7 +337,8 @@ Do not modify directly.* |STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |Scale|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| -|Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|19+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|21+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[19, 20]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 15]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -344,7 +359,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[19, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[15, 18]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -356,7 +372,8 @@ Do not modify directly.* |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| -|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| @@ -377,10 +394,13 @@ Do not modify directly.* |SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)| |Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| -|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|StringConcat|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|20+|**T** = tensor(string)| |StringNormalizer|*in* X:**tensor(string)**
*out* Y:**tensor(string)**|10+|**X** = tensor(string)| +|StringSplit|*in* X:**T1**
*out* Y:**T2**
*out* Z:**T3**|20+|**T1** = tensor(string)
**T2** = tensor(string)
**T3** = tensor(int64)| |Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| @@ -398,11 +418,13 @@ Do not modify directly.* |TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| |||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| -|Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* X:**T**
*out* Y:**T**
*out* indices:**tensor(int64)**
*out* inverse_indices:**tensor(int64)**
*out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)| -|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|9|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)| @@ -420,7 +442,8 @@ Do not modify directly.* |DictVectorizer|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = map(int64,tensor(double)), map(int64,tensor(float)), map(int64,tensor(string)), map(string,tensor(double)), map(string,tensor(float)), map(string,tensor(int64))
**T2** = tensor(double), tensor(float), tensor(int64), tensor(string)| |FeatureVectorizer|*in* X:**T1**
*out* Y:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Imputer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(int64)| -|LabelEncoder|*in* X:**T1**
*out* Y:**T2**|2+|**T1** = tensor(float), tensor(int64), tensor(string)
**T2** = tensor(float), tensor(int64), tensor(string)| +|LabelEncoder|*in* X:**T1**
*out* Y:**T2**|4+|**T1** = tensor(double), tensor(float), tensor(int64), tensor(string)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int64), tensor(string)| +|||[2, 3]|**T1** = tensor(float), tensor(int64), tensor(string)
**T2** = tensor(float), tensor(int64), tensor(string)| |||1|**T1** = tensor(int64), tensor(string)
**T2** = tensor(int64), tensor(string)| |LinearClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |LinearRegressor|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(float)| @@ -463,7 +486,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -493,7 +516,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -600,6 +623,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -611,6 +635,7 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -622,6 +647,11 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| @@ -676,7 +706,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| @@ -724,7 +755,8 @@ Do not modify directly.* |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -737,7 +769,9 @@ Do not modify directly.* |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -756,7 +790,7 @@ Do not modify directly.* |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -775,7 +809,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -834,22 +868,24 @@ Do not modify directly.* |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| +|GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| @@ -862,7 +898,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -870,7 +906,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | @@ -903,7 +939,8 @@ Do not modify directly.* |Asinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| |Atan|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Atanh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|AveragePool|*in* X:**T**
*out* Y:**T**|19+|**T** = tensor(float), tensor(float16)| +|||11+|**T** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| |BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(float), tensor(float16)| @@ -915,10 +952,12 @@ Do not modify directly.* |BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float), tensor(float16)| @@ -945,16 +984,18 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Dropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T1**|7+|**T** = tensor(float), tensor(float16)| -|DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| @@ -997,7 +1038,8 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1030,7 +1072,8 @@ Do not modify directly.* |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| |LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|LpPool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|LpPool|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(float), tensor(float16)| +|||11+|**T** = tensor(float), tensor(float16)| |||2+|**T** = tensor(float), tensor(float16)| |MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||9+|**T** = tensor(float), tensor(float16)| @@ -1090,8 +1133,9 @@ Do not modify directly.* |||12+|**T** = tensor(float), tensor(float16), tensor(int32)
**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| -|QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| @@ -1142,11 +1186,12 @@ Do not modify directly.* |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| -|||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||11+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| @@ -1170,7 +1215,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| @@ -1180,7 +1226,8 @@ Do not modify directly.* |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| @@ -1239,14 +1286,21 @@ Do not modify directly.* |BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| +|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| +|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QLinearAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QLinearConcat|*in* Y_scale:**TF**
*in* Y_zero_point:**T8**
*in* inputs:**TV**
*out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)
**TF** = tensor(float)
**TV** = tensor(float), tensor(int8), tensor(uint8)| +|QLinearGlobalAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/docs/python/README.rst b/docs/python/README.rst index 32bb3729e01d..bbc8571fe3f1 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime ; +// TODO: When other compilers support std::chrono::operator<<, update this. +// TODO: Check support for other compilers' version before enable C++20 for other compilers. +// Xcode added support for C++20's std::chrono::operator<< in SDK version 14.4. +#if __cplusplus >= 202002L && __MAC_OS_X_VERSION_MAX_ALLOWED >= 140400L +namespace timestamp_ns = std::chrono; +#else +namespace timestamp_ns = ::date; +#endif + #ifndef NDEBUG ORT_ATTRIBUTE_UNUSED static bool vlog_enabled = true; // Set directly based on your needs. #else @@ -75,6 +86,21 @@ struct Category { // TODO: What other high level categories are meaningful? Model? Optimizer? Execution? }; +/// +/// ORT TraceLogging keywords for categories of dynamic logging enablement +/// +enum class ORTTraceLoggingKeyword : uint64_t { + Session = 0x1, // ORT Session TraceLoggingWrite + Logs = 0x2, // LOGS() Macro ORT logs. Pair with an appropriate level depending on detail required + Reserved1 = 0x4, // Reserved if we want to add some specific sub-categories instead of just LOGS() or other uses + Reserved2 = 0x8, + Reserved3 = 0x10, + Reserved4 = 0x20, + Reserved5 = 0x40, + Reserved6 = 0x80, + Profiling = 0x100 // Enables profiling. At higher levels >5 can impact inference performance +}; + class ISink; class Logger; class Capture; @@ -333,5 +359,17 @@ unsigned int GetThreadId(); */ unsigned int GetProcessId(); +/** + If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then adds to the existing logger. +*/ +std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity, + logging::Severity etwSeverity); + +/** + If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then can override the logging level. + But this overrided level only applies to the ETW sink. The original logger(s) retain their original logging level +*/ +Severity OverrideLevelWithEtw(Severity originalSeverity); + } // namespace logging } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 9015b23296e0..097873c5e365 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -80,7 +80,6 @@ class IAllocator { virtual void Free(void* p) = 0; - // TODO: Find a better name than Reserve() and update in all places. // Reserve() is an interface exposed for an implementation of IAllocator // to optionally implement some allocation logic that by-passes any arena-based // logic that may be housed in the Alloc() implementation. diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index fbeee8a2aedc..3a3b5cb6888f 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -305,7 +305,7 @@ class CallableDispatchableHelper { return 0; } - void CheckCalledOnce() { + void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); } }; diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index ea4f52f99649..16ad943a5f47 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -33,6 +33,8 @@ class Node; #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +struct OrtRunOptions; + namespace onnxruntime { /** @@ -51,6 +53,8 @@ struct NodeComputeInfo { DestroyFunctionStateFunc release_state_func; }; +using RunOptions = ::OrtRunOptions; + enum class DataLayout { NCHW, NHWC, @@ -59,14 +63,11 @@ enum class DataLayout { class IExecutionProvider { protected: - IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false) - : IExecutionProvider(type, OrtDevice(), use_metadef_id_creator) {} + IExecutionProvider(const std::string& type) + : IExecutionProvider(type, OrtDevice()) {} - IExecutionProvider(const std::string& type, OrtDevice device, bool use_metadef_id_creator = false) + IExecutionProvider(const std::string& type, OrtDevice device) : default_device_(device), type_{type} { - if (use_metadef_id_creator) { - metadef_id_generator_ = std::make_unique(); - } } /* @@ -187,7 +188,7 @@ class IExecutionProvider { Run may not be finished on device This function should be regarded as the point after which a new Run would start to submit commands from CPU */ - virtual common::Status OnRunStart() { return Status::OK(); } + virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } /** Called when InferenceSession::Run ended @@ -195,25 +196,27 @@ class IExecutionProvider { may not be finished on device This function should be regarded as the point that all commands of current Run has been submmited by CPU */ - virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } + virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { + return Status::OK(); + } /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for - the provider. Currently only CUDA execution provider supports it. + the provider. */ virtual bool IsGraphCaptureEnabled() const { return false; } /** - Indicate whether the graph has been captured and instantiated. Currently - only CUDA execution provider supports it. + Indicate whether the graph has been captured and instantiated. */ - virtual bool IsGraphCaptured() const { return false; } + virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; } /** - Run the instantiated graph. Currently only CUDA execution provider supports - it. + Run the instantiated graph. */ - virtual common::Status ReplayGraph() { return Status::OK(); } + virtual common::Status ReplayGraph(int /*graph_annotation_id*/) { + return Status::OK(); + } /** Called when session creation is complete @@ -274,19 +277,6 @@ class IExecutionProvider { return logger_; } - /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance. - The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models. - @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph. - @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model. - This is created using the model path if available, - or the model input names and the output names from all nodes in the main graph. - @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches - compiled kernels, so the name must be unique and deterministic across models and sessions. - NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and - virtual, and ModelMetadefIdGenerator but be defined in the header as well. - */ - virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const; - virtual std::unique_ptr GetProfiler() { return {}; } @@ -326,23 +316,19 @@ class IExecutionProvider { */ virtual std::vector CreatePreferredAllocators() { return std::vector(); }; + /** + * Get the array of pointers for EPContext nodes + * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it. + * Default return an empty vector if not provided by the Execution Provider + */ + virtual const InlinedVector GetEpContextNodes() const { + return InlinedVector(); + } + private: const std::string type_; // It will be set when this object is registered to a session const logging::Logger* logger_ = nullptr; - - // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across - // multiple sessions. - class ModelMetadefIdGenerator { - public: - int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash); - - private: - std::unordered_map main_graph_hash_; // map graph instance hash to model contents hash - std::unordered_map model_metadef_id_; // current unique id for model - }; - - std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h index b31c85e32f80..a0bbfe50a700 100644 --- a/include/onnxruntime/core/framework/op_kernel_info.h +++ b/include/onnxruntime/core/framework/op_kernel_info.h @@ -28,7 +28,8 @@ class OpKernelInfo : public OpNodeProtoHelper { const std::unordered_map& constant_initialized_tensors, const OrtValueNameIdxMap& mlvalue_name_idx_map, const DataTransferManager& data_transfer_mgr, - const AllocatorMap& allocators = {}); + const AllocatorMap& allocators, + const ConfigOptions& config_options); OpKernelInfo(const OpKernelInfo& other); @@ -50,6 +51,8 @@ class OpKernelInfo : public OpNodeProtoHelper { const AllocatorMap& GetAllocators() const { return allocators_; } + const ConfigOptions& GetConfigOptions() const { return config_options_; } + private: ORT_DISALLOW_MOVE(OpKernelInfo); ORT_DISALLOW_ASSIGNMENT(OpKernelInfo); @@ -64,6 +67,7 @@ class OpKernelInfo : public OpNodeProtoHelper { const DataTransferManager& data_transfer_mgr_; ProtoHelperNodeContext proto_helper_context_; const AllocatorMap& allocators_; + const ConfigOptions& config_options_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index 5444c825d799..789c3b13f2c3 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -45,5 +45,5 @@ struct OrtRunOptions { }; namespace onnxruntime { -using RunOptions = OrtRunOptions; +using RunOptions = ::OrtRunOptions; } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index c235ee904762..26d78133b52f 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -100,6 +100,8 @@ class Stream { return nullptr; } + virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; } + private: StreamHandle handle_; const OrtDevice& device_; diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 9b26ba914c7d..8e04050d089a 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; +constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 22827d43b200..3b417a362d2c 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -21,7 +21,7 @@ #pragma warning(pop) #endif -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/gsl.h" @@ -621,6 +621,22 @@ class Node { // Reference to the function template defined in the model. const FunctionTemplate* func_template_ = nullptr; + + // set/clear NodeProto that the Node was created from. + // Set by Graph ctor when loading a model from file. + // Cleared after first call to onnx::check_node in VerifyNodeAndOpMatch when the first Graph::Resolve runs. + void SetOriginalNodeProto(const ONNX_NAMESPACE::NodeProto* node_proto) { + original_node_proto_ = node_proto; + } + + const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto() const { + return original_node_proto_; + } + + // NodeProto that the Node was created from. We temporarily set this as a performance optimization to avoid calling + // Node::ToProto when running onnx::check_node in the first Graph::Resolve. At that point we know all the nodes are + // unchanged from the original model. + const ONNX_NAMESPACE::NodeProto* original_node_proto_ = nullptr; #endif // Execution priority, lower value for higher priority @@ -753,7 +769,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned. @param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. - @remarks check_outer_scope of true is not supported in a minimal build */ const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 3cdbb07099ca..1023d5031018 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -165,7 +165,8 @@ class GraphViewer { if a const initializer is part of the underlying Graph but not part of this GraphViewer, it will still be returned instead of nullptr */ - const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, + bool check_outer_scope = true) const; /** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */ const Node* ParentNode() const noexcept { return graph_->ParentNode(); } diff --git a/include/onnxruntime/core/providers/cann/cann_provider_options.h b/include/onnxruntime/core/providers/cann/cann_provider_options.h index ac60fbe4a293..51b423e68110 100644 --- a/include/onnxruntime/core/providers/cann/cann_provider_options.h +++ b/include/onnxruntime/core/providers/cann/cann_provider_options.h @@ -16,6 +16,7 @@ struct OrtCANNProviderOptions { int enable_cann_graph; // Flag indicating if prioritizing the use of // CANN's graph-running capabilities int dump_graphs; // Flag indicating if dumping graphs + int dump_om_model; // Flag indicating if dumping om model std::string precision_mode; // Operator Precision Mode std::string op_select_impl_mode; // Operator-level model compilation options: // Mode selection diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 03715eb5b78b..55abb90b981f 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -28,9 +28,12 @@ enum COREMLFlags { // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008, + // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + COREML_FLAG_CREATE_MLPROGRAM = 0x010, + // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it - COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES, + COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM, }; #ifdef __cplusplus diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index d73d551920d4..7104e70c3a8a 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -16,9 +16,10 @@ #include "core/providers/custom_op_context.h" #include #include +#ifndef USE_CUDA_MINIMAL #include #include - +#endif namespace Ort { namespace Custom { @@ -28,38 +29,47 @@ struct CudaContext : public CustomOpContext { cudnnHandle_t cudnn_handle = {}; cublasHandle_t cublas_handle = {}; OrtAllocator* deferred_cpu_allocator = {}; + // below are cuda ep options + int16_t device_id = 0; + int32_t arena_extend_strategy = 0; + int32_t cudnn_conv_algo_search = 0; + bool cudnn_conv_use_max_workspace = true; + bool cudnn_conv1d_pad_to_nc1d = false; + bool enable_skip_layer_norm_strict_mode = false; + bool prefer_nhwc = false; + bool use_tf32 = true; void Init(const OrtKernelContext& kernel_ctx) { - const auto& ort_api = Ort::GetApi(); - void* resource = {}; - OrtStatus* status = nullptr; - - status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource); - if (status) { - ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - cuda_stream = reinterpret_cast(resource); - - resource = {}; - status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource); - if (status) { - ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - cudnn_handle = reinterpret_cast(resource); + cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); + cudnn_handle = FetchResource(kernel_ctx, CudaResource::cudnn_handle_t); + cublas_handle = FetchResource(kernel_ctx, CudaResource::cublas_handle_t); + deferred_cpu_allocator = FetchResource(kernel_ctx, CudaResource::deferred_cpu_allocator_t); + + device_id = FetchResource(kernel_ctx, CudaResource::device_id_t); + arena_extend_strategy = FetchResource(kernel_ctx, CudaResource::arena_extend_strategy_t); + cudnn_conv_algo_search = FetchResource(kernel_ctx, CudaResource::cudnn_conv_algo_search_t); + cudnn_conv_use_max_workspace = FetchResource(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t); + + cudnn_conv1d_pad_to_nc1d = FetchResource(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t); + enable_skip_layer_norm_strict_mode = FetchResource(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); + prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); + use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); + } - resource = {}; - status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource); - if (status) { - ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + template + T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { + if constexpr (sizeof(T) > sizeof(void*)) { + ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT); } - cublas_handle = reinterpret_cast(resource); - - resource = {}; - status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource); + const auto& ort_api = Ort::GetApi(); + void* resource = {}; + OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource); if (status) { - ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION); } - deferred_cpu_allocator = reinterpret_cast(resource); + T t = {}; + memcpy(&t, &resource, sizeof(T)); + return t; } void* AllocDeferredCpuMem(size_t size) const { diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 82bb8ba83be4..6d53760ab60b 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -37,4 +37,5 @@ struct OrtCUDAProviderOptionsV2 { // The strict mode has better accuracy but lower performance. int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not + int use_tf32 = 1; // use TF32 }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index 8c3ed46ade6a..00e7dec5727d 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -3,11 +3,20 @@ #include "core/providers/resource.h" -#define ORT_CUDA_RESOUCE_VERSION 2 +#define ORT_CUDA_RESOUCE_VERSION 3 enum CudaResource : int { - cuda_stream_t = cuda_resource_offset, + cuda_stream_t = cuda_resource_offset, // 10000 cudnn_handle_t, cublas_handle_t, deferred_cpu_allocator_t, -}; \ No newline at end of file + // below are cuda ep options + device_id_t, // 10004 + arena_extend_strategy_t, + cudnn_conv_algo_search_t, + cudnn_conv_use_max_workspace_t, + cudnn_conv1d_pad_to_nc1d_t, + enable_skip_layer_norm_strict_mode_t, + prefer_nhwc_t, + use_tf32_t, +}; diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 7d7f05193f48..33b98edf3bf4 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -27,14 +27,8 @@ typedef struct IDMLDevice IDMLDevice; #include "onnxruntime_c_api.h" #ifdef __cplusplus -extern "C" { -#endif -enum OrtDmlPerformancePreference { - Default = 0, - HighPerformance = 1, - MinimumPower = 2 -}; +extern "C" { enum OrtDmlDeviceFilter : uint32_t { #ifdef ENABLE_NPU_ADAPTER_ENUMERATION @@ -54,11 +48,33 @@ inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } +#else + +typedef enum OrtDmlDeviceFilter { +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +#else + Gpu = 1 << 0, +#endif +} OrtDmlDeviceFilter; + +#endif + +typedef enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +} OrtDmlPerformancePreference; + struct OrtDmlDeviceOptions { OrtDmlPerformancePreference Preference; OrtDmlDeviceFilter Filter; }; +typedef struct OrtDmlDeviceOptions OrtDmlDeviceOptions; + /** * [[deprecated]] * This export is deprecated. diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 680ce1cc5b9a..32a9f06464ac 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -11,6 +11,8 @@ /// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions. ///
struct OrtTensorRTProviderOptionsV2 { + OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator + int device_id{0}; // cuda device id. int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. void* user_compute_stream{nullptr}; // user specified CUDA compute stream. @@ -46,4 +48,26 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT + + /* + * Please note that there are rules for using following context model related provider options: + * + * 1. In the case of dumping the context model and loading the context model, + * for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be + * the absolute path or relative path that is outside of context model directory. + * It means engine cache needs to be in the same directory or sub-directory of context model. + * + * 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. + * For example: + * If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, + * if "trt_ep_context_file_path" is "./context_model_dir", + * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" + * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" + * + */ + int trt_dump_ep_context_model{0}; // Dump EP context node model + const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path. + int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data + + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index dbd5ad41255f..e7b8f1487112 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -29,15 +29,16 @@ */ #pragma once -#include +#include #include +#include #include /** \brief The API version defined in this header * * This value is used by some API functions to behave as this version of the header expects. */ -#define ORT_API_VERSION 17 +#define ORT_API_VERSION 18 #ifdef __cplusplus extern "C" { @@ -318,6 +319,12 @@ typedef struct OrtAllocator { void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator + /** + * @brief Optional allocation function to use for memory allocations made during session initialization. + * Use this function if you want to separate allocations made by ORT during Run() calls from + * those made during session initialization. This allows for separate memory management strategies for these allocations. + */ + void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes } OrtAllocator; typedef void(ORT_API_CALL* OrtLoggingFunction)( @@ -495,6 +502,7 @@ typedef struct OrtROCMProviderOptions { has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{}, + enable_hip_graph{false}, tunable_op_enable{false}, tunable_op_tuning_enable{false}, tunable_op_max_tuning_duration_ms{} {} @@ -547,6 +555,8 @@ typedef struct OrtROCMProviderOptions { */ OrtArenaCfg* default_memory_arena_cfg; + int enable_hip_graph; + /** \brief Enable TunableOp for using. * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. @@ -1833,14 +1843,28 @@ struct OrtApi { /** \brief Used for custom operators, get an input of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the input of the kernel. If the input is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] input index. See KernelContext_GetInputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the input is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); /** \brief Used for custom operators, get an output of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the output of the kernel. If the output is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] output index. See KernelContext_GetOutputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the output is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); @@ -3594,10 +3618,11 @@ struct OrtApi { * QNN supported keys: * "backend_path": file path to QNN backend library. * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. + * "profiling_file_path": QNN profiling file path if ETW not enabled. * "rpc_control_latency": QNN RPC control latency. * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", - * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". + * "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and * may alter model/EP partitioning. Use only for debugging. @@ -3607,6 +3632,18 @@ struct OrtApi { * - "1": Faster preparation time, less optimal graph. * - "2": Longer preparation time, more optimal graph. * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). + "enable_htp_fp16_precision": Only used for float32 model. + Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + - "0": Default. With fp32 precision. + - "1": With fp16 precision. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", @@ -4417,7 +4454,7 @@ struct OrtApi { ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); /** - * Get a EP resoure. + * Get a EP resource. * E.g. a cuda stream or a cublas handle * * \param context - Kernel context @@ -4515,6 +4552,85 @@ struct OrtApi { * \since Version 1.17. */ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); + + /** \brief Set whether to use deterministic compute. + * + * Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible. + * Note that this most likely will have a performance cost. + * + * \param[in] options + * \param[in] value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value); + + /** + * Run fn in parallel + * + * \param[in] context + * \param[in] fn Function accepting usr_data and an integer as iterator + * \param[in] total The number of times fn is to be invoked + * \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit + * \param[in] usr_data User data to be passed back to fn + * + * \since Version 1.17. + */ + ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data); + + /** \brief Append OpenVINO execution provider to the session options + * + * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object. + * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] count_or_bytes How many bytes is this scratch buffer + * \param[out] out A pointer to the scrach buffer + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); + + /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object + * + * \param[in] info OrtKernelInfo instance + * \param[in] mem_type OrtMemType object + * \param[out] out A pointer to OrtAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); }; /* @@ -4612,6 +4728,21 @@ struct OrtCustomOp { // Get start range int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); + + // Get the inplace_map that defines which output can reuse which input + // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays + // when return, output (*output_index)[i] may reuse the input (*input_index[i]). + // The return value is the size of these 2 arrays. + // Callers are responsible to delete these 2 arrays after use by calling OrtCustomOp::ReleaseMayInplace(). + size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index); + + // Release the pointer input_index and output_index allocated from GetMayInplace() function. + // If GetMayInplace() is defined, this function MUST be defined as well. + void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); + + // Same as GetMayInplace() and ReleaseMayInplace() + size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index); + void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 92c25d8688b6..fd0e3490426a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -845,6 +845,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel + SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena @@ -873,10 +874,12 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer SessionOptionsImpl& AddExternalInitializers(const std::vector& names, const std::vector& ort_values); ///< Wraps OrtApi::AddExternalInitializers - SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA - SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 - SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM - SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA + SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 + SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM + SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2 + SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options = {}); SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX @@ -898,6 +901,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction + + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI + SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); }; } // namespace detail @@ -2049,13 +2055,18 @@ struct KernelContext { explicit KernelContext(OrtKernelContext* context); size_t GetInputCount() const; size_t GetOutputCount() const; + // If input is optional and is not present, the method returns en empty ConstValue + // which can be compared to nullptr. ConstValue GetInput(size_t index) const; + // If outout is optional and is not present, the method returns en empty UnownedValue + // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; OrtKernelContext* GetOrtKernelContext() const { return ctx_; } + void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const; private: OrtKernelContext* ctx_; @@ -2290,6 +2301,11 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { return static_cast(this_)->end_ver_; }; + + OrtCustomOp::GetMayInplace = nullptr; + OrtCustomOp::ReleaseMayInplace = nullptr; + OrtCustomOp::GetAliasMap = nullptr; + OrtCustomOp::ReleaseAliasMap = nullptr; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 860a27fc73f7..9d1e8c944308 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -7,17 +7,27 @@ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter // the main C++ file with implementation details. -#include +#include #include - -#define RETURN_ON_API_FAIL(expression) \ - { \ - auto err = (expression); \ - if (err) { \ - return Status(err); \ - } \ +#include +#include + +// Convert OrtStatus to Ort::Status and return +// instead of throwing +#define ORT_CXX_RETURN_ON_API_FAIL(expression) \ + { \ + auto ort_status = (expression); \ + if (ort_status) { \ + return Ort::Status(ort_status); \ + } \ } +#ifdef __cpp_if_constexpr +#define ORT_CXX_IF_CONSTEXPR if constexpr +#else +#define ORT_CXX_IF_CONSTEXPR if +#endif + namespace Ort { namespace detail { @@ -656,6 +666,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::SetGraphOptimizationLevel(G return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetDeterministicCompute(bool value) { + ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath)); @@ -859,6 +875,45 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Ope return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_, + keys.data(), values.data(), num_entries)); + + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); + + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs) { @@ -1652,6 +1707,10 @@ inline Logger KernelContext::GetLogger() const { return Logger{out}; } +inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const { + ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); +} + inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); } @@ -1918,7 +1977,7 @@ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) { OrtTensorTypeAndShapeInfo* info = {}; - RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); using InfoPtr = std::unique_ptr>; @@ -1942,9 +2001,9 @@ inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shap } } - RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size())); - RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size())); - RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size())); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size())); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info)); return Status{nullptr}; } diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 0c0af16d4e20..ee60f25da115 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -862,6 +862,11 @@ struct OrtLiteCustomOp : public OrtCustomOp { auto self = reinterpret_cast(op); return self->end_ver_; }; + + OrtCustomOp::GetMayInplace = {}; + OrtCustomOp::ReleaseMayInplace = {}; + OrtCustomOp::GetAliasMap = {}; + OrtCustomOp::ReleaseAliasMap = {}; } const std::string op_name_; @@ -1111,4 +1116,4 @@ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, } } // namespace Custom -} // namespace Ort \ No newline at end of file +} // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 1f5fcd50e185..c80b8c0c164b 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -30,3 +30,22 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor // Per default it will be set to '0' // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; + +// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. +// The value should be an integer. If the value is not set, the default value is 0 and +// ORT session only captures one cuda graph before another capture is requested. +// If the value is set to -1, cuda graph capture/replay is disabled in that run. +// User are not expected to set the value to 0 as it is reserved for internal use. +static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index df79cb6e5b21..bb5e0344895e 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -93,6 +93,15 @@ static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimizatio static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config"; #endif +// This setting if set should contain a comma separated list of optimizers names that should be disabled. +// Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer +// does not provider runtime benefits, but affects your model loading time you may disable it using this config +// entry. This option is not enabled in ORT_MINIMAL_BUILD build. +// A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc +// +// Default is an empty string which means no optimizers are disabled. +static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers"; + // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". // Using device allocators means the memory allocation is made using malloc/new. static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers"; @@ -236,7 +245,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; -// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file. +// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) // "1": enable. @@ -249,4 +258,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; diff --git a/java/README.md b/java/README.md index 2ce9a8bf62e4..5c5baeb43a27 100644 --- a/java/README.md +++ b/java/README.md @@ -14,7 +14,7 @@ Use the main project's [build instructions](https://www.onnxruntime.ai/docs/how- #### Requirements -JDK version 8 or later is required. +Java 11 or later is required to build the library. The compiled jar file will run on Java 8 or later. The [Gradle](https://gradle.org/) build system is used here to manage the Java project's dependency management, compilation, testing, and assembly. In particular, the Gradle [wrapper](https://docs.gradle.org/current/userguide/gradle_wrapper.html) at `java/gradlew[.bat]` is used, locking the Gradle version to the one specified in the `java/gradle/wrapper/gradle-wrapper.properties` configuration. @@ -35,6 +35,7 @@ This allows the CMake system to ensure all of the C/C++ compilation is achieved The Java build depends on C/C++ onnxruntime shared library and a C JNI shared library (source located in the `src/main/native` directory). The JNI shared library is the glue that allows for Java to call functions in onnxruntime shared library. Given the fact that CMake injects native dependencies during CMake builds, some gradle tasks (primarily, `build`, `test`, and `check`) may fail. +To run the Java build independently of CMake supply `-DcmakeBuildDir=`, though this will only succeed after an initial build of the native libraries has completed. When running the build script, CMake will compile the `onnxruntime` target and the JNI glue `onnxruntime4j_jni` target and expose the resulting libraries in a place where Gradle can ingest them. Upon successful compilation of those targets, a special Gradle task to build will be executed. The results will be placed in the output directory stated above. @@ -61,4 +62,4 @@ Then the corresponding C files in `./src/main/native/ai_onnxruntime*.c` may be u ### Dependencies -The Java API does not have any runtime or compile dependencies currently. +The Java API does not have any runtime or compile dependencies. diff --git a/java/build.gradle b/java/build.gradle index c0a75f8165f7..fd66ec220b78 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -3,7 +3,7 @@ plugins { id 'maven-publish' id 'signing' id 'jacoco' - id "com.diffplug.spotless" version "6.13.0" + id "com.diffplug.spotless" version "6.25.0" } allprojects { @@ -185,7 +185,7 @@ test { if (cmakeBuildDir != null) { workingDir cmakeBuildDir } - systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS']) + systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'USE_DML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS']) testLogging { events "passed", "skipped", "failed" showStandardStreams = true diff --git a/java/gradle/wrapper/gradle-wrapper.jar b/java/gradle/wrapper/gradle-wrapper.jar index ccebba7710de..d64cd4917707 100644 Binary files a/java/gradle/wrapper/gradle-wrapper.jar and b/java/gradle/wrapper/gradle-wrapper.jar differ diff --git a/java/gradle/wrapper/gradle-wrapper.properties b/java/gradle/wrapper/gradle-wrapper.properties index f396aaac2d31..4baf5a11d45a 100644 --- a/java/gradle/wrapper/gradle-wrapper.properties +++ b/java/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,8 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=1b6b558be93f29438d3df94b7dfee02e794b94d9aca4611a92cdb79b6b88e909 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.0.1-bin.zip +distributionSha256Sum=9631d53cf3e74bfa726893aee1f8994fee4e060c401335946dba2156f440f24c +distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/java/gradlew b/java/gradlew index 79a61d421cc4..1aa94a426907 100755 --- a/java/gradlew +++ b/java/gradlew @@ -83,10 +83,8 @@ done # This is normally unused # shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum @@ -133,10 +131,13 @@ location of your Java installation." fi else JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. @@ -144,7 +145,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then case $MAX_FD in #( max*) # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC2039,SC3045 MAX_FD=$( ulimit -H -n ) || warn "Could not query maximum file descriptor limit" esac @@ -152,7 +153,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then '' | soft) :;; #( *) # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC2039,SC3045 ulimit -n "$MAX_FD" || warn "Could not set maximum file descriptor limit to $MAX_FD" esac @@ -197,11 +198,15 @@ if "$cygwin" || "$msys" ; then done fi -# Collect all arguments for the java command; -# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of -# shell script including quotes and variable substitutions, so put them in -# double quotes to make sure that they get re-expanded; and -# * put everything else in single quotes, so that it's not re-expanded. + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. set -- \ "-Dorg.gradle.appname=$APP_BASE_NAME" \ diff --git a/java/src/main/java/ai/onnxruntime/OnnxJavaType.java b/java/src/main/java/ai/onnxruntime/OnnxJavaType.java index 24bf6ad4b95f..6f3ca13984f4 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxJavaType.java +++ b/java/src/main/java/ai/onnxruntime/OnnxJavaType.java @@ -45,8 +45,10 @@ public enum OnnxJavaType { /** The native value of the enum. */ public final int value; + /** The Java side type used as the carrier. */ public final Class clazz; + /** The number of bytes used by a single value of this type. */ public final int size; diff --git a/java/src/main/java/ai/onnxruntime/OnnxMap.java b/java/src/main/java/ai/onnxruntime/OnnxMap.java index 354ebec61274..68d91d0d9e74 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxMap.java +++ b/java/src/main/java/ai/onnxruntime/OnnxMap.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.logging.Logger; /** * A container for a map returned by {@link OrtSession#run(Map)}. @@ -16,6 +17,7 @@ * values: String, Long, Float, Double. */ public class OnnxMap implements OnnxValue { + private static final Logger logger = Logger.getLogger(OnnxMap.class.getName()); static { try { @@ -37,6 +39,7 @@ public enum OnnxMapValueType { FLOAT(3), /** A 64-bit floating point value. */ DOUBLE(4); + /** The native enum value. */ final int value; @@ -107,6 +110,8 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) { private final OnnxMapValueType valueType; + private boolean closed; + /** * Constructs an OnnxMap containing a reference to the native map along with the type information. * @@ -122,6 +127,7 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) { this.info = info; this.stringKeys = info.keyType == OnnxJavaType.STRING; this.valueType = OnnxMapValueType.mapFromOnnxJavaType(info.valueType); + this.closed = false; } /** @@ -146,6 +152,7 @@ public OnnxValueType getType() { */ @Override public Map getValue() throws OrtException { + checkClosed(); Object[] keys = getMapKeys(); Object[] values = getMapValues(); HashMap map = new HashMap<>(OrtUtil.capacityFromSize(keys.length)); @@ -222,10 +229,27 @@ public String toString() { return "ONNXMap(size=" + size() + ",info=" + info.toString() + ")"; } + @Override + public synchronized boolean isClosed() { + return closed; + } + /** Closes this map, releasing the native memory backing it and it's elements. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed map."); + } + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } } private native String[] getStringKeys(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index ed739dd9729d..f552badd4f83 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -54,19 +54,25 @@ final class OnnxRuntime { /** The short name of the ONNX runtime shared library */ static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime"; + /** The short name of the ONNX runtime JNI shared library */ static final String ONNXRUNTIME_JNI_LIBRARY_NAME = "onnxruntime4j_jni"; /** The short name of the ONNX runtime shared provider library */ static final String ONNXRUNTIME_LIBRARY_SHARED_NAME = "onnxruntime_providers_shared"; + /** The short name of the ONNX runtime CUDA provider library */ static final String ONNXRUNTIME_LIBRARY_CUDA_NAME = "onnxruntime_providers_cuda"; + /** The short name of the ONNX runtime ROCM provider library */ static final String ONNXRUNTIME_LIBRARY_ROCM_NAME = "onnxruntime_providers_rocm"; + /** The short name of the ONNX runtime DNNL provider library */ static final String ONNXRUNTIME_LIBRARY_DNNL_NAME = "onnxruntime_providers_dnnl"; + /** The short name of the ONNX runtime OpenVINO provider library */ static final String ONNXRUNTIME_LIBRARY_OPENVINO_NAME = "onnxruntime_providers_openvino"; + /** The short name of the ONNX runtime TensorRT provider library */ static final String ONNXRUNTIME_LIBRARY_TENSORRT_NAME = "onnxruntime_providers_tensorrt"; diff --git a/java/src/main/java/ai/onnxruntime/OnnxSequence.java b/java/src/main/java/ai/onnxruntime/OnnxSequence.java index 93e1be21588b..7722514b913b 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSequence.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSequence.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.logging.Logger; /** * A sequence of {@link OnnxValue}s all of the same type. @@ -24,6 +25,7 @@ * */ public class OnnxSequence implements OnnxValue { + private static final Logger logger = Logger.getLogger(OnnxSequence.class.getName()); static { try { @@ -40,6 +42,8 @@ public class OnnxSequence implements OnnxValue { private final SequenceInfo info; + private boolean closed; + /** * Creates the wrapper object for a native sequence. * @@ -53,6 +57,7 @@ public class OnnxSequence implements OnnxValue { this.nativeHandle = nativeHandle; this.allocatorHandle = allocatorHandle; this.info = info; + this.closed = false; } @Override @@ -76,6 +81,7 @@ public OnnxValueType getType() { */ @Override public List getValue() throws OrtException { + checkClosed(); if (info.sequenceOfMaps) { OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); return Collections.unmodifiableList(Arrays.asList(maps)); @@ -110,10 +116,27 @@ public String toString() { return "OnnxSequence(info=" + info.toString() + ")"; } + @Override + public synchronized boolean isClosed() { + return closed; + } + /** Closes this sequence, releasing the native memory backing it and it's elements. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed sequence."); + } + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } } private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java index 53bd4c7f9b3e..8400ef53ff6d 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java @@ -14,6 +14,7 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.Arrays; +import java.util.logging.Logger; /** * A Java object wrapping an OnnxSparseTensor. @@ -22,6 +23,7 @@ * different static inner class representing each type. */ public final class OnnxSparseTensor extends OnnxTensorLike { + private static final Logger logger = Logger.getLogger(OnnxSparseTensor.class.getName()); private final SparseTensorType sparseTensorType; // Held to prevent deallocation while used in native code. @@ -198,6 +200,7 @@ public OnnxValueType getType() { @Override public SparseTensor getValue() throws OrtException { + checkClosed(); Buffer buffer = getValuesBuffer(); long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); switch (sparseTensorType) { @@ -234,8 +237,13 @@ public SparseTensor getValue() throws OrtException { } @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed OnnxSparseTensor."); + } } /** @@ -257,6 +265,7 @@ public SparseTensorType getSparseTensorType() { * @return The indices. */ public Buffer getIndicesBuffer() { + checkClosed(); switch (sparseTensorType) { case COO: case CSRC: @@ -295,6 +304,7 @@ public Buffer getIndicesBuffer() { * @return The inner indices. */ public LongBuffer getInnerIndicesBuffer() { + checkClosed(); if (sparseTensorType == SparseTensorType.CSRC) { LongBuffer buf = getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle) @@ -320,6 +330,7 @@ public LongBuffer getInnerIndicesBuffer() { * @return The data buffer. */ public Buffer getValuesBuffer() { + checkClosed(); ByteBuffer buffer = getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); switch (info.type) { @@ -396,6 +407,7 @@ public Buffer getValuesBuffer() { * @return The indices shape. */ public long[] getIndicesShape() { + checkClosed(); return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); } @@ -405,6 +417,7 @@ public long[] getIndicesShape() { * @return The indices shape. */ public long[] getInnerIndicesShape() { + checkClosed(); if (sparseTensorType == SparseTensorType.CSRC) { return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); } else { @@ -420,6 +433,7 @@ public long[] getInnerIndicesShape() { * @return The values shape. */ public long[] getValuesShape() { + checkClosed(); return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle); } @@ -623,6 +637,7 @@ public abstract static class SparseTensor { /** The buffer holding the indices. */ final T indices; + /** The buffer holding the values. */ final Buffer values; diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 0078adb6402f..e1ee2c14fd9d 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -14,12 +14,14 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.Optional; +import java.util.logging.Logger; /** * A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be * returned as outputs. */ public class OnnxTensor extends OnnxTensorLike { + private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName()); /** * This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does @@ -97,6 +99,7 @@ public OnnxValueType getType() { */ @Override public Object getValue() throws OrtException { + checkClosed(); if (info.isScalar()) { switch (info.type) { case FLOAT: @@ -144,16 +147,21 @@ public Object getValue() throws OrtException { @Override public String toString() { - return "OnnxTensor(info=" + info.toString() + ")"; + return "OnnxTensor(info=" + info.toString() + ",closed=" + closed + ")"; } /** - * Closes the tensor, releasing it's underlying memory (if it's not backed by an NIO buffer). If - * it is backed by a buffer then the memory is released when the buffer is GC'd. + * Closes the tensor, releasing its underlying memory (if it's not backed by an NIO buffer). If it + * is backed by a buffer then the memory is released when the buffer is GC'd. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed tensor."); + } } /** @@ -165,6 +173,7 @@ public void close() { * @return A ByteBuffer copy of the OnnxTensor. */ public ByteBuffer getByteBuffer() { + checkClosed(); if (info.type != OnnxJavaType.STRING) { ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); @@ -183,6 +192,7 @@ public ByteBuffer getByteBuffer() { * @return A FloatBuffer copy of the OnnxTensor. */ public FloatBuffer getFloatBuffer() { + checkClosed(); if (info.type == OnnxJavaType.FLOAT) { // if it's fp32 use the efficient copy. FloatBuffer buffer = getBuffer().asFloatBuffer(); @@ -212,6 +222,7 @@ public FloatBuffer getFloatBuffer() { * @return A DoubleBuffer copy of the OnnxTensor. */ public DoubleBuffer getDoubleBuffer() { + checkClosed(); if (info.type == OnnxJavaType.DOUBLE) { DoubleBuffer buffer = getBuffer().asDoubleBuffer(); DoubleBuffer output = DoubleBuffer.allocate(buffer.capacity()); @@ -230,6 +241,7 @@ public DoubleBuffer getDoubleBuffer() { * @return A ShortBuffer copy of the OnnxTensor. */ public ShortBuffer getShortBuffer() { + checkClosed(); if ((info.type == OnnxJavaType.INT16) || (info.type == OnnxJavaType.FLOAT16) || (info.type == OnnxJavaType.BFLOAT16)) { @@ -250,6 +262,7 @@ public ShortBuffer getShortBuffer() { * @return An IntBuffer copy of the OnnxTensor. */ public IntBuffer getIntBuffer() { + checkClosed(); if (info.type == OnnxJavaType.INT32) { IntBuffer buffer = getBuffer().asIntBuffer(); IntBuffer output = IntBuffer.allocate(buffer.capacity()); @@ -268,6 +281,7 @@ public IntBuffer getIntBuffer() { * @return A LongBuffer copy of the OnnxTensor. */ public LongBuffer getLongBuffer() { + checkClosed(); if (info.type == OnnxJavaType.INT64) { LongBuffer buffer = getBuffer().asLongBuffer(); LongBuffer output = LongBuffer.allocate(buffer.capacity()); diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java index c2989fe296dc..bbfd4e981ece 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java @@ -28,6 +28,9 @@ public abstract class OnnxTensorLike implements OnnxValue { /** The size and shape information for this tensor. */ protected final TensorInfo info; + /** Is this value closed? */ + protected boolean closed; + /** * Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor). * @@ -39,6 +42,7 @@ public abstract class OnnxTensorLike implements OnnxValue { this.nativeHandle = nativeHandle; this.allocatorHandle = allocatorHandle; this.info = info; + this.closed = false; } /** @@ -59,4 +63,16 @@ long getNativeHandle() { public TensorInfo getInfo() { return info; } + + @Override + public synchronized boolean isClosed() { + return closed; + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } + } } diff --git a/java/src/main/java/ai/onnxruntime/OnnxValue.java b/java/src/main/java/ai/onnxruntime/OnnxValue.java index 752a0e74267d..e829bc80f09f 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxValue.java +++ b/java/src/main/java/ai/onnxruntime/OnnxValue.java @@ -64,7 +64,14 @@ public enum OnnxValueType { */ public ValueInfo getInfo(); - /** Closes the OnnxValue, freeing it's native memory. */ + /** + * Checks if this value is closed (i.e., the native object has been released). + * + * @return True if the value is closed and the native object has been released. + */ + public boolean isClosed(); + + /** Closes the OnnxValue, freeing its native memory. */ @Override public void close(); diff --git a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java index 39a5121fad7a..70af10ff8cd7 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java @@ -5,11 +5,14 @@ package ai.onnxruntime; import java.io.IOException; +import java.util.logging.Logger; /** An abstract base class for execution provider options classes. */ // Note this lives in ai.onnxruntime to allow subclasses to access the OnnxRuntime.ortApiHandle // package private field. public abstract class OrtProviderOptions implements AutoCloseable { + private static final Logger logger = Logger.getLogger(OrtProviderOptions.class.getName()); + static { try { OnnxRuntime.init(); @@ -21,6 +24,9 @@ public abstract class OrtProviderOptions implements AutoCloseable { /** The native pointer. */ protected final long nativeHandle; + /** Is the native object closed? */ + protected boolean closed; + /** * Constructs a OrtProviderOptions wrapped around a native pointer. * @@ -28,6 +34,7 @@ public abstract class OrtProviderOptions implements AutoCloseable { */ protected OrtProviderOptions(long nativeHandle) { this.nativeHandle = nativeHandle; + this.closed = false; } /** @@ -46,9 +53,30 @@ protected static long getApiHandle() { */ public abstract OrtProvider getProvider(); + /** + * Is the native object closed? + * + * @return True if the native object has been released. + */ + public synchronized boolean isClosed() { + return closed; + } + @Override public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed tensor."); + } + } + + /** Checks if the OrtProviderOptions is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtProviderOptions"); + } } /** diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 49ddf29c2233..eeede3a1bed0 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -12,6 +12,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.logging.Logger; /** * Wraps an ONNX training model and allows training and inference calls. @@ -1049,8 +1050,12 @@ private native void exportModelForInference( /** Wrapper class for the checkpoint state. */ static final class OrtCheckpointState implements AutoCloseable { + private static final Logger logger = Logger.getLogger(OrtCheckpointState.class.getName()); + final long nativeHandle; + private boolean closed; + /** * Wraps an object around the checkpoint native handle. * @@ -1058,6 +1063,7 @@ static final class OrtCheckpointState implements AutoCloseable { */ OrtCheckpointState(long nativeHandle) { this.nativeHandle = nativeHandle; + this.closed = false; } /** @@ -1097,6 +1103,7 @@ static OrtCheckpointState loadCheckpoint(String checkpoint) throws OrtException * @throws OrtException If the checkpoint failed to save. */ public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtException { + checkClosed(); Objects.requireNonNull(outputPath, "checkpoint path must not be null"); String outputStr = outputPath.toString(); saveCheckpoint( @@ -1115,6 +1122,7 @@ public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtExc * @throws OrtException If the call failed. */ public void addProperty(String name, float value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1127,6 +1135,7 @@ public void addProperty(String name, float value) throws OrtException { * @throws OrtException If the call failed. */ public void addProperty(String name, int value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1139,6 +1148,7 @@ public void addProperty(String name, int value) throws OrtException { * @throws OrtException If the call failed. */ public void addProperty(String name, String value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1152,6 +1162,7 @@ public void addProperty(String name, String value) throws OrtException { * @throws OrtException If the property does not exist, or is of the wrong type. */ public float getFloatProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getFloatProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1169,6 +1180,7 @@ public float getFloatProperty(OrtAllocator allocator, String name) throws OrtExc * @throws OrtException If the property does not exist, or is of the wrong type. */ public int getIntProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getIntProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1186,6 +1198,7 @@ public int getIntProperty(OrtAllocator allocator, String name) throws OrtExcepti * @throws OrtException If the property does not exist, or is of the wrong type. */ public String getStringProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getStringProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1194,9 +1207,25 @@ public String getStringProperty(OrtAllocator allocator, String name) throws OrtE name); } + /** Checks if the OrtCheckpointState is closed, if so throws {@link IllegalStateException}. */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtCheckpointState"); + } + } + + public synchronized boolean isClosed() { + return closed; + } + @Override - public void close() { - close(OnnxRuntime.ortTrainingApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortTrainingApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing a checkpoint twice"); + } } /* diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 69ccb954e8af..1c21387b5045 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -7,6 +7,7 @@ import java.lang.reflect.Array; import java.nio.Buffer; import java.util.Arrays; +import java.util.stream.Collectors; /** Describes an {@link OnnxTensor}, including it's size, shape and element type. */ public class TensorInfo implements ValueInfo { @@ -159,6 +160,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { /** The shape of the tensor. */ final long[] shape; + /** The names of the unbound dimensions. */ + final String[] dimensionNames; + + /** If there are non-empty dimension names */ + private final boolean hasNames; + /** The Java type of this tensor. */ public final OnnxJavaType type; @@ -177,6 +184,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { */ TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) { this.shape = shape; + this.dimensionNames = new String[shape.length]; + Arrays.fill(dimensionNames, ""); + this.hasNames = false; this.type = type; this.onnxType = onnxType; this.numElements = elementCount(shape); @@ -188,10 +198,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { *

Called from JNI. * * @param shape The tensor shape. + * @param names The dimension names. * @param typeInt The native type int. */ - TensorInfo(long[] shape, int typeInt) { + TensorInfo(long[] shape, String[] names, int typeInt) { this.shape = shape; + this.dimensionNames = names; + boolean hasNames = false; + for (String s : names) { + if (!s.isEmpty()) { + hasNames = true; + break; + } + } + this.hasNames = hasNames; this.onnxType = OnnxTensorType.mapFromInt(typeInt); this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); this.numElements = elementCount(shape); @@ -206,15 +226,42 @@ public long[] getShape() { return Arrays.copyOf(shape, shape.length); } + /** + * Get a copy of the tensor's named dimensions. + * + * @return A copof the tensor's named dimensions. + */ + public String[] getDimensionNames() { + return Arrays.copyOf(dimensionNames, dimensionNames.length); + } + @Override public String toString() { - return "TensorInfo(javaType=" - + type.toString() - + ",onnxType=" - + onnxType.toString() - + ",shape=" - + Arrays.toString(shape) - + ")"; + String output = + "TensorInfo(javaType=" + + type.toString() + + ",onnxType=" + + onnxType.toString() + + ",shape=" + + Arrays.toString(shape); + if (hasNames) { + output = + output + + ",dimNames=[" + + Arrays.stream(dimensionNames) + .map( + a -> { + if (a.isEmpty()) { + return "\"\""; + } else { + return a; + } + }) + .collect(Collectors.joining(",")) + + "]"; + } + output = output + ")"; + return output; } /** diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index eb124decf75f..cec3fadf446c 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags { /** Enables CoreML on subgraphs. */ ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002) /** Only enable usage of CoreML if the device has an Apple Neural Engine. */ - ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004), + ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004) + /** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also + * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs + * have dynamic shapes. + */ + ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008) + /** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or + * later. + */ + CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) /** The native value of the enum. */ public final int value; diff --git a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java index 02207b2949e5..961163035c9a 100644 --- a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java @@ -32,6 +32,7 @@ protected StringConfigProviderOptions(long nativeHandle) { * @throws OrtException If the addition failed. */ public void add(String key, String value) throws OrtException { + checkClosed(); Objects.requireNonNull(key, "Key must not be null"); Objects.requireNonNull(value, "Value must not be null"); options.put(key, value); diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 879ba8a31061..7b2629158139 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT if (code != ORT_OK) { return NULL; } - //printf("numDim %d\n",numDim); int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim); code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); if (code != ORT_OK) { @@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT free(dimensions); dimensions = NULL; + // Create the string array for the names. + const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim); + if (dimensionNames == NULL) { + throwOrtException(jniEnv, 1, "Not enough memory"); + return NULL; + } + code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim)); + if (code != ORT_OK) { + // extraction failed, exception has been thrown, return to Java. + free(dimensionNames); + return NULL; + } + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL); + for (size_t i = 0; i < numDim; i++) { + jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName); + } + free(dimensionNames); + // Create the TensorInfo object static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo"; jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName); - jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V"); - //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor); - jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt); + jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([J[Ljava/lang/String;I)V"); + jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt); return tensorInfo; } diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 3a1c0d1bb8fa..337f4c1921c6 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -8,7 +8,7 @@ #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "OrtJniUtil.h" #include "ai_onnxruntime_OrtSession_SessionOptions.h" -#ifdef WIN32 +#ifdef _WIN32 #include #else #include @@ -318,7 +318,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC // Iterate the handles, calling the appropriate close function for (jint i = 0; i < numHandles; i++) { -#ifdef WIN32 +#ifdef _WIN32 FreeLibrary((void*)handles[i]); #else dlclose((void*)handles[i]); @@ -630,7 +630,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIG JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDirectML (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceID) { (void)jobj; - #ifdef USE_DIRECTML + #ifdef USE_DML checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*) handle, deviceID)); #else (void)apiHandle;(void)handle;(void)deviceID; // Parameters used when DirectML is defined. diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index 9f7b8d3a3dcf..464234c34798 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes } } wchar_t* optimizerStr = NULL; - if (optimizerPath == NULL) { + if (optimizerPath != NULL) { optimizerStr = copyAndPad(jniEnv, optimizerPath); if (optimizerStr == NULL) { // exception has been thrown in Java, go to cleanup and return null. diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index e975117fb75b..ac65cbab146b 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -69,7 +69,9 @@ public void environmentTest() { // Checks that the environment instance is the same. OrtEnvironment otherEnv = OrtEnvironment.getEnvironment(); assertSame(env, otherEnv); + TestHelpers.quietLogger(OrtEnvironment.class); otherEnv = OrtEnvironment.getEnvironment("test-name"); + TestHelpers.loudLogger(OrtEnvironment.class); assertSame(env, otherEnv); } @@ -588,6 +590,12 @@ public void testSymbolicDimensionAssignment() throws OrtException { Map infoMap = session.getInputInfo(); TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); assertArrayEquals(new long[] {-1, 2}, aInfo.shape); + assertEquals(2, aInfo.dimensionNames.length); + assertEquals("n", aInfo.dimensionNames[0]); + assertEquals("", aInfo.dimensionNames[1]); + TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo(); + assertEquals(1, bInfo.dimensionNames.length); + assertEquals("m", bInfo.dimensionNames[0]); } } // Check that when the options are assigned it overrides the symbolic dimension @@ -643,6 +651,12 @@ public void testCoreML() throws OrtException { runProvider(OrtProvider.CORE_ML); } + @Test + @EnabledIfSystemProperty(named = "USE_DML", matches = "1") + public void testDirectML() throws OrtException { + runProvider(OrtProvider.DIRECT_ML); + } + private void runProvider(OrtProvider provider) throws OrtException { EnumSet providers = OrtEnvironment.getAvailableProviders(); assertTrue(providers.size() > 1); @@ -665,7 +679,7 @@ private void runProvider(OrtProvider provider) throws OrtException { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); } else { - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); @@ -1918,6 +1932,8 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid options.addNnapi(); break; case DIRECT_ML: + options.setMemoryPatternOptimization(false); + options.setExecutionMode(ExecutionMode.SEQUENTIAL); options.addDirectML(0); break; case ACL: diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index a5f285ba86a1..c060cf73ecf1 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -4,6 +4,10 @@ */ package ai.onnxruntime; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + import ai.onnxruntime.platform.Fp16Conversions; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -97,8 +101,8 @@ public void testBufferCreation() throws OrtException { float[] arrValues = new float[] {0, 1, 2, 3, 4}; try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { // array creation isn't backed by buffers - Assertions.assertFalse(t.ownsBuffer()); - Assertions.assertFalse(t.getBufferRef().isPresent()); + assertFalse(t.ownsBuffer()); + assertFalse(t.getBufferRef().isPresent()); FloatBuffer buf = t.getFloatBuffer(); float[] output = new float[arrValues.length]; buf.get(output); @@ -146,7 +150,7 @@ public void testBufferCreation() throws OrtException { directBuffer.rewind(); try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) { // direct buffers don't trigger a copy - Assertions.assertFalse(t.ownsBuffer()); + assertFalse(t.ownsBuffer()); // tensors backed by buffers can get the buffer ref back out Assertions.assertTrue(t.getBufferRef().isPresent()); FloatBuffer buf = t.getFloatBuffer(); @@ -428,4 +432,21 @@ public void testBf16RoundTrip() { } } } + + @Test + public void testClose() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + long[] input = new long[] {1, 2, 3, 4, 5}; + OnnxTensor value = OnnxTensor.createTensor(env, input); + assertFalse(value.isClosed()); + long[] output = (long[]) value.getValue(); + assertArrayEquals(input, output); + value.close(); + // check use after close throws + assertThrows(IllegalStateException.class, value::getValue); + // check double close doesn't crash (emits warning) + TestHelpers.quietLogger(OnnxTensor.class); + value.close(); + TestHelpers.loudLogger(OnnxTensor.class); + } } diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 55d8169434d4..c13cdf222b15 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -22,6 +22,8 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; import java.util.regex.Pattern; import org.junit.jupiter.api.Assertions; @@ -258,6 +260,16 @@ static void flattenStringBase(String[] input, List output) { output.addAll(Arrays.asList(input)); } + static void loudLogger(Class loggerClass) { + Logger l = Logger.getLogger(loggerClass.getName()); + l.setLevel(Level.INFO); + } + + static void quietLogger(Class loggerClass) { + Logger l = Logger.getLogger(loggerClass.getName()); + l.setLevel(Level.OFF); + } + public static Path getResourcePath(String path) { return new File(TestHelpers.class.getResource(path).getFile()).toPath(); } diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 1ed883ace36e..0e3bc15ba9c7 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -96,7 +96,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); } diff --git a/java/src/test/java/sample/ScoreMNIST.java b/java/src/test/java/sample/ScoreMNIST.java index 5587b58e17f5..6ecbc5cd56d1 100644 --- a/java/src/test/java/sample/ScoreMNIST.java +++ b/java/src/test/java/sample/ScoreMNIST.java @@ -30,6 +30,7 @@ public class ScoreMNIST { private static final Logger logger = Logger.getLogger(ScoreMNIST.class.getName()); + /** Pattern for splitting libsvm format files. */ private static final Pattern splitPattern = Pattern.compile("\\s+"); diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 3e1e833addb9..e90efd7b97c2 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {Backend} from './backend.js'; +import {InferenceSession} from './inference-session.js'; interface BackendInfo { backend: Backend; @@ -10,6 +11,7 @@ interface BackendInfo { initPromise?: Promise; initialized?: boolean; aborted?: boolean; + error?: string; } const backends: Map = new Map(); @@ -60,43 +62,100 @@ export const registerBackend = (name: string, backend: Backend, priority: number }; /** - * Resolve backend by specified hints. + * Try to resolve and initialize a backend. * - * @param backendHints - a list of execution provider names to lookup. If omitted use registered backends as list. - * @returns a promise that resolves to the backend. + * @param backendName - the name of the backend. + * @returns the backend instance if resolved and initialized successfully, or an error message if failed. + */ +const tryResolveAndInitializeBackend = async(backendName: string): Promise => { + const backendInfo = backends.get(backendName); + if (!backendInfo) { + return 'backend not found.'; + } + + if (backendInfo.initialized) { + return backendInfo.backend; + } else if (backendInfo.aborted) { + return backendInfo.error!; + } else { + const isInitializing = !!backendInfo.initPromise; + try { + if (!isInitializing) { + backendInfo.initPromise = backendInfo.backend.init(backendName); + } + await backendInfo.initPromise; + backendInfo.initialized = true; + return backendInfo.backend; + } catch (e) { + if (!isInitializing) { + backendInfo.error = `${e}`; + backendInfo.aborted = true; + } + return backendInfo.error!; + } finally { + delete backendInfo.initPromise; + } + } +}; + +/** + * Resolve execution providers from the specific session options. + * + * @param options - the session options object. + * @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with + * filtered EP list. * * @ignore */ -export const resolveBackend = async(backendHints: readonly string[]): Promise => { - const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - const errors = []; - for (const backendName of backendNames) { - const backendInfo = backends.get(backendName); - if (backendInfo) { - if (backendInfo.initialized) { - return backendInfo.backend; - } else if (backendInfo.aborted) { - continue; // current backend is unavailable; try next - } +export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions): + Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => { + // extract backend hints from session options + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - const isInitializing = !!backendInfo.initPromise; - try { - if (!isInitializing) { - backendInfo.initPromise = backendInfo.backend.init(backendName); + // try to resolve and initialize all requested backends + let backend: Backend|undefined; + const errors = []; + const availableBackendNames = new Set(); + for (const backendName of backendNames) { + const resolveResult = await tryResolveAndInitializeBackend(backendName); + if (typeof resolveResult === 'string') { + errors.push({name: backendName, err: resolveResult}); + } else { + if (!backend) { + backend = resolveResult; + } + if (backend === resolveResult) { + availableBackendNames.add(backendName); + } } - await backendInfo.initPromise; - backendInfo.initialized = true; - return backendInfo.backend; - } catch (e) { - if (!isInitializing) { - errors.push({name: backendName, err: e}); + } + + // if no backend is available, throw error. + if (!backend) { + throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`); + } + + // for each explicitly requested backend, if it's not available, output warning message. + for (const {name, err} of errors) { + if (backendHints.includes(name)) { + // eslint-disable-next-line no-console + console.warn(`removing requested execution provider "${ + name}" from session options because it is not available: ${err}`); } - backendInfo.aborted = true; - } finally { - delete backendInfo.initPromise; } - } - } - throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`); -}; + const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name)); + + return [ + backend, new Proxy(options, { + get: (target, prop) => { + if (prop === 'executionProviders') { + return filteredEps; + } + return Reflect.get(target, prop); + } + }) + ]; + }; diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 9bfcb1220605..8c07bdd5c5c4 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -58,7 +58,7 @@ export interface TrainingSessionHandler extends SessionHandler { options: InferenceSession.RunOptions): Promise; getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; getContiguousParameters(trainableOnly: boolean): Promise; } @@ -77,8 +77,8 @@ export interface Backend { Promise; createTrainingSessionHandler? - (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer, - evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer, + (checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, trainModelUriOrBuffer: TrainingSession.UriOrBuffer, + evalModelUriOrBuffer: TrainingSession.UriOrBuffer, optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, options: InferenceSession.SessionOptions): Promise; } diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 0cded7e5edbc..c8df1613b326 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -33,6 +33,14 @@ export declare namespace Env { */ simd?: boolean; + /** + * set or get a boolean value indicating whether to enable trace. + * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. + * @defaultValue `false` + */ + trace?: boolean; + /** * Set or get a number specifying the timeout for initialization of WebAssembly backend, in milliseconds. A zero * value indicates no timeout is set. @@ -103,6 +111,7 @@ export declare namespace Env { kernelId: number; kernelType: string; kernelName: string; + programName: string; startTime: number; endTime: number; } @@ -134,13 +143,52 @@ export declare namespace Env { */ ondata?: (data: WebGpuProfilingData) => void; }; + /** + * Set or get the power preference. + * + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as options for `navigator.gpu.requestAdapter()`. + * + * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. + * + * @defaultValue `undefined` + */ + powerPreference?: 'low-power'|'high-performance'; + /** + * Set or get the force fallback adapter flag. + * + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as options for `navigator.gpu.requestAdapter()`. + * + * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. + * + * @defaultValue `undefined` + */ + forceFallbackAdapter?: boolean; + /** + * Set or get the adapter for WebGPU. + * + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as the GPU adapter for the underlying WebGPU backend to create GPU device. + * + * If this property is not set, it will be available to get after the first WebGPU inference session is created. The + * value will be the GPU adapter that created by the underlying WebGPU backend. + * + * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". + * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. + * + * see comments on {@link Tensor.GpuBufferType} + */ + adapter: unknown; /** * Get the device for WebGPU. * + * This property is only available after the first WebGPU inference session is created. + * * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types". * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type. * - * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". + * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; /** @@ -159,6 +207,7 @@ export interface Env { * @defaultValue `'warning'` */ logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'; + /** * Indicate whether run in debug mode. * @@ -166,6 +215,13 @@ export interface Env { */ debug?: boolean; + /** + * set or get a boolean value indicating whether to enable trace. + * + * @defaultValue `false` + */ + trace?: boolean; + /** * Get version of the current package. */ diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 9cbfcc4e8bcd..3ed56b3c2e81 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -11,7 +11,7 @@ * - [onnxruntime-react-native](https://www.npmjs.com/package/onnxruntime-react-native) * * See also: - * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript.html) + * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript/) * - [Inference examples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js) * * @packageDocumentation @@ -21,5 +21,9 @@ export * from './backend.js'; export * from './env.js'; export * from './inference-session.js'; export * from './tensor.js'; +export * from './tensor-conversion.js'; +export * from './tensor-factory.js'; +export * from './trace.js'; +export * from './onnx-model.js'; export * from './onnx-value.js'; export * from './training-session.js'; diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 9bc2088f2088..ab4c6a3e0c46 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackend} from './backend-impl.js'; +import {resolveBackendAndExecutionProviders} from './backend-impl.js'; import {InferenceSessionHandler} from './backend.js'; import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; import {Tensor} from './tensor.js'; +import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from './trace.js'; type SessionOptions = InferenceSessionInterface.SessionOptions; type RunOptions = InferenceSessionInterface.RunOptions; @@ -20,6 +21,7 @@ export class InferenceSession implements InferenceSessionInterface { run(feeds: FeedsType, options?: RunOptions): Promise; run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async run(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + TRACE_FUNC_BEGIN(); const fetches: {[name: string]: OnnxValue|null} = {}; let options: RunOptions = {}; // check inputs @@ -117,6 +119,7 @@ export class InferenceSession implements InferenceSessionInterface { } } } + TRACE_FUNC_END(); return returnValue; } @@ -132,6 +135,7 @@ export class InferenceSession implements InferenceSessionInterface { static async create( arg0: string|ArrayBufferLike|Uint8Array, arg1?: SessionOptions|number, arg2?: number, arg3?: SessionOptions): Promise { + TRACE_FUNC_BEGIN(); // either load from a file or buffer let filePathOrUint8Array: string|Uint8Array; let options: SessionOptions = {}; @@ -191,11 +195,10 @@ export class InferenceSession implements InferenceSessionInterface { throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.'); } - // get backend hints - const eps = options.executionProviders || []; - const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); - const backend = await resolveBackend(backendHints); - const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options); + // resolve backend, update session options with validated EPs, and create session handler + const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); + const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs); + TRACE_FUNC_END(); return new InferenceSession(handler); } diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index c7760692eed0..14db5c59d972 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {InferenceSession as InferenceSessionImpl} from './inference-session-impl.js'; +import {OnnxModelOptions} from './onnx-model.js'; import {OnnxValue, OnnxValueDataLocation} from './onnx-value.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -43,7 +44,7 @@ export declare namespace InferenceSession { /** * A set of configurations for session behavior. */ - export interface SessionOptions { + export interface SessionOptions extends OnnxModelOptions { /** * An array of execution provider options. * @@ -110,7 +111,7 @@ export declare namespace InferenceSession { optimizedModelFilePath?: string; /** - * Wether enable profiling. + * Whether enable profiling. * * This setting is a placeholder for a future use. */ @@ -153,6 +154,12 @@ export declare namespace InferenceSession { */ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + /** + * Whether enable graph capture. + * This setting is available only in ONNXRuntime Web for WebGPU EP. + */ + enableGraphCapture?: boolean; + /** * Store configurations for a session. See * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ @@ -179,22 +186,22 @@ export declare namespace InferenceSession { // #region execution providers // Currently, we have the following backends to support execution providers: - // Backend Node.js binding: supports 'cpu' and 'cuda'. - // Backend WebAssembly: supports 'cpu', 'wasm', 'xnnpack' and 'webnn'. + // Backend Node.js binding: supports 'cpu', 'dml' (win32), 'coreml' (macOS) and 'cuda' (linux). + // Backend WebAssembly: supports 'cpu', 'wasm', 'webgpu' and 'webnn'. // Backend ONNX.js: supports 'webgl'. // Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android). interface ExecutionProviderOptionMap { + coreml: CoreMLExecutionProviderOption; cpu: CpuExecutionProviderOption; - coreml: CoreMlExecutionProviderOption; cuda: CudaExecutionProviderOption; dml: DmlExecutionProviderOption; + nnapi: NnapiExecutionProviderOption; tensorrt: TensorRtExecutionProviderOption; wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; - xnnpack: XnnpackExecutionProviderOption; webgpu: WebGpuExecutionProviderOption; webnn: WebNNExecutionProviderOption; - nnapi: NnapiExecutionProviderOption; + xnnpack: XnnpackExecutionProviderOption; } type ExecutionProviderName = keyof ExecutionProviderOptionMap; @@ -212,10 +219,6 @@ export declare namespace InferenceSession { readonly name: 'cuda'; deviceId?: number; } - export interface CoreMlExecutionProviderOption extends ExecutionProviderOption { - readonly name: 'coreml'; - coreMlFlags?: number; - } export interface DmlExecutionProviderOption extends ExecutionProviderOption { readonly name: 'dml'; deviceId?: number; @@ -240,14 +243,45 @@ export declare namespace InferenceSession { } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; - deviceType?: 'cpu'|'gpu'; + deviceType?: 'cpu'|'gpu'|'npu'; numThreads?: number; powerPreference?: 'default'|'low-power'|'high-performance'; } export interface CoreMLExecutionProviderOption extends ExecutionProviderOption { readonly name: 'coreml'; + /** + * The bit flags for CoreML execution provider. + * + * ``` + * COREML_FLAG_USE_CPU_ONLY = 0x001 + * COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002 + * COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004 + * COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008 + * COREML_FLAG_CREATE_MLPROGRAM = 0x010 + * ``` + * + * See include/onnxruntime/core/providers/coreml/coreml_provider_factory.h for more details. + * + * This flag is available only in ONNXRuntime (Node.js binding). + */ + coreMlFlags?: number; + /** + * Specify whether to use CPU only in CoreML EP. + * + * This setting is available only in ONNXRuntime (react-native). + */ useCPUOnly?: boolean; + /** + * Specify whether to enable CoreML EP on subgraph. + * + * This setting is available only in ONNXRuntime (react-native). + */ enableOnSubgraph?: boolean; + /** + * Specify whether to only enable CoreML EP for Apple devices with ANE (Apple Neural Engine). + * + * This setting is available only in ONNXRuntime (react-native). + */ onlyEnableDeviceWithANE?: boolean; } export interface NnapiExecutionProviderOption extends ExecutionProviderOption { diff --git a/js/common/lib/onnx-model.ts b/js/common/lib/onnx-model.ts new file mode 100644 index 000000000000..1cd3eedb6fcc --- /dev/null +++ b/js/common/lib/onnx-model.ts @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/** + * A string that represents a file's URL or path. + * + * Path is vailable only in onnxruntime-node or onnxruntime-web running in Node.js. + */ +export type FileUrlOrPath = string; + +/** + * A Blob object that represents a file. + */ +export type FileBlob = Blob; + +/** + * A Uint8Array, ArrayBuffer or SharedArrayBuffer object that represents a file content. + * + * When it is an ArrayBuffer or SharedArrayBuffer, the whole buffer is assumed to be the file content. + */ +export type FileData = Uint8Array|ArrayBufferLike; + +/** + * Represents a file that can be loaded by the ONNX Runtime JavaScript API. + */ +export type FileType = FileUrlOrPath|FileBlob|FileData; + +/** + * Represents an external data file. + */ +export interface ExternalDataFileDescription { + /** + * Specify the external data file. + */ + data: FileType; + /** + * Specify the file path. + */ + path: string; +} + +/** + * Represents an external data file. + * + * When using a string, it should be a file URL or path that in the same directory as the model file. + */ +export type ExternalDataFileType = ExternalDataFileDescription|FileUrlOrPath; + +/** + * Options for model loading. + */ +export interface OnnxModelOptions { + /** + * Specifying a list of files that represents the external data. + */ + externalData?: readonly ExternalDataFileType[]; +} diff --git a/js/common/lib/onnx-value.ts b/js/common/lib/onnx-value.ts index a16a30d25d83..72369ce8b420 100644 --- a/js/common/lib/onnx-value.ts +++ b/js/common/lib/onnx-value.ts @@ -3,7 +3,7 @@ import {Tensor} from './tensor.js'; -type NonTensorType = never; +export type NonTensorType = never; /** * Type OnnxValue Represents both tensors and non-tensors value for model's inputs/outputs. diff --git a/js/common/lib/tensor-conversion-impl.ts b/js/common/lib/tensor-conversion-impl.ts index 22397321e8c6..b1de48a10c0e 100644 --- a/js/common/lib/tensor-conversion-impl.ts +++ b/js/common/lib/tensor-conversion-impl.ts @@ -8,10 +8,11 @@ import {Tensor} from './tensor.js'; * implementation of Tensor.toDataURL() */ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions): string => { - const canvas = document.createElement('canvas'); + const canvas = typeof document !== 'undefined' ? document.createElement('canvas') : (new OffscreenCanvas(1, 1)); canvas.width = tensor.dims[3]; canvas.height = tensor.dims[2]; - const pixels2DContext = canvas.getContext('2d'); + const pixels2DContext = + canvas.getContext('2d') as (CanvasRenderingContext2D | OffscreenCanvasRenderingContext2D | null); if (pixels2DContext != null) { // Default values for height and width & format @@ -88,7 +89,11 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions pixels2DContext.fillRect(j, i, 1, 1); } } - return canvas.toDataURL(); + if ('toDataURL' in canvas) { + return canvas.toDataURL(); + } else { + throw new Error('toDataURL is not supported'); + } } else { throw new Error('Can not access image data'); } @@ -98,7 +103,9 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions * implementation of Tensor.toImageData() */ export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOptions): ImageData => { - const pixels2DContext = document.createElement('canvas').getContext('2d'); + const pixels2DContext = typeof document !== 'undefined' ? + document.createElement('canvas').getContext('2d') : + new OffscreenCanvas(1, 1).getContext('2d') as OffscreenCanvasRenderingContext2D; let image: ImageData; if (pixels2DContext != null) { // Default values for height and width & format diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts index 7228c4a97055..19c62cb54bfe 100644 --- a/js/common/lib/tensor-factory-impl.ts +++ b/js/common/lib/tensor-factory-impl.ts @@ -110,13 +110,31 @@ export const tensorFromImage = async( let data: Uint8ClampedArray|undefined; let bufferToTensorOptions: BufferToTensorOptions = options ?? {}; + const createCanvas = () => { + if (typeof document !== 'undefined') { + return document.createElement('canvas'); + } else if (typeof OffscreenCanvas !== 'undefined') { + return new OffscreenCanvas(1, 1); + } else { + throw new Error('Canvas is not supported'); + } + }; + const createCanvasContext = (canvas: HTMLCanvasElement|OffscreenCanvas) => { + if (canvas instanceof HTMLCanvasElement) { + return canvas.getContext('2d'); + } else if (canvas instanceof OffscreenCanvas) { + return canvas.getContext('2d') as OffscreenCanvasRenderingContext2D; + } else { + return null; + } + }; // filling and checking image configuration options if (isHTMLImageEle) { // HTMLImageElement - image object - format is RGBA by default - const canvas = document.createElement('canvas'); + const canvas = createCanvas(); canvas.width = image.width; canvas.height = image.height; - const pixels2DContext = canvas.getContext('2d'); + const pixels2DContext = createCanvasContext(canvas); if (pixels2DContext != null) { let height = image.height; @@ -166,12 +184,12 @@ export const tensorFromImage = async( bufferToTensorOptions.width = width; if (options !== undefined) { - const tempCanvas = document.createElement('canvas'); + const tempCanvas = createCanvas(); tempCanvas.width = width; tempCanvas.height = height; - const pixels2DContext = tempCanvas.getContext('2d'); + const pixels2DContext = createCanvasContext(tempCanvas); if (pixels2DContext != null) { pixels2DContext.putImageData(image, 0, 0); @@ -188,10 +206,10 @@ export const tensorFromImage = async( throw new Error('Please provide image config with format for Imagebitmap'); } - const canvas = document.createElement('canvas'); + const canvas = createCanvas(); canvas.width = image.width; canvas.height = image.height; - const pixels2DContext = canvas.getContext('2d'); + const pixels2DContext = createCanvasContext(canvas); if (pixels2DContext != null) { const height = image.height; @@ -206,8 +224,8 @@ export const tensorFromImage = async( } } else if (isString) { return new Promise((resolve, reject) => { - const canvas = document.createElement('canvas'); - const context = canvas.getContext('2d'); + const canvas = createCanvas(); + const context = createCanvasContext(canvas); if (!image || !context) { return reject(); } diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 6e19d7fb898a..431de4c3635c 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -253,7 +253,7 @@ export interface TensorFactory { /** * create a tensor from an ImageBitmap object * - * @param bitMap - the ImageBitmap object to create tensor from + * @param bitmap - the ImageBitmap object to create tensor from * @param options - An optional object representing options for creating tensor from URL. * * The following default settings will be applied: diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index c4a43ea27fea..b29cb8cbd6d3 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map { - if (!isBigIntChecked) { - isBigIntChecked = true; - const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function'; - const isBigUint64ArrayAvailable = - typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + +// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for +// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array +// polyfill if available. +let isTypedArrayChecked = false; +export const checkTypedArray = () => { + if (!isTypedArrayChecked) { + isTypedArrayChecked = true; + const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; + const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array); @@ -53,5 +58,12 @@ export const checkBigInt = () => { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array); NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64'); } + if (isFloat16ArrayAvailable) { + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array); + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16'); + } else { + // if Float16Array is not available, use 'Uint16Array' to store the data. + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array); + } } }; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c72855..56682ef98e11 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; +import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; import {Tensor as TensorInterface} from './tensor.js'; @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface { arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| TextureConstructorParameters|GpuBufferConstructorParameters, arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { - // perform one-time check for BigInt support - checkBigInt(); + // perform one-time check for BigInt/Float16Array support + checkTypedArray(); let type: TensorType; let dims: readonly number[]; @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface { throw new TypeError(`Unsupported tensor type: ${arg0}.`); } if (Array.isArray(arg1)) { - if (arg0 === 'float16') { + if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) { + // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array. + // // Throw error here because when user try to use number array as data, // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e05..20319ebb800c 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored @@ -160,7 +160,7 @@ export interface Tensor extends TypedTensorBase, TypedTensorUtils { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { + return; + } + // eslint-disable-next-line no-console + console.timeStamp(`${deviceType}::ORT::${label}`); +}; + +const TRACE_FUNC = (msg: string, extraMsg?: string) => { + const stack = new Error().stack?.split(/\r\n|\r|\n/g) || []; + let hasTraceFunc = false; + for (let i = 0; i < stack.length; i++) { + if (hasTraceFunc && !stack[i].includes('TRACE_FUNC')) { + let label = `FUNC_${msg}::${stack[i].trim().split(' ')[1]}`; + if (extraMsg) { + label += `::${extraMsg}`; + } + TRACE('CPU', label); + return; + } + if (stack[i].includes('TRACE_FUNC')) { + hasTraceFunc = true; + } + } +}; + +/** + * @ignore + */ +export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { + return; + } + TRACE_FUNC('BEGIN', extraMsg); +}; + +/** + * @ignore + */ +export const TRACE_FUNC_END = (extraMsg?: string) => { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { + return; + } + TRACE_FUNC('END', extraMsg); +}; diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 23bd4421ae67..bae38b0dfda5 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackend} from './backend-impl.js'; +import {resolveBackendAndExecutionProviders} from './backend-impl.js'; import {SessionHandler, TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; @@ -55,13 +55,12 @@ export class TrainingSession implements TrainingSessionInterface { const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; const options: SessionOptions = sessionOptions || {}; - // get backend hints - const eps = options.executionProviders || []; - const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); - const backend = await resolveBackend(backendHints); + // resolve backend, update session options with validated EPs, and create session handler + const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); if (backend.createTrainingSessionHandler) { const handler = await backend.createTrainingSessionHandler( - trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, + optionsWithValidatedEPs); return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); } else { throw new Error(noBackendErrMsg); diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index e54aed90e702..f9de77e3ac7d 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -11,7 +11,7 @@ export declare namespace TrainingSession { /** * Either URI file path (string) or Uint8Array containing model or checkpoint information. */ - type URIorBuffer = string|Uint8Array; + type UriOrBuffer = string|Uint8Array; } /** @@ -98,13 +98,13 @@ export interface TrainingSession { getParametersSize(trainableOnly: boolean): Promise; /** - * Copies parameter values from the given array to the training state. Currently, only supporting models with + * Copies parameter values from the given buffer to the training state. Currently, only supporting models with * parameters of type Float32. * - * @param buffer - Float32 buffer containing parameters converted to a Uint8Array. + * @param buffer - A Uint8Array representation of Float32 parameters. * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. */ - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; /** * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. @@ -157,19 +157,19 @@ export interface TrainingSessionCreateOptions { /** * URI or buffer for a .ckpt file that contains the checkpoint for the training model. */ - checkpointState: TrainingSession.URIorBuffer; + checkpointState: TrainingSession.UriOrBuffer; /** * URI or buffer for the .onnx training file. */ - trainModel: TrainingSession.URIorBuffer; + trainModel: TrainingSession.UriOrBuffer; /** * Optional. URI or buffer for the .onnx optimizer model file. */ - optimizerModel?: TrainingSession.URIorBuffer; + optimizerModel?: TrainingSession.UriOrBuffer; /** * Optional. URI or buffer for the .onnx eval model file. */ - evalModel?: TrainingSession.URIorBuffer; + evalModel?: TrainingSession.UriOrBuffer; } /** diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 96c2361cceab..40f970ddf02a 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 84f6dba83fa5..3988ac80707e 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,21 +1,21 @@ { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "node_modules/balanced-match": { @@ -34,9 +34,9 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "node_modules/lunr": { @@ -46,9 +46,9 @@ "dev": true }, "node_modules/marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true, "bin": { "marked": "bin/marked.js" @@ -58,24 +58,24 @@ } }, "node_modules/minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, "engines": { - "node": ">=10" + "node": ">=16 || 14 >=14.17" }, "funding": { "url": "https://github.com/sponsors/isaacs" } }, "node_modules/shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -85,30 +85,30 @@ } }, "node_modules/typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "dependencies": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" }, "bin": { "typedoc": "bin/typedoc" }, "engines": { - "node": ">= 14.14" + "node": ">= 16" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x" } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true, "bin": { @@ -116,7 +116,7 @@ "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/vscode-oniguruma": { @@ -134,9 +134,9 @@ }, "dependencies": { "ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "balanced-match": { @@ -155,9 +155,9 @@ } }, "jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "lunr": { @@ -167,24 +167,24 @@ "dev": true }, "marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true }, "minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "requires": { "brace-expansion": "^2.0.1" } }, "shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "requires": { "ansi-sequence-parser": "^1.1.0", @@ -194,21 +194,21 @@ } }, "typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "requires": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true }, diff --git a/js/common/package.json b/js/common/package.json index beab7d29be26..cd2612aab498 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,14 +2,14 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" }, "author": "fs-eire", "scripts": { - "build:cjs": "tsc --module commonjs --outDir ./dist/cjs", + "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs", "build:esm": "tsc", "build:bundles": "webpack", "build": "node ./build.js", @@ -18,7 +18,7 @@ "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" }, "main": "dist/cjs/index.js", "exports": { diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json index 2e4927ac3b32..e9068ad837a8 100644 --- a/js/common/test/tsconfig.json +++ b/js/common/test/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.tools.json", "exclude": ["type-tests/**/*.ts"], "compilerOptions": { - "module": "ES2022", + "module": "Node16", "sourceMap": true } } diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index c3898fbad740..8157df288eeb 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -66,9 +66,17 @@ if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET) execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS}) endif() +if (WIN32) + if (${ONNXRUNTIME_GENERATOR} MATCHES "Ninja") + set(ONNXRUNTIME_WIN_BIN_DIR ${ONNXRUNTIME_BUILD_DIR}) + else() + set(ONNXRUNTIME_WIN_BIN_DIR ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}) + endif() + message(STATUS "onnxruntime dist dir: ${ONNXRUNTIME_WIN_BIN_DIR}") +endif() # add libraries if (WIN32) - target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}) + target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_WIN_BIN_DIR}) else() target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_BUILD_DIR}) endif() @@ -95,14 +103,14 @@ if (WIN32) add_custom_command( TARGET onnxruntime_binding POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/onnxruntime.dll + ${ONNXRUNTIME_WIN_BIN_DIR}/onnxruntime.dll ${dist_folder} ) if (USE_DML) add_custom_command( TARGET onnxruntime_binding POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/DirectML.dll + ${ONNXRUNTIME_WIN_BIN_DIR}/DirectML.dll ${dist_folder} ) endif () @@ -110,7 +118,7 @@ if (WIN32) add_custom_command( TARGET onnxruntime_binding POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/onnxruntime.pdb + ${ONNXRUNTIME_WIN_BIN_DIR}/onnxruntime.pdb ${dist_folder} COMMAND ${CMAKE_COMMAND} -E copy $/onnxruntime_binding.pdb ${dist_folder} ) diff --git a/js/node/README.md b/js/node/README.md index 98b2ea66de2a..234eaa111a22 100644 --- a/js/node/README.md +++ b/js/node/README.md @@ -22,7 +22,7 @@ Following platforms are supported with pre-built binaries: - Linux x64 CPU NAPI_v3 - MacOS x64 CPU NAPI_v3 -To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install /js/node/`. See also [instructions](https://www.onnxruntime.ai/docs/how-to/build.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally. +To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install /js/node/`. See also [instructions](https://onnxruntime.ai/docs/build/inferencing.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally. # GPU Support diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index e8eb0e9babf5..927953b4f1dd 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(this.#inferenceSession.run(feeds, fetches, options)); } catch (e) { @@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend { async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {})); } catch (e) { diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 96c2361cceab..40f970ddf02a 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index c1cf8af4bb80..62b47698a143 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.17.0", + "version": "1.18.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "os": [ "win32", @@ -27,10 +27,10 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/@protobufjs/aspromise": { @@ -336,9 +336,9 @@ "dev": true }, "node_modules/follow-redirects": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.2.tgz", - "integrity": "sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -1242,9 +1242,9 @@ "dev": true }, "follow-redirects": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.2.tgz", - "integrity": "sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "form-data": { @@ -1503,7 +1503,7 @@ "onnxruntime-common": { "version": "file:../common", "requires": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "parse-json": { diff --git a/js/node/package.json b/js/node/package.json index 8e591d8f46b9..026840742e29 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.17.0", + "version": "1.18.0", "dependencies": { "onnxruntime-common": "file:../common" }, diff --git a/js/node/script/build.ts b/js/node/script/build.ts index dfa88821a8d0..cc5950717908 100644 --- a/js/node/script/build.ts +++ b/js/node/script/build.ts @@ -23,6 +23,8 @@ if (ARCH !== 'x64' && ARCH !== 'ia32' && ARCH !== 'arm64' && ARCH !== 'arm') { } // --onnxruntime-build-dir= const ONNXRUNTIME_BUILD_DIR = buildArgs['onnxruntime-build-dir']; +// --onnxruntime-generator= +const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator']; // --rebuild const REBUILD = !!buildArgs.rebuild; // --use_dml @@ -55,6 +57,9 @@ const args = [ if (ONNXRUNTIME_BUILD_DIR && typeof ONNXRUNTIME_BUILD_DIR === 'string') { args.push(`--CDONNXRUNTIME_BUILD_DIR=${ONNXRUNTIME_BUILD_DIR}`); } +if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') { + args.push(`--CDONNXRUNTIME_GENERATOR=${ONNXRUNTIME_GENERATOR}`); +} if (USE_DML) { args.push('--CDUSE_DML=ON'); } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index fd085f953380..707a356b949a 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -199,6 +199,12 @@ private WritableMap loadModelImpl(String uri, byte[] modelData, ReadableMap opti if (modelData != null && modelData.length > 0) { // load model via model data array ortSession = ortEnvironment.createSession(modelData, sessionOptions); + } else if (uri.startsWith("file://") || uri.startsWith("/")) { + // load model from local + if (uri.startsWith("file://")) { + uri = uri.substring(7); + } + ortSession = ortEnvironment.createSession(uri, sessionOptions); } else { // load model via model path string uri InputStream modelStream = diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index 9e20a286c4e2..6f05faf04609 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -3351,9 +3351,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-accessor-descriptor@^0.1.6: version "0.1.6" diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 96c2361cceab..40f970ddf02a 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/react_native/package.json b/js/react_native/package.json index 39e6cb08bb06..47324a76fe55 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -36,7 +36,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.17.0", + "version": "1.18.0", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index ff9be7fbe3a5..bbb0c4f3d1e2 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -3701,9 +3701,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-absolute@^1.0.0: version "1.0.0" @@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.17.0" + version "1.18.0" open@^6.2.0: version "6.4.0" diff --git a/js/web/README.md b/js/web/README.md index c75a40ad6da2..906c78a1b7ec 100644 --- a/js/web/README.md +++ b/js/web/README.md @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience. -ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. +ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports. @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Documents -### Developement +### Development Refer to the following links for development information: diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md index 7c129b66bfa3..cd25819a2069 100644 --- a/js/web/docs/webgl-operators.md +++ b/js/web/docs/webgl-operators.md @@ -29,7 +29,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [BitwiseOr](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseOr) | | | [BitwiseXor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseXor) | | | [BlackmanWindow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BlackmanWindow) | | -| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19) | +| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-21) | | [CastLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#CastLike) | | | [Ceil](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Ceil) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-13) | | [Celu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Celu) | | @@ -62,7 +62,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Exp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Exp) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13) | | [Expand](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand) | | | [EyeLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#EyeLike) | | -| [Flatten](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten) | [1-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13) | +| [Flatten](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten) | [1-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-21) | | [Floor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Floor) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-13) | | [GRU](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU) | | | [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) | @@ -82,7 +82,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [HardSigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid) | | | [HardSwish](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSwish) | | | [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | | -| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) | +| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-21) | | [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | | | [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | | | [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) | @@ -124,7 +124,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [OptionalHasElement](https://github.com/onnx/onnx/blob/main/docs/Operators.md#OptionalHasElement) | | | [Or](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Or) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Or-7) | | [PRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#PRelu) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-7), [9-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-9), [16+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-16) | -| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19) | +| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-21) | | [Pow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pow) | [7-11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-7), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-12), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-13), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-15) | | [QLinearConv](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearConv) | | | [QLinearMatMul](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearMatMul) | | @@ -148,7 +148,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) | | [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | | | [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) | -| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) | +| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-21) | | [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) | | [ReverseSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReverseSequence) | | | [RoiAlign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RoiAlign) | | @@ -166,7 +166,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SequenceInsert](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceInsert) | | | [SequenceLength](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceLength) | | | [SequenceMap](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceMap) | | -| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19) | +| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-21) | | [Shrink](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shrink) | | | [Sigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-13) | | [Sign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sign) | | @@ -182,7 +182,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Split](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-11) | | [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | | | [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) | -| [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13) | +| [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-21) | | [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | | | [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | | | [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | | @@ -194,10 +194,10 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ThresholdedRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ThresholdedRelu) | | | [Tile](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tile) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tile-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tile-13) | | [TopK](https://github.com/onnx/onnx/blob/main/docs/Operators.md#TopK) | | -| [Transpose](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-1), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-13) | +| [Transpose](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-1), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-21) | | [Trilu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Trilu) | | | [Unique](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unique) | | -| [Unsqueeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-13) | +| [Unsqueeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-21) | | [Upsample](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Upsample) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Upsample-7), [9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Upsample-9) | | [Where](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where) | | | [Xor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Xor-7) | diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2f510308d930..c93f4f3cce68 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -34,6 +34,7 @@ Do not modify directly.* | Cos | ai.onnx(7+) | | | Cosh | ai.onnx(9+) | | | CumSum | ai.onnx(11-13,14+) | | +| DepthToSpace | ai.onnx(11-12,13+); com.ms.internal.nhwc(11-12,13+) | | | Div | ai.onnx(7-12,13,14+) | | | Einsum | ai.onnx(12+) | | | Elu | ai.onnx(6+) | | @@ -41,6 +42,7 @@ Do not modify directly.* | Erf | ai.onnx(9-12,13+) | | | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | +| FastGelu | com.microsoft(1+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | @@ -52,6 +54,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | @@ -60,6 +63,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | +| MatMulNBits | com.microsoft(1+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | @@ -84,11 +88,14 @@ Do not modify directly.* | Relu | ai.onnx(6-12,13,14+) | | | Reshape | ai.onnx(5-12,13,14+) | no GPU kernel | | Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | +| RotaryEmbedding | com.microsoft(1+) | | | Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | +| SimplifiedLayerNormalization | ai.onnx(1+) | | | Sin | ai.onnx(7+) | | | Sinh | ai.onnx(9+) | | | SkipLayerNormalization | com.microsoft(1+) | | +| SkipSimplifiedLayerNormalization | com.microsoft(1+) | | | Slice | ai.onnx(1-9,10,11-12,13+) | | | Softmax | ai.onnx(1-10,11-12,13+) | | | Split | ai.onnx(1,2-10,11-12,13-17,18+) | | diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 8fce79843f61..507da0de2b4a 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -9,6 +9,8 @@ const karmaPlugins = args['karma-plugins'] || undefined; const timeoutMocha = args['timeout-mocha'] || 60000; const forceLocalHost = !!args['force-localhost']; +// user data directory; will be passed to the Edge/Chrome/ChromeCanary/Firefox launchers +const userDataDir = args['user-data-dir']; // parse chromium flags let chromiumFlags = args['chromium-flags']; if (!chromiumFlags) { @@ -86,11 +88,12 @@ module.exports = function(config) { hostname, listenAddress, customLaunchers: { - // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly. - EdgeTest: {base: 'Edge', flags: chromiumFlags}, - ChromeTest: {base: 'Chrome', flags: chromiumFlags}, - ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags}, - ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags}, + // Chromium-based browsers + EdgeTest: {base: 'Edge', flags: chromiumFlags, edgeDataDir: userDataDir}, + ChromeTest: {base: 'Chrome', flags: chromiumFlags, chromeDataDir: userDataDir}, + ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags, chromeDataDir: userDataDir}, + FirefoxTest: {base: 'Firefox', profile: userDataDir}, + // // ==== BrowserStack browsers ==== // diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 2d123cdb7129..31ecffb07e40 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -26,7 +26,17 @@ export const initializeFlags = (): void => { env.wasm.proxy = false; } + if (typeof env.wasm.trace !== 'boolean') { + env.wasm.trace = false; + } + if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) { + // Web: when crossOriginIsolated is false, SharedArrayBuffer is not available so WebAssembly threads will not work. + // Node.js: onnxruntime-web does not support multi-threads in Node.js. + if ((typeof self !== 'undefined' && !self.crossOriginIsolated) || + (typeof process !== 'undefined' && process.versions && process.versions.node)) { + env.wasm.numThreads = 1; + } const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency; env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2)); } diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index fb714bf5996f..2c9cd88a375b 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -19,7 +19,7 @@ interface BuildDefinitions { */ readonly DISABLE_WEBGPU: boolean; /** - * defines whether to disable the whole WebAssembly backend in the build. + * defines whether to disable the whole WebNN backend in the build. */ readonly DISABLE_WASM: boolean; /** diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 499327741c82..b212c0f49df3 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,13 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) { require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); + registerBackend('webnn', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (BUILD_DEFS.DISABLE_TRAINING) { - registerBackend('xnnpack', wasmBackend, 9); - registerBackend('webnn', wasmBackend, 9); - } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/onnxjs/model.ts b/js/web/lib/onnxjs/model.ts index f9a1b6e76089..8e689626011b 100644 --- a/js/web/lib/onnxjs/model.ts +++ b/js/web/lib/onnxjs/model.ts @@ -16,6 +16,7 @@ export class Model { constructor() {} load(buf: Uint8Array, graphInitializer?: Graph.Initializer, isOrtFormat?: boolean): void { + let onnxError: Error|undefined; if (!isOrtFormat) { // isOrtFormat === false || isOrtFormat === undefined try { @@ -25,10 +26,19 @@ export class Model { if (isOrtFormat !== undefined) { throw e; } + onnxError = e; } } - this.loadFromOrtFormat(buf, graphInitializer); + try { + this.loadFromOrtFormat(buf, graphInitializer); + } catch (e) { + if (isOrtFormat !== undefined) { + throw e; + } + // Tried both formats and failed (when isOrtFormat === undefined) + throw new Error(`Failed to load model as ONNX format: ${onnxError}\nas ORT format: ${e}`); + } } private loadFromOnnxFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void { diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 96c2361cceab..40f970ddf02a 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 00431a4e86d5..56925b728e9a 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,25 +13,105 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; + + export interface Module extends WebGpuModule { + /** + * Mount the external data file to an internal map, which will be used during session initialization. + * + * @param externalDataFilePath - specify the relative path of the external data file. + * @param externalDataFileData - specify the content data. + */ + mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; + /** + * Unmount all external data files from the internal map. + */ + unmountExternalData(): void; + + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per + * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and + * registers a few callbacks that will be called in C++ code. + */ + jsepInit(name: 'webgpu', initParams: [ + backend: BackendType, alloc: AllocFunction, free: FreeFunction, upload: UploadFunction, + download: DownloadFunction, createKernel: CreateKernelFunction, releaseKernel: ReleaseKernelFunction, + run: RunFunction, captureBegin: CaptureBeginFunction, captureEnd: CaptureEndFunction, replay: ReplayFunction + ]): void; + jsepInit(name: 'webnn', initParams?: never): void; + } + + export interface WebGpuModule { + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ + _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ + _JsepGetNodeName(kernel: number): number; + + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. + * @param sessionId - specify the session ID. + * @returns + */ + jsepOnReleaseSession: (sessionId: number) => void; + } } -export interface OrtWasmModule extends EmscriptenModule { - // #region emscripten functions - stackSave(): number; - stackRestore(stack: number): void; - stackAlloc(size: number): number; - - UTF8ToString(offset: number, maxBytesToRead?: number): string; - lengthBytesUTF8(str: string): number; - stringToUTF8(str: string, offset: number, maxBytes: number): void; - // #endregion - - // #region ORT APIs +export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; - _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; _OrtReleaseSession(sessionHandle: number): void; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; @@ -71,112 +151,61 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtReleaseRunOptions(runOptionsHandle: number): void; _OrtEndProfiling(sessionHandle: number): number; - // #endregion +} - // #region ORT Training APIs - _OrtTrainingLoadCheckpoint?(dataOffset: number, dataLength: number): number; +export interface OrtTrainingAPIs { + _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - _OrtTrainingReleaseCheckpoint?(checkpointHandle: number): void; + _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - _OrtTrainingCreateSession? - (sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, - evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number; + _OrtTrainingCreateSession( + sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, + evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number; - _OrtTrainingLazyResetGrad?(trainingHandle: number): number; + _OrtTrainingLazyResetGrad(trainingHandle: number): number; - _OrtTrainingRunTrainStep? - (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + _OrtTrainingRunTrainStep( + trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, + runOptionsHandle: number): number; - _OrtTrainingOptimizerStep?(trainingHandle: number, runOptionsHandle: number): number; + _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - _OrtTrainingEvalStep? - (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + _OrtTrainingEvalStep( + trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, + runOptionsHandle: number): number; - _OrtTrainingGetParametersSize?(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; + _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - _OrtTrainingCopyParametersToBuffer? - (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingCopyParametersToBuffer( + trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingCopyParametersFromBuffer? - (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingCopyParametersFromBuffer( + trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetModelInputOutputCount? - (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; - _OrtTrainingGetModelInputOutputName? - (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputCount( + trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): + number; - _OrtTrainingReleaseSession?(trainingHandle: number): void; + _OrtTrainingReleaseSession(trainingHandle: number): void; +} + +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, + Partial { + // #region emscripten functions + stackSave(): number; + stackRestore(stack: number): void; + stackAlloc(size: number): number; + + UTF8ToString(offset: number, maxBytesToRead?: number): string; + lengthBytesUTF8(str: string): number; + stringToUTF8(str: string, offset: number, maxBytes: number): void; // #endregion // #region config + numThreads?: number; mainScriptUrlOrBlob?: string|Blob; // #endregion - - // #region JSEP - /** - * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. - * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. - */ - jsepInit? - (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, - download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; - - /** - * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). - * - * @param context - specify the kernel context pointer. - * @param index - specify the index of the output. - * @param data - specify the pointer to encoded data of type and dims. - */ - _JsepOutput(context: number, index: number, data: number): number; - /** - * [exported from wasm] Get name of an operator node. - * - * @param kernel - specify the kernel pointer. - * @returns the pointer to a C-style UTF8 encoded string representing the node name. - */ - _JsepGetNodeName(kernel: number): number; - - /** - * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. - * - * @param sessionId - specify the session ID. - * @param index - specify an integer to represent which input/output it is registering for. For input, it is the - * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index - * corresponding to the session's ouputNames. - * @param buffer - specify the GPU buffer to register. - * @param size - specify the original data size in byte. - * @returns the GPU data ID for the registered GPU buffer. - */ - jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. - * - * @param sessionId - specify the session ID. - */ - jsepUnregisterBuffers?: (sessionId: number) => void; - /** - * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. - * - * @param dataId - specify the GPU data ID - * @returns the GPU buffer. - */ - jsepGetBuffer: (dataId: number) => GPUBuffer; - /** - * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. - * - * @param gpuBuffer - specify the GPU buffer - * @param size - specify the original data size in byte. - * @param type - specify the tensor type. - * @returns the generated downloader function. - */ - jsepCreateDownloader: - (gpuBuffer: GPUBuffer, size: number, - type: Tensor.GpuBufferDataTypes) => () => Promise; - // #endregion } declare const moduleFactory: EmscriptenModuleFactory; diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 6c3d22352772..1b421029cc7a 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -1,14 +1,37 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env, Tensor} from 'onnxruntime-common'; +import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; + +import {DataType, tensorDataTypeEnumToString} from '../wasm-common'; import {configureLogger, LOG_DEBUG} from './log'; import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency} from './webgpu/types'; +import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; + +interface CommandInfo { + readonly kernelId: number; + readonly computePipeline: GPUComputePipeline; + readonly bindGroup: GPUBindGroup; + readonly dispatchGroup: [number, number, number]; +} + +interface KernelInfo { + readonly kernelType: string; + readonly kernelName: string; + readonly kernelEntry: RunFunction; + readonly attributes: [((attribute: unknown) => unknown)|undefined, unknown]; +} + +interface PendingKernelInfo { + readonly kernelId: number; + readonly programName: string; + readonly inputTensorViews: readonly TensorView[]; + readonly outputTensorViews: readonly TensorView[]; +} const getProgramInputTensorInfoDependencyKey = (inputTensors: readonly TensorView[], inputDependencies: readonly ProgramInputTensorInfoDependency[]): string => { @@ -71,11 +94,32 @@ const getProgramInfoUniqueKey = return key; }; +class AdapterInfoImpl implements AdapterInfo { + readonly architecture?: string; + readonly vendor?: string; + + constructor(adapterInfo: GPUAdapterInfo) { + if (adapterInfo) { + this.architecture = adapterInfo.architecture; + this.vendor = adapterInfo.vendor; + } + } + + isArchitecture(architecture: GpuArchitecture): boolean { + return this.architecture === architecture; + } + + isVendor(vendor: GpuVendor): boolean { + return this.vendor === vendor; + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. */ export class WebGpuBackend { + adapterInfo: AdapterInfoImpl; device: GPUDevice; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping @@ -87,6 +131,13 @@ export class WebGpuBackend { */ programManager: ProgramManager; + /** + * representing the session ID of which is currently being run. + * `null` means no session is being run. + * only valid when session.run is executed. + */ + currentSessionId: number|null = null; + /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. @@ -122,22 +173,33 @@ export class WebGpuBackend { return data; } - /** - * a KernelID -> kernel info mapping. value is - * [ op_type, name, run function, [optional] preprocess_attribute_once function ] - */ - kernels: Map unknown) | undefined, unknown]]>; - + // KernelID -> kernelInfo mapping + kernels: Map; private commandEncoder: GPUCommandEncoder|null = null; private computePassEncoder: GPUComputePassEncoder|null = null; + maxDispatchNumber = 16; pendingDispatchNumber = 0; - queryData?: GpuData; - querySet?: GPUQuerySet; - querySetCount = 2; - queryTimeBase?: bigint; + // info of kernels pending submission for a single batch + private pendingKernels: PendingKernelInfo[] = []; + // queryReadBuffer -> pendingKernels mapping for all the batches + private pendingQueries: Map = new Map(); + private queryResolveBuffer?: GPUBuffer; + private querySet?: GPUQuerySet; + private queryTimeBase?: bigint; + queryType: TimestampQuery; env: Env; + sessionStatus: SessionState = 'default'; + /** + * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + */ + capturedCommandList: Map = new Map(); + + /** + * a SessionID -> PendingKernelInfo[] mapping for profiling. + */ + private capturedPendingKernels: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -161,7 +223,9 @@ export class WebGpuBackend { requiredFeatures, }; - if (adapter.features.has('timestamp-query')) { + if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) { + requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName); + } else if (adapter.features.has('timestamp-query')) { requiredFeatures.push('timestamp-query'); } if (adapter.features.has('shader-f16')) { @@ -169,6 +233,7 @@ export class WebGpuBackend { } this.device = await adapter.requestDevice(deviceDescriptor); + this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo()); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); this.kernels = new Map(); @@ -187,7 +252,13 @@ export class WebGpuBackend { } }; - Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); + Object.defineProperty( + this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false}); + Object.defineProperty( + this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false}); + + // init queryType, which is necessary for InferenceSession.create + this.setQueryType(); } dispose(): void { @@ -206,22 +277,18 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { + const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; - if (this.isQueryEnabled()) { - if (typeof this.querySet === 'undefined') { - this.querySet = this.device.createQuerySet({ - type: 'timestamp', - count: this.querySetCount, - }); - } + + if (this.queryType === 'at-passes') { computePassDescriptor.timestampWrites = { - querySet: this.querySet, - beginningOfPassWriteIndex: 0, - endOfPassWriteIndex: 1, + querySet: this.querySet!, + beginningOfPassWriteIndex: this.pendingDispatchNumber * 2, + endOfPassWriteIndex: this.pendingDispatchNumber * 2 + 1, }; } - this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -234,19 +301,95 @@ export class WebGpuBackend { } flush(): void { - if (this.commandEncoder) { - this.endComputePass(); - this.device.queue.submit([this.getCommandEncoder().finish()]); - this.gpuDataManager.refreshPendingBuffers(); - this.commandEncoder = null; - this.pendingDispatchNumber = 0; + if (!this.commandEncoder) { + return; } - } - isQueryEnabled(): boolean { - return this.device.features.has('timestamp-query') && - (this.env.webgpu.profiling?.mode === 'default' || - (!this.env.webgpu.profiling?.mode && this.env.webgpu.profilingMode === 'default')); + TRACE_FUNC_BEGIN(); + + this.endComputePass(); + let queryReadBuffer: GPUBuffer; + if (this.queryType !== 'none') { + this.commandEncoder.resolveQuerySet( + this.querySet!, 0, this.pendingDispatchNumber * 2, this.queryResolveBuffer!, 0); + + queryReadBuffer = this.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: this.pendingDispatchNumber * 2 * 8, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST}); + + this.pendingQueries.set(queryReadBuffer, this.pendingKernels); + this.pendingKernels = []; + this.commandEncoder.copyBufferToBuffer( + this.queryResolveBuffer!, 0, queryReadBuffer, 0, this.pendingDispatchNumber * 2 * 8); + } + + this.device.queue.submit([this.commandEncoder.finish()]); + this.gpuDataManager.refreshPendingBuffers(); + this.commandEncoder = null; + this.pendingDispatchNumber = 0; + + if (this.queryType !== 'none') { + void queryReadBuffer!.mapAsync(GPUMapMode.READ).then(() => { + const mappedData = new BigUint64Array(queryReadBuffer.getMappedRange()); + const pendingKernels = this.pendingQueries.get(queryReadBuffer)!; + for (let i = 0; i < mappedData.length / 2; i++) { + const pendingKernelInfo = pendingKernels[i]; + const kernelId = pendingKernelInfo.kernelId; + const kernelInfo = this.kernels.get(kernelId)!; + const kernelType = kernelInfo.kernelType; + const kernelName = kernelInfo.kernelName; + const programName = pendingKernelInfo.programName; + const inputTensorViews = pendingKernelInfo.inputTensorViews; + const outputTensorViews = pendingKernelInfo.outputTensorViews; + const startTimeU64 = mappedData[i * 2]; + const endTimeU64 = mappedData[i * 2 + 1]; + + if (typeof this.queryTimeBase === 'undefined') { + this.queryTimeBase = startTimeU64; + } + + const startTime = Number(startTimeU64 - this.queryTimeBase); + const endTime = Number(endTimeU64 - this.queryTimeBase); + + if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) { + throw new RangeError('incorrect timestamp range'); + } + + if (this.env.webgpu.profiling?.ondata) { + this.env.webgpu.profiling.ondata({ + version: 1, + inputsMetadata: inputTensorViews.map( + value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), + outputsMetadata: outputTensorViews.map( + value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), + kernelId, + kernelType, + kernelName, + programName, + startTime, + endTime, + }); + } else { + // if no callback is provided, print the profiling message to console + let inputShapes = ''; + inputTensorViews.forEach((value, i) => { + inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; + }); + let outputShapes = ''; + outputTensorViews.forEach((value, i) => { + outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; + }); + // eslint-disable-next-line no-console + console.log(`[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${ + outputShapes}execution time: ${endTime - startTime} ns`); + } + TRACE('GPU', `${programName}::${startTimeU64}::${endTimeU64}`); + } + queryReadBuffer.unmap(); + this.pendingQueries.delete(queryReadBuffer); + }); + } + TRACE_FUNC_END(); } /** @@ -263,14 +406,20 @@ export class WebGpuBackend { run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[], createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView, createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] { + TRACE_FUNC_BEGIN(program.name); // create info for inputs const inputDatas: GpuData[] = []; for (let i = 0; i < inputTensorViews.length; ++i) { - const gpuData = this.gpuDataManager.get(inputTensorViews[i].data); + const data = inputTensorViews[i].data; + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (data === 0) { + continue; + } + const gpuData = this.gpuDataManager.get(data); if (!gpuData) { - throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`); + throw new Error(`no GPU data for input: ${data}`); } - inputDatas[i] = gpuData; + inputDatas.push(gpuData); } const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); @@ -300,6 +449,11 @@ export class WebGpuBackend { const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + outputTensorViews.push(tensorView); + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (tensorView.data === 0) { + continue; + } const gpuData = this.gpuDataManager.get(tensorView.data); if (!gpuData) { throw new Error(`no GPU data for output: ${tensorView.data}`); @@ -315,10 +469,24 @@ export class WebGpuBackend { } persistentData.push(gpuData); } - outputTensorViews.push(tensorView); outputDatas.push(gpuData); } + // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are + // zero-sized tensors. + if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) { + // if all outputs are zero-sized tensors, there is no need to run the program. + if (outputDatas.length === 0) { + TRACE_FUNC_END(program.name); + return outputTensorViews; + } + // if some outputs are zero-sized tensors, report an error. + // + // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors. + // If we see such use case, we need to make a change here to support it. + throw new Error( + `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`); + } // load uniforms // TODO: add cache for uniform (is it necessary?) @@ -334,13 +502,26 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const baseAlignment = data.length <= 2 ? data.length * 4 : 16; + const sizeOfElement = v.type === DataType.float16 ? 2 : 4; + let sizeOfVecOrMat; + let baseAlignment; + if (v.type === DataType.float16) { + baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); + sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; + } else { + baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; + sizeOfVecOrMat = 16; + } currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); - // When data.length > 4, the uniform variable is of type array,N>, where N = - // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * - // SizeOf(vec4). - currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; + // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where + // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type + // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte + // length is N * SizeOf(mat2x4). + const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4; + currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : + data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set @@ -351,12 +532,17 @@ export class WebGpuBackend { programUniforms.forEach((v, i) => { const offset = offsets[i]; const data = typeof v.data === 'number' ? [v.data] : v.data; - if (v.type === 'int32') { + if (v.type === DataType.int32) { new Int32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'uint32') { + } else if (v.type === DataType.uint32) { new Uint32Array(arrayBuffer, offset, data.length).set(data); - } else { + } else if (v.type === DataType.float16) { + // TODO: use Float16Array. + new Uint16Array(arrayBuffer, offset, data.length).set(data); + } else if (v.type === DataType.float) { new Float32Array(arrayBuffer, offset, data.length).set(data); + } else { + throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`); } }); @@ -379,14 +565,47 @@ export class WebGpuBackend { LOG_DEBUG('info', () => `[artifact] key: ${key}, programName: ${program.name}`); } + // validate uniform variables + if (programUniforms && artifact.uniformVariablesInfo) { + if (programUniforms.length !== artifact.uniformVariablesInfo.length) { + throw new Error(`Uniform variables count mismatch: expect ${artifact.uniformVariablesInfo.length}, got ${ + programUniforms.length} in program "${artifact.programInfo.name}".`); + } + for (let i = 0; i < programUniforms.length; i++) { + const uniform = programUniforms[i]; + const actualType = uniform.type; + const actualLength = typeof uniform.data === 'number' ? 1 : uniform.data.length; + const [type, length] = artifact.uniformVariablesInfo[i]; + if (actualType !== type || actualLength !== length) { + throw new Error(`Uniform variable ${i} mismatch: expect type ${type} with size ${length}, got type ${ + actualType} with size ${actualLength} in program "${artifact.programInfo.name}".`); + } + } + } + LOG_DEBUG( 'info', () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - this.programManager.run( - artifact, inputTensorViews, outputTensorViews, inputDatas, outputDatas, normalizedDispatchGroup, - uniformBufferBinding); + if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { + const pendingKernelInfo: PendingKernelInfo = { + kernelId: this.currentKernelId!, + programName: artifact.programInfo.name, + inputTensorViews, + outputTensorViews, + }; + this.pendingKernels.push(pendingKernelInfo); + + if (this.sessionStatus === 'capturing') { + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); + } + } + + this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); + + TRACE_FUNC_END(program.name); return outputTensorViews; } @@ -412,13 +631,19 @@ export class WebGpuBackend { return this.gpuDataManager.release(ptr); } - createKernel(opType: string, kernelId: number, attribute: unknown, nodeName: string): void { - const op = WEBGPU_OP_RESOLVE_RULES.get(opType); + createKernel(kernelType: string, kernelId: number, attribute: unknown, kernelName: string): void { + const op = WEBGPU_OP_RESOLVE_RULES.get(kernelType); if (!op) { - throw new Error(`kernel not implemented: ${opType}`); + throw new Error(`kernel not implemented: ${kernelType}`); } - this.kernels.set(kernelId, [opType, nodeName, op[0], [op[1], attribute]]); + const kernelInfo: KernelInfo = { + kernelType, + kernelName, + kernelEntry: op[0], + attributes: [op[1], attribute], + }; + this.kernels.set(kernelId, kernelInfo); } releaseKernel(kernelId: number): void { @@ -439,9 +664,12 @@ export class WebGpuBackend { if (!kernel) { throw new Error(`kernel not created: ${kernelId}`); } - const [opType, nodeName, kernelEntry, attributes] = kernel; + const kernelType = kernel.kernelType; + const kernelName = kernel.kernelName; + const kernelEntry = kernel.kernelEntry; + const attributes = kernel.attributes; if (this.currentKernelId !== null) { - throw new Error(`kernel "[${opType}] ${nodeName}" is not allowed to be called recursively`); + throw new Error(`kernel "[${kernelType}] ${kernelName}" is not allowed to be called recursively`); } this.currentKernelId = kernelId; @@ -451,7 +679,7 @@ export class WebGpuBackend { attributes[0] = undefined; } - LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${opType}] ${nodeName}"...`); + LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${kernelType}] ${kernelName}"...`); const useErrorScope = this.env.debug; @@ -464,12 +692,12 @@ export class WebGpuBackend { kernelEntry(context, attributes[1]); return 0; // ORT_OK } catch (e) { - errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`)); + errors.push(Promise.resolve(`[WebGPU] Kernel "[${kernelType}] ${kernelName}" failed. ${e}`)); return 1; // ORT_FAIL } finally { if (useErrorScope) { errors.push(this.device.popErrorScope().then( - err => err ? `GPU validation error for kernel "[${opType}] ${nodeName}": ${err.message}` : null)); + err => err ? `GPU validation error for kernel "[${kernelType}] ${kernelName}": ${err.message}` : null)); } for (const data of this.temporaryData) { @@ -515,4 +743,98 @@ export class WebGpuBackend { }; } // #endregion + writeTimestamp(index: number): void { + if (this.queryType !== 'inside-passes') { + return; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (this.computePassEncoder as any).writeTimestamp(this.querySet, index); + } + setQueryType(): void { + this.queryType = 'none'; + if (this.env.webgpu.profiling?.mode === 'default' || + (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) { + if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) { + this.queryType = 'inside-passes'; + } else if (this.device.features.has('timestamp-query')) { + this.queryType = 'at-passes'; + } + + if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.maxDispatchNumber * 2, + }); + this.queryResolveBuffer = this.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); + } + } + } + + captureBegin(): void { + LOG_DEBUG('info', 'captureBegin'); + if (!this.capturedCommandList.get(this.currentSessionId!)) { + this.capturedCommandList.set(this.currentSessionId!, []); + } + if (!this.capturedPendingKernels.get(this.currentSessionId!)) { + this.capturedPendingKernels.set(this.currentSessionId!, []); + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'capturing'; + } + captureEnd(): void { + LOG_DEBUG('info', 'captureEnd'); + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + replay(): void { + LOG_DEBUG('info', 'replay'); + this.sessionStatus = 'replaying'; + const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + const length = sessionCommandList!.length; + this.pendingKernels = []; + for (let i = 0; i < length; i++) { + const computePassEncoder = this.getComputePassEncoder(); + const command = sessionCommandList![i]; + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { + this.flush(); + } + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + + onReleaseSession(sessionId: number): void { + this.unregisterBuffers(sessionId); + if (this.capturedCommandList.has(sessionId)) { + this.capturedCommandList.delete(sessionId); + } + if (this.capturedPendingKernels.has(sessionId)) { + this.capturedPendingKernels.delete(sessionId); + } + this.gpuDataManager.onReleaseSession(sessionId); + } + + onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; + this.setQueryType(); + } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 3c6edf3ebb35..1ceae2394f46 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; import {TensorView} from './tensor-view'; import {ShapeUtil} from './util'; -import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; +import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; /* eslint-disable no-bitwise */ @@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView { } class ComputeContextImpl implements ComputeContext { + readonly adapterInfo: AdapterInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext { private customDataOffset = 0; private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { + this.adapterInfo = backend.adapterInfo; const heapU32 = module.HEAPU32; // extract context data @@ -90,6 +92,17 @@ class ComputeContextImpl implements ComputeContext { this.inputs = inputs; } + getMaxComputeWorkgroupSizes(): [number, number, number] { + return [ + this.backend.device.limits.maxComputeWorkgroupSizeX, this.backend.device.limits.maxComputeWorkgroupSizeY, + this.backend.device.limits.maxComputeWorkgroupSizeZ + ]; + } + + getMaxComputeWorkgroupStoragesize(): number { + return this.backend.device.limits.maxComputeWorkgroupStorageSize; + } + compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = @@ -104,7 +117,8 @@ class ComputeContextImpl implements ComputeContext { throw new Error(`Unsupported data type: ${dataType}`); } const bufferSize = elementSize * ShapeUtil.size(dims); - return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims); + const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0; + return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput); } @@ -118,7 +132,7 @@ class ComputeContextImpl implements ComputeContext { for (let i = 0; i < dims.length; i++) { this.module.HEAPU32[offset++] = dims[i]; } - return this.module._JsepOutput(this.opKernelContext, index, data); + return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { throw new Error( `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + @@ -133,27 +147,39 @@ class ComputeContextImpl implements ComputeContext { /** * Initialize JSEP with WebGPU backend. * - * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called). - * This function expects: + * This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for + * each of the following EPs if they are specified: + * - "webgpu" + * - "webnn" + * + * For WebGPU, this function expects: * - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). * - WebGPU is available in current environment. (a valid GPUAdapter is passed in) + * + * For WebNN, this function expects: + * - WebNN is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). + * - WebNN is available in current environment. (navigator.ml is not undefined) + * * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate - * 'webgpu' backend. + * 'webgpu'/'webnn' backend. * + * @param name - the name of the EP, either "webgpu" or "webnn" * @param module - the ORT WebAssembly module * @param env - the ORT environment variable (ort.env) * @param gpuAdapter - the pre-created GPU adapter */ -export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise => { +export const init = + async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise => { const jsepInit = module.jsepInit; if (!jsepInit) { throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.'); } - const backend = new WebGpuBackend(); - await backend.initialize(env, gpuAdapter); + if (name === 'webgpu') { + const backend = new WebGpuBackend(); + await backend.initialize(env, gpuAdapter!); - jsepInit( + jsepInit('webgpu', [ // backend backend, @@ -170,7 +196,7 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte backend.memcpy(src, dst); } else { LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); - const data = module.HEAPU8.subarray(src, src + size); + const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size); backend.upload(dst, data); } }, @@ -182,13 +208,13 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte 'verbose', () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); - await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size)); + await backend.download( + gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); }, // jsepCreateKernel - (name: string, kernel: number, attribute: unknown) => backend.createKernel( - name, kernel, attribute, - env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`), + (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel( + kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), @@ -201,5 +227,15 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); - }); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay() + ]); + } else { + jsepInit('webnn'); + } }; diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 6922d7ff5df6..9a1d5463f784 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -56,7 +56,16 @@ export class BroadcastUtil { if (aLen !== bLen && aLen > 1 && bLen > 1) { return undefined; } - cdims[crank - i] = Math.max(aLen, bLen); + const max = Math.max(aLen, bLen); + if (aLen && bLen) { + cdims[crank - i] = Math.max(aLen, bLen); + } else { + // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable. + if (max > 1) { + return undefined; + } + cdims[crank - i] = 0; + } } return cdims; @@ -92,6 +101,34 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + /** + * convert dims corresponding to type change to pack. ex. uint8 data to uint32 + */ + static convertShape(dims: readonly number[], size = 4): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } + const newDims = new Array(rank); + let i = rank - 1; + while (i >= 0) { + if (dims[i] % size === 0) { + newDims[i] = dims[i] / size; + break; + } + if (size % dims[i] !== 0) { + throw new Error('cannot convert shape'); + } + newDims[i] = 1; + size /= dims[i]; + i--; + } + for (i--; i >= 0; i--) { + newDims[i] = dims[i]; + } + return newDims; + } + /** * calculate the size (number of elements) from the given axis (inclusive) */ diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 6f3d9a52d9f5..c17bd1e1477e 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -60,9 +60,15 @@ export interface GpuDataManager { unregisterExternalBuffer(buffer: GPUBuffer): void; /** - * destroy all gpu buffers. Call this when the session.release is called. + * destroy all gpu buffers. */ dispose(): void; + + /** + * release session related data. + * @param sessionId - specify the session ID. + */ + onReleaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager { // The external buffers registered users for IO Binding. private externalBuffers: Map; + // The pendingBuffers for capture graph. + // a SessionID -> GPUBuffer[] mapping. + private capturedPendingBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.buffersForUploadingPending = []; this.buffersPending = []; this.externalBuffers = new Map(); + this.capturedPendingBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager { () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ id}, buffer is the same, skip.`); return id; + } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { + throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. + Please use the previous external buffer!`); } this.externalBuffers.delete(previousBuffer); } else { @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); } this.buffersForUploadingPending = []; - for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); + + if (this.buffersPending.length === 0) { + return; + } + + if (this.backend.sessionStatus === 'default') { + for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise - } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { - // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. - this.freeUniformBuffers.get(buffer.size)!.push(buffer); - } else { - buffer.destroy(); + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); + } else { + buffer.destroy(); + } + } + this.buffersPending = []; + } else { + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. + let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); + if (!capturedBuffers) { + capturedBuffers = []; + this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } + for (const buffer of this.buffersPending) { + capturedBuffers.push(buffer); + } + this.buffersPending = []; } - this.buffersPending = []; } dispose() { @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager { storage.gpuData.buffer.destroy(); }); + this.capturedPendingBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); + this.capturedPendingBuffers = new Map(); + } + + onReleaseSession(sessionId: number) { + // release the captured pending buffers. + const pendingBuffers = this.capturedPendingBuffers.get(sessionId); + if (pendingBuffers) { + pendingBuffers.forEach(buffer => { + buffer.destroy(); + }); + this.capturedPendingBuffers.delete(sessionId); + } } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 8e1ec782079b..5627365100d9 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; -import {attention, parseAttentionAttributes} from './ops/attention'; +import {attention} from './ops/attention'; import {batchNorm} from './ops/batch-norm'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; @@ -11,21 +11,25 @@ import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; import {cumsum, parseCumSumAttributes} from './ops/cumsum'; +import {depthToSpace, parseDepthToSpaceAttributes} from './ops/depth-to-space'; import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; +import {fastGelu} from './ops/fast-gelu'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; -import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; -import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; +import {instanceNorm} from './ops/instance-norm'; +import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; -import {pad, parsePadAttributes} from './ops/pad'; +import {pad} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; -import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; +import {rotaryEmbedding} from './ops/rotary-embedding'; +import {skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; @@ -50,7 +54,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], - ['Attention', [attention, parseAttentionAttributes]], + ['Attention', [attention]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BatchNormalization', [batchNorm]], @@ -65,6 +69,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Cos', [unaryOps.cos]], ['Cosh', [unaryOps.cosh]], ['CumSum', [cumsum, parseCumSumAttributes]], + ['DepthToSpace', [depthToSpace, parseDepthToSpaceAttributes]], ['Div', [binaryOps.div]], ['Einsum', [einsum, parseEinsumAttributes]], ['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]], @@ -72,6 +77,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Erf', [unaryOps.erf]], ['Exp', [unaryOps.exp]], ['Expand', [expand]], + ['FastGelu', [fastGelu]], ['Floor', [unaryOps.floor]], ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], @@ -82,20 +88,22 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], - ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], - ['LayerNormalization', [layerNorm, parseLayerNormAttributes]], + ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], + ['InstanceNormalization', [instanceNorm]], + ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], + ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], - ['Pad', [pad, parsePadAttributes]], + ['Pad', [pad]], ['Pow', [binaryOps.pow]], ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], @@ -111,11 +119,12 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ReduceSumSquare', [reduceSumSquare]], ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], + ['RotaryEmbedding', [rotaryEmbedding]], ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], ['Slice', [slice, parseSliceAttributes]], - ['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]], + ['SkipLayerNormalization', [skipLayerNorm]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Softmax', [softmax, parseSoftmaxAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3638938df7db..24006d393592 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -19,12 +19,13 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -88,10 +89,10 @@ const conv2dCommonSnippet = let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; - let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels); + let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]); + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; let xCh = ${col} % inChannels; var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for @@ -108,7 +109,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -117,7 +118,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -129,9 +130,8 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType, dataType); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -142,7 +142,7 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; @@ -181,31 +181,40 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const fitAOuter = dimAOuter % tileAOuter === 0; const fitBOuter = dimBOuter % tileBOuter === 0; const fitInner = dimInner % tileInner === 0; - const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = - inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, + {name: 'dilation', type: 'i32', length: 2} + ]; + appendActivationUniforms(attributes, uniforms); - let declareFunctions = ` + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } @@ -213,51 +222,50 @@ export const createConv2DMatMulProgramInfo = let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` + const x = inputVariable( + 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + } + + return ` ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)} - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); - const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} - ${ + ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + sequentialAccessByThreads)}`; + }; + return { + name: 'Conv2DMatMul', + shaderCache: { + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ + tileAOuter};${tileBOuter};${tileInner}`, + inputDependencies + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d425155857e1..080b24a2432a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -19,20 +19,21 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; +import {biasSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { - const type = typeSnippet(innerElementSize, 'f32'); + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -46,7 +47,7 @@ const conv2dTransposeCommonSnippet = let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; - return vec4(v0, v1, v2, v3); + return ${type}(v0, v1, v2, v3); `; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); @@ -74,21 +75,21 @@ const conv2dTransposeCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; - const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readASnippet = ` - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + let WRow = ${col} / (uniforms.filter_dims[1] * inChannels); + let WCol = ${col} / inChannels % uniforms.filter_dims[1]; + let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]); + let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${type}(0.0); } @@ -103,25 +104,25 @@ const conv2dTransposeCommonSnippet = const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readASnippet} } return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); - let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; + let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); + let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; if (${ - isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : - 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : + 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -129,9 +130,8 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const applyActivation = getActivationSnippet(attributes, type); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -142,7 +142,7 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueInput; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} @@ -164,17 +164,14 @@ export const createConv2DTransposeMatMulProgramInfo = const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = - isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && (inChannels % 4 === 0 && inChannels % 3) && outChannels % 4 === 0; // TODO: fine tune size const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = isVec4 ? - [8, 8, 1] : - [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; const dispatch = [ Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), @@ -186,72 +183,82 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const inputVariables = [x, w]; - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const effectiveFilterDims = [ + filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) + ]; + + const programUniforms: ProgramUniform[] = [ + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides}, + {type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims}, + {type: DataType.int32, data: pads} + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); - let declareFunctions = ''; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; + inputDependencies.push('rank'); } - programUniforms.push(...createTensorShapeVariables(outputShape)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, + {name: 'filter_dims', type: 'i32', length: filterDims.length}, + {name: 'pads', type: 'i32', length: pads.length} + ]; + appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); + } + return ` + ${utilFunctions('uniforms.result_strides')} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; + ${declareFunctions} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} + ${ + isVec4 ? makeMatMulPackedVec4Source( + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, + undefined, sequentialAccessByThreads)}`; + }; + return { name: 'Conv2DTransposeMatMul', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: + {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, programUniforms }), - getShaderSource: (shaderHelper: ShaderHelper) => ` - ${utilFunctions('uniforms.result_strides')} - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)}; - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ - attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${ - attributes.pads[1] + attributes.pads[3]})/2); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; - ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} - ${ - isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : - makeMatMulPackedSource( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, - undefined, sequentialAccessByThreads)}` + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 2e6392aada45..45c89406e173 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -17,27 +17,22 @@ // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, - dataType: string): string => { - const isChannelsLast = attributes.format === 'NHWC'; + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, + is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, + isChannelsLast = false): string => { const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; - const outputSize = ShapeUtil.size(outputShape); const workPerThread = isVec4 ? 2 : 1; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { @@ -50,20 +45,21 @@ const createConvTranspose2DOpProgramShaderSource = }`; } const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); const inputVariables = [dy, w]; if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); } - const output = outputVariable('result', inputs[0].dataType, outputShape, components); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); + let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. @@ -71,29 +67,29 @@ const createConvTranspose2DOpProgramShaderSource = for (var i = 0; i < ${workPerThread}; i++) { dotProd[i] = vec4<${dataType}>(0.0); } - for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); - let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || + for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); + let wRPerm = uniforms.filter_dims[0] - 1 - wR; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims[1] - 1 - wC; + for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -101,7 +97,7 @@ const createConvTranspose2DOpProgramShaderSource = let idyC: u32 = u32(dyC); let idyC2: u32 = u32(dyC2); if (bDyCVal && bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -123,7 +119,7 @@ const createConvTranspose2DOpProgramShaderSource = dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[${channelDim}]; + let d2Length = uniforms.Dy_shape[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -138,7 +134,7 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -157,7 +153,7 @@ const createConvTranspose2DOpProgramShaderSource = } for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) { - let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : '0.0'}; + let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`}; ${output.set('batch', 'r', 'c + i', 'd1', 'value')}; } }`; @@ -167,39 +163,39 @@ const createConvTranspose2DOpProgramShaderSource = let d1 = ${output.indicesGet('outputIndices', channelDim)}; let r = ${output.indicesGet('outputIndices', rowDim)}; let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - pads; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; - let groupId = d1 / ${outputChannelsPerGroup}; - let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. - var dotProd = 0.0; - for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) { - if (wR % dilations.x != 0) { + var dotProd = ${dataType}(0.0); + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { continue; } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); - let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) { - if (wC % dilations.y != 0) { + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { continue; } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - var inputChannel = groupId * ${inputChannelsPerGroup}; - for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; @@ -209,32 +205,16 @@ const createConvTranspose2DOpProgramShaderSource = } } } - let value = dotProd + ${hasBias ? 'bias[d1]' : '0.0'}; + let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`}; ${output.setByOffset('global_idx', 'value')}; `; return ` - ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const dilations : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; }; @@ -257,19 +237,73 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const isChannelsLast = attributes.format === 'NHWC'; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const strides = [attributes.strides[0], attributes.strides[1]]; + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const dilations = [attributes.dilations[0], attributes.dilations[1]]; + const effectiveFilterDims = [ + filterDims[0] + + (attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + + (attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 + ]; + + const isVec4 = false; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: strides}, + {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, + {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, + {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, + {name: 'filter_dims', type: 'u32', length: filterDims.length}, + {name: 'dilations', type: 'u32', length: filterDims.length}, + {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, + {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return `${ + createConvTranspose2DOpProgramShaderSource( + shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, + isChannelsLast)}`; + }; return { name: 'ConvTranspose2D', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, getRunData: () => ({ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType - }] + }], + programUniforms }), - getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, - dataType), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 47ec16a29671..29c7941e6bd3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -19,11 +19,12 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; -import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,14 +113,14 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let inputRow = tileRow + innerRow; @@ -204,7 +205,7 @@ export const makeMatMulPackedSource = let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { @@ -260,7 +261,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; let tileRowB = i32(localId.y) * ${rowPerThreadB}; // Loop over shared dimension. -for (var t = 0; t < numTiles; t = t + 1) { +for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { @@ -322,7 +323,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${ + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -379,7 +381,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimAOuter && col < uniforms.dimInner) + if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -391,7 +393,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimInner && col < uniforms.dimBOuter) + if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -401,7 +403,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -422,16 +424,10 @@ export const createMatmulProgramInfo = isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; - const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const enableBatchUniforms = enableShapesUniforms(outerDims.length); - const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); - const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; @@ -446,72 +442,62 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); - const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; - + const aRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); - const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; - + const bRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); - const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - const inputVariables = [A, B]; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (enableBatchUniforms) { - programUniforms.push(...createTensorShapeVariables(outerDims)); - } - if (enableAShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(aShapeTemp)); - } - if (enableBShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(bShapeTemp)); - } - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + const programUniforms: ProgramUniform[] = [ + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner} + ]; + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = matMulReadWriteFnSource( - components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], - isChannelsLast); if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); } programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const A = inputVariable('a', inputs[0].dataType, aRank, components); + const B = inputVariable('b', inputs[1].dataType, bRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); + const inputVariables = [A, B]; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + } + const uniforms: UniformsArrayType = + [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; + appendActivationUniforms(activationAttributes, uniforms); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); + return ` ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} `; - // TODO: turn clipMax and clipMin to uniforms. + }; return { name: 'MatMul', shaderCache: { - hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + - `${isVec4}` + - `${isChannelsLast}`, + hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, inputDependencies }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index e1f2a47301bf..37606232a726 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType} from '../types'; +import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common'; export const enum AttentionQkvFormat { unknown, // enum value not set, or depends on qkv projection implementation details @@ -231,20 +231,8 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({...attributes}); - export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { const components = getMaxComponents(d); - const inputHelper = outputVariable('x', input.dataType, input.dims, components); - - let threadMaxValue = 'threadMaxVector'; - if (components === 2) { - threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; - } else if (components === 4) { - threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; - } - const dataType = tensorTypeToWsglStorageType(input.dataType); let WG = 64; const dComp = d / components; if (dComp < WG) { @@ -253,25 +241,42 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(d / components / WG); + const programUniforms: ProgramUniform[] = [ + {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, + {type: DataType.uint32, data: elementsPerWG} + ]; + const dataType = tensorTypeToWsglStorageType(input.dataType, components); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + let threadMaxValue = 'thread_max_vector'; + if (components === 2) { + threadMaxValue = 'max(thread_max_vector.x, thread_max_vector.y)'; + } else if (components === 4) { + threadMaxValue = + 'max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))'; + } + const elemValueType = tensorTypeToWsglValueType(input.dataType); + const uniforms: UniformsArrayType = [ + {name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'}, + {name: 'elements_per_wg', type: 'u32'} + ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dInv: ${dataType} = 1 / ${d}; - const dComp = ${d / components}; + return ` var wgMax: array; var wgSum: array; - - ${shaderHelper.declareVariables(inputHelper)} - @compute @workgroup_size(${WG}, 1, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_index) local_index : u32) { - let localOffset = local_index * ${elementsPerWG}; - let offset: u32 = workgroup_id.x * dComp + localOffset; - - var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} + ${shaderHelper.mainStart([ + WG, 1, 1 + ])} + let localOffset = local_idx * uniforms.elements_per_wg; + let offset: u32 = workgroup_id.x * uniforms.d_comp + localOffset; + + var thread_max_vector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + thread_max_vector = max(${castToF32(elemValueType, components, 'x[offset + i]')}, thread_max_vector); } - wgMax[local_index] = ${threadMaxValue}; + wgMax[local_idx] = ${threadMaxValue}; workgroupBarrier(); var maxValue = -3.402823e+38f; @@ -280,10 +285,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } var sumVector = ${fillVector('f32', components, '0')}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + sumVector += exp(${castToF32(elemValueType, components, 'x[offset + i]')} - maxValue); } - wgSum[local_index] = ${sumVector('sumVector', components)}; + wgSum[local_idx] = ${sumVector('sumVector', components)}; workgroupBarrier(); var sum: f32 = 0; @@ -292,26 +297,24 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } if (sum == 0) { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')}; } } else { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + let f32input = ${castToF32(elemValueType, components, 'x[offset + i]')}; x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); } } }`; + }; context.compute( { name: 'AttentionProbsSoftmax', - shaderCache: {hint: `${d}`}, + shaderCache: {hint: `${WG};${dataType};${components}`}, getShaderSource, - getRunData: () => ({ - outputs: [], - dispatchGroup: {x: n}, - }), + getRunData: () => ({outputs: [], dispatchGroup: {x: n}, programUniforms}), }, {inputs: [input], outputs: []}); }; @@ -326,88 +329,82 @@ const computeAttentionProbs = // TODO: handle mask const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; - - const dataType = tensorTypeToWsglStorageType(q.dataType); - const components = getMaxComponents(parameters.headSize); - const qInput = inputVariable('q', q.dataType, q.dims, components); - const kInput = inputVariable('key', key.dataType, key.dims, components); - const output = outputVariable('output', q.dataType, probsShape); - const vectorizedHeadSize = parameters.headSize / components; - const M = parameters.sequenceLength; - const N = parameters.totalSequenceLength; - const K = vectorizedHeadSize; - const TILE_SIZE = 12; - const dispatch = { x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, + {type: DataType.uint32, data: parameters.totalSequenceLength}, + {type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha} + ]; const inputs = [q, key]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${M}u; - const N: u32 = ${N}u; - const K: u32 = ${K}u; - const alpha: ${dataType} = ${alpha}; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'kv_sequence_length', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType} + ]; + return ` const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - - ${shaderHelper.declareVariables(qInput, kInput, output)} - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + ${shaderHelper.registerUniforms(uniforms).declareVariables(qInput, kInput, output)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} // x holds the N and y holds the M let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; - let lm = m + local_id.y; - let ln = n + local_id.x; - - let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; - let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; + let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K; var value = ${fillVector(dataType, components)}; - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m + local_id.y < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x]; } - if (n + local_id.y < N && w + local_id.x < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x]; } workgroupBarrier(); - for (var k: u32 = 0u; k ({ outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, @@ -423,78 +420,76 @@ const computeAttentionProbs = const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; - - const probsHelper = inputVariable('probs', probs.dataType, probs.dims); - const vHelper = inputVariable('v', v.dataType, v.dims); - const output = outputVariable('output', probs.dataType, outputShape); - - const dataType = tensorTypeToWsglStorageType(probs.dataType); - const TILE_SIZE = 12; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), y: Math.ceil(params.sequenceLength / TILE_SIZE), z: params.batchSize * params.numHeads }; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, + {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, + {type: DataType.uint32, data: params.vHiddenSize} + ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${params.sequenceLength}u; - const N: u32 = ${params.vHeadSize}u; - const K: u32 = ${params.totalSequenceLength}u; - const numHeads: u32 = ${params.numHeads}u; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'} + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; - - var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - - ${shaderHelper.declareVariables(probsHelper, vHelper, output)} - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + ${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} let headIdx = workgroup_id.z; - let m = workgroup_id.y * TILE_SIZE + local_id.y; - let n = workgroup_id.x * TILE_SIZE + local_id.x; + let m = global_id.y; + let n = global_id.x; - let offsetA = headIdx * (M * K) + m * K; - let offsetB = headIdx * (N * K) + n; + let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; + let offsetB = headIdx * (uniforms.N * uniforms.K) + n; - var value = ${dataType}(0); - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { + var value = ${probsHelper.type.storage}(0); + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m < uniforms.M && w + local_id.x < uniforms.K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + if (n < uniforms.N && w + local_id.y < uniforms.K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N]; } workgroupBarrier(); - for (var k: u32 = 0u; k ({ outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, @@ -517,71 +512,71 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { parameters.sequenceLength, parameters.headSize, ]; - - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); - const M = parameters.sequenceLength; const K = parameters.inputHiddenSize; const N = parameters.headSize; - const TILE_SIZE = 12; const dispatch = { x: Math.ceil(parameters.headSize / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; + const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, + {type: DataType.uint32, data: parameters.hiddenSize}, + {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + ]; - const getShaderSource = () => ` - const M: u32 = ${M}u; - const K: u32 = ${K}u; - const N: u32 = ${N}u; - const numHeads: u32 = ${parameters.numHeads}; - const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const outputQ = outputVariable('output_q', inputs[0].dataType, outputShape); + const outputK = outputVariable('output_k', inputs[0].dataType, outputShape); + const outputV = outputVariable('output_v', inputs[0].dataType, outputShape); + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims); + const weight = inputVariable('weight', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const dataType = input.type.storage; + + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'num_heads', type: 'u32'}, + {name: 'head_size', type: 'u32'}, {name: 'hidden_size', type: 'u32'}, {name: 'ldb', type: 'u32'} + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; - var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; - - @group(0) @binding(0) var input: array<${dataType}>; - @group(0) @binding(1) var weight: array<${dataType}>; - @group(0) @binding(2) var bias: array<${dataType}>; - @group(0) @binding(3) var outputQ: array<${dataType}>; - @group(0) @binding(4) var outputK: array<${dataType}>; - @group(0) @binding(5) var outputV: array<${dataType}>; - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - - let batchIndex = workgroup_id.z / ${parameters.numHeads}; - let headNumber = workgroup_id.z % ${parameters.numHeads}; - let m = workgroup_id.y * TILE_SIZE + local_id.y; - let n = workgroup_id.x * TILE_SIZE + local_id.x; - - let inputOffset = batchIndex * (M * K) + m * K; - let biasOffsetQ = headNumber * ${parameters.headSize}; - let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; - let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} + let batchIndex = workgroup_id.z / uniforms.num_heads; + let headNumber = workgroup_id.z % uniforms.num_heads; + let m = global_id.y; + let n = global_id.x; + + let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K; + let biasOffsetQ = headNumber * uniforms.head_size; + let biasOffsetK = uniforms.hidden_size + biasOffsetQ; + let biasOffsetV = uniforms.hidden_size + biasOffsetK; var valueQ = ${dataType}(0); var valueK = ${dataType}(0); var valueV = ${dataType}(0); - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m < uniforms.M && w + local_id.x < uniforms.K) { tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - let offset = n + (w + local_id.y) * ldb; + if (n < uniforms.N && w + local_id.y < uniforms.K) { + let offset = n + (w + local_id.y) * uniforms.ldb; tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; } workgroupBarrier(); - for (var k: u32 = 0u; k { workgroupBarrier(); } - let headOffset = (m * N + n) % ${parameters.headSize}; + let headOffset = (m * uniforms.N + n) % uniforms.head_size; valueQ += bias[headOffset + biasOffsetQ]; valueK += bias[headOffset + biasOffsetK]; valueV += bias[headOffset + biasOffsetV]; - let offset = workgroup_id.z * M * N; - if (m < M && n < N) { - let outputIdx = offset + m * N + n; - outputQ[outputIdx] = valueQ; - outputK[outputIdx] = valueK; - outputV[outputIdx] = valueV; + let offset = workgroup_id.z * uniforms.M * uniforms.N; + if (m < uniforms.M && n < uniforms.N) { + let outputIdx = offset + m * uniforms.N + n; + output_q[outputIdx] = valueQ; + output_k[outputIdx] = valueK; + output_v[outputIdx] = valueV; } }`; - - const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; + }; return context.compute( { name: 'AttentionPrepare', - shaderCache: {hint: JSON.stringify(parameters)}, + shaderCache: {inputDependencies: ['type', 'type', 'type']}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, @@ -619,6 +613,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, ], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index ec9da2613f40..39b932375891 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -3,12 +3,13 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; export interface BatchNormAttributes extends AttributeWithCacheKey { readonly epsilon: number; @@ -61,7 +62,7 @@ const createBatchNormInferenceProgramInfo = const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; const outputSize = ShapeUtil.size(yShape) / components; // Only support uniforms for opset version >= 9 (spatial = true). - const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const useShapesUniforms = spatial; const shapeOrRank = useShapesUniforms ? yShape.length : yShape; const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); @@ -108,7 +109,7 @@ const createBatchNormInferenceProgramInfo = let inputMean = ${inputMean.getByOffset('cOffset')}; let inputVar = ${inputVar.getByOffset('cOffset')}; let x = ${x.getByOffset('global_idx')}; - let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias; + let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias; ${y.setByOffset('global_idx', 'value')} }`; return { @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: useShapesUniforms ? [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(yShape), ] : [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index a81a7a8f1df5..089fecd758e3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl(`vec4<${dataType}>`, dataType)} + ${erfImpl(dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index c033c0ba0535..a094fffe239c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, - typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, - additionalImplementation?: string) => { + typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -31,12 +30,9 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; - const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; - const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; - const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); - const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); - const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); + const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4); + const a = inputVariable('aData', typeA, dimsA.length, 4); + const b = inputVariable('bData', typeB, dimsB.length, 4); let assignment: string; if (vectorize) { @@ -169,30 +165,23 @@ const createBinaryOpProgramInfo = vectorize = true; } cacheKeyAux.push(vectorize); - const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && - enableShapesUniforms(outputShape.length); + return { name, shaderCache: { hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), - inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], + inputDependencies: ['rank', 'rank'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, - a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), + a.dataType, b.dataType, outputDataType, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b.dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ], + programUniforms: [ + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ...createTensorShapeVariables(a.dims, b.dims, outputShape) + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 3ce114c5d388..17ac814c4403 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -3,7 +3,7 @@ import {DataType} from '../../../wasm-common'; import {ShapeUtil} from '../../util'; -import {ProgramUniform} from '../types'; +import {ProgramUniform, ProgramUniformVariableInfo} from '../types'; /** * constant value for a workgroup size. @@ -259,8 +259,16 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => - dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (...dims: ReadonlyArray): ProgramUniform[] => { + const programUniforms: ProgramUniform[] = []; + dims.forEach(dim => { + if (dim.length !== 0) { + programUniforms.push( + {type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)}); + } + }); + return programUniforms; +}; /** * A helper function to get maximum vector size for specified data length @@ -330,18 +338,28 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of variable. * @param index - the index of variable element. * @param length - the length of variable. + * @param type - the type of variable, optional. */ -export const getElementAt = (name: string, index: number|string, length: number): string => { - if (name.startsWith('uniforms.') && length > 4) { - if (typeof (index) === 'string') { - return `${name}[(${index}) / 4][(${index}) % 4]`; - } else { - return `${name}[${Math.floor(index / 4)}][${index % 4}]`; - } - } else { - return length > 1 ? `${name}[${index}]` : name; - } -}; +export const getElementAt = + (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + if (type === 'f16') { + return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; + } else { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } + } else { + if (type === 'f16') { + return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } + }; /** * A helper function to get a IndicesHelper for a given input or output. @@ -688,7 +706,7 @@ export const internalVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** @@ -765,7 +783,7 @@ export interface ShaderHelper { } class ShaderHelperImpl implements ShaderHelper { - constructor(private normalizedDispatchGroup: [number, number, number]) {} + constructor(private normalizedDispatchGroup: [number, number, number], private limits: GPUSupportedLimits) {} guardAgainstOutOfBoundsWorkgroupSizes(size: number|string): string { // Guard against out-of-bounds work group sizes @@ -778,10 +796,27 @@ class ShaderHelperImpl implements ShaderHelper { const workgroupSizeY = typeof workgroupSize === 'number' ? 1 : workgroupSize[1]; const workgroupSizeZ = typeof workgroupSize === 'number' ? 1 : workgroupSize[2]; + if (workgroupSizeX > this.limits.maxComputeWorkgroupSizeX || + workgroupSizeY > this.limits.maxComputeWorkgroupSizeY || + workgroupSizeZ > this.limits.maxComputeWorkgroupSizeZ) { + throw new Error(`workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${ + workgroupSizeZ}] exceeds the maximum workgroup size [${this.limits.maxComputeWorkgroupSizeX}, ${ + this.limits.maxComputeWorkgroupSizeY}, ${this.limits.maxComputeWorkgroupSizeZ}].`); + } + + if (workgroupSizeX * workgroupSizeY * workgroupSizeZ > this.limits.maxComputeInvocationsPerWorkgroup) { + throw new Error(`workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${ + workgroupSizeZ}] exceeds the maximum workgroup invocations ${ + this.limits.maxComputeInvocationsPerWorkgroup}.`); + } + const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, + @builtin(workgroup_id) workgroup_id : vec3, @builtin(local_invocation_id) local_id : vec3` : - `@builtin(local_invocation_index) local_idx : u32, + `@builtin(global_invocation_id) global_id : vec3, + @builtin(local_invocation_id) local_id : vec3, + @builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? @@ -859,7 +894,11 @@ class ShaderHelperImpl implements ShaderHelper { const uniformSnippets: string[] = []; for (const {name, type, length} of this.uniforms) { if (length && length > 4) { - uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + if (type === 'f16') { + uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); + } else { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } } else { const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; uniformSnippets.push(`${name}:${typeTemp}`); @@ -879,9 +918,24 @@ class ShaderHelperImpl implements ShaderHelper { return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') + this.internalVariables.map(i => i.impl()).join('\n'); } + + /** + * Get the variable info of the shader program. + */ + get variablesInfo(): ProgramUniformVariableInfo[]|undefined { + if (this.uniforms.length === 0) { + return undefined; + } + + const uniformWgslTypeToDataType = (type: UniformDataElementType) => + ([DataType.uint32, DataType.float16, DataType.float, + DataType.int32][['u32', 'f16', 'f32', 'i32'].indexOf(type)]); + return this.uniforms.map(u => ([uniformWgslTypeToDataType(u.type), u.length ?? 1])); + } } -export const createShaderHelper = (dispatchGroup: [number, number, number]) => new ShaderHelperImpl(dispatchGroup); +export const createShaderHelper = (dispatchGroup: [number, number, number], limits: GPUSupportedLimits) => + new ShaderHelperImpl(dispatchGroup, limits); /** * This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40 @@ -906,6 +960,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; - -// TODO: remove this when all related uses have been removed. -export const enableShapesUniforms = (_rank: number): boolean => true; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 43cc4a4c080b..010ee589c44f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -1,36 +1,44 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -const validateInputs = (inputs: readonly TensorView[]): void => { +const validateInputs = (inputs: readonly TensorView[], axis: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } - - const inputType = inputs[0].dataType; - const inputDimensionality = inputs[0].dims.length; - - for (const input of inputs) { + const referenceIndex = 0; + const referenceInput = inputs[referenceIndex]; + const inputType = referenceInput.dataType; + const inputRank = referenceInput.dims.length; + inputs.forEach((input, i) => { + if (i === referenceIndex) { + return; + } // make sure types of all inputs match if (input.dataType !== inputType) { throw new Error('input tensors should be one type'); } - // make sure the dimensionality of all inputs are the same - if (input.dims.length !== inputDimensionality) { + if (input.dims.length !== inputRank) { throw new Error('input tensors should have the same shape'); } - } + input.dims.forEach((dim, i) => { + if (i !== axis && dim !== referenceInput.dims[i]) { + throw new Error('non concat dimensions must match'); + } + }); + }); }; const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` @@ -63,75 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); - } - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === adjustedAxis) { - outputShape[adjustedAxis] += dataNShape[axisIndex]; +const createConcatProgramInfo = + (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => { + const outputSize = ShapeUtil.size(outputShape); + + const sizeInConcatAxis = new Array(inputs.length); + const inputVars = new Array(inputs.length); + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputRanks = []; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { - throw new Error('non concat dimensions must match'); + for (let i = 0; i < inputs.length; ++i) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); } - } - } - - const outputSize = ShapeUtil.size(outputShape); - - const sizeInConcatAxis = new Array(inputs.length); - const inputVars = new Array(inputs.length); - const dataType = inputs[0].dataType; - - let previousSum = 0; - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputShapeOrRanks = []; - const enableInputShapesUniforms = []; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; - for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[adjustedAxis]; - sizeInConcatAxis[i] = previousSum; - enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); - inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); - inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); - inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); - programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); - } - for (let i = 0; i < inputs.length; ++i) { - if (enableInputShapesUniforms[i]) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } - - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); - - const indicesAxis = output.indicesGet('indices', adjustedAxis); - const sizeInConcatAxisStr = - Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const output = outputVariable('output', dataType, outputShape.length); + const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${(() => { - shaderHelper.registerUniform('outputSize', 'u32'); - for (let i = 0; i < inputs.length; i++) { - shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); - } - return shaderHelper.declareVariables(...inputVars, output); - })()} + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} @@ -149,21 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P ${assignOutputData(inputVars, output)} }`; - return { - name: 'Concat', - shaderCache: {hint: `${axis}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms, - }), - getShaderSource, - }; -}; + return { + name: 'Concat', + shaderCache: {hint: `${adjustedAxis}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, + }), + getShaderSource, + }; + }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - validateInputs(context.inputs); - context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); + const inputs = context.inputs; + const inputShape = inputs[0].dims; + const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); + validateInputs(inputs, adjustedAxis); + const outputShape = inputShape.slice(); + outputShape[adjustedAxis] = + inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); + // 0 length tensors are valid for concat, remove them + const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute( + createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 14482272bad3..924030125c42 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo} from '../types'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActivationSnippet} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -27,52 +28,70 @@ export const createGroupedConvProgramInfo = xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', inputs[0].dataType, outputShape); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); - const x = inputVariable('x', inputs[0].dataType, xShape); - const w = inputVariable('w', inputs[1].dataType, wShape); - const inputVars = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations}, + {type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.uint32, data: outputChannelsPerGroup} + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(xShape, wShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); - const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); - - ${shaderHelper.declareVariables(...inputVars, output)} + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); + const x = inputVariable('x', inputs[0].dataType, xShape.length); + const w = inputVariable('w', inputs[1].dataType, wShape.length); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims.length)); + } - ${activationFunction} + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, + {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + appendActivationUniforms(attributes, uniforms); + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; let batch: u32 = outputIndices[0]; let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ - isChannelLast ? 2 : 3}]) * strides - pads; - let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; + let group_id: u32 = output_channel / uniforms.output_channels_per_group; var value: ${output.type.value} = ${output.type.value}(0); - for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { - let input_channel = group_id * ${wShape[1]}u + wInChannel; - for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { - let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { + let input_channel = group_id * uniforms.w_shape[1] + wInChannel; + for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { + let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; - if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) { continue; } - for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { - let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; - if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { + let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; + if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) { continue; } let xVal = ${ - isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : - x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : + x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; value += xVal*wVal; } @@ -82,15 +101,115 @@ export const createGroupedConvProgramInfo = ${applyActivation} ${output.setByOffset('global_idx', 'value')} }`; + }; return { name: 'GroupedConv', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType }], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource, + }; + }; + +export const createGroupedConvVectorizeProgramInfo = + (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[]): ProgramInfo => { + const hasBias = inputs.length > 2; + const components = getMaxComponents(outputShape[3]); + const outputNumber = getMaxComponents(outputShape[2]); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components]; + const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components]; + const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, + {type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader)); + const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); + const x = inputVariable('x', inputs[0].dataType, xShape.length, components); + const w = inputVariable('w', inputs[1].dataType, wShape.length, components); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); + } + const processBias = hasBias ? 'value += b[output_channel];' : ''; + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'strides', type: 'i32', length: 2}, + {name: 'pads', type: 'i32', length: 2}, + ]; + appendActivationUniforms(attributes, uniforms); + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let width0 = uniforms.output_shape[3]; + let output_channel = global_idx % width0; + var index1 = global_idx / width0; + let width1 = uniforms.output_shape[2] / ${outputNumber}u; + let col = (index1 % width1) * ${outputNumber}u; + index1 = index1 / width1; + let row = index1 % uniforms.output_shape[1]; + let batch = index1 / uniforms.output_shape[1]; + + let x_corner = vec2(i32(row), i32(col)) * uniforms.strides - uniforms.pads; + + var x_vals: array<${x.type.value}, ${xNumber}>; + var values: array<${output.type.value}, ${outputNumber}>; + let input_channel = output_channel; + // Use constant instead of uniform can give better performance for w's height/width. + for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { + let x_height = x_corner.x + i32(w_height); + if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) { + for (var i = 0; i < ${xNumber}; i++) { + let x_width = x_corner.y + i; + if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { + x_vals[i] = ${x.get('batch', 'u32(x_height)', 'u32(x_width)', 'input_channel')}; + } else { + x_vals[i] = ${x.type.value}(0); + } + } + for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { + let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; + for (var i = 0u; i < ${outputNumber}u; i++) { + values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]); + } + } + } + } + + for (var i = 0u; i < ${outputNumber}u; i++) { + var value = values[i]; + ${processBias} + ${applyActivation} + ${output.set('batch', 'row', 'col + i', 'output_channel', 'value')}; + } + }`; + }; + + return { + name: 'GroupedConv-Vectorize', + shaderCache: { + hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 32b1d52ed94c..41bd1d5326dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -2,7 +2,6 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; @@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes { readonly outputShape: readonly number[]; } - const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - const cacheKey = attributes.cacheKey + [ - kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), - dilations.join(',') - ].join('_'); - Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); return newAttributes; }; @@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record const wIsConst = (attributes.wIsConst as () => boolean)(); const outputPadding = attributes.outputPadding as [number, number, number, number]; const outputShape = attributes.outputShape as [number, number]; - return createAttributeWithCacheKey({ + return { autoPad, format, dilations, @@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record pads, strides, wIsConst, - ...activationAttributes - }); + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { @@ -273,7 +268,7 @@ const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttri //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW] context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]) ]; - if (inputs.length === 3) { + if (context.inputs.length === 3) { inputs.push(context.inputs[2]); } let kernelShape = attributes.kernelShape; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 33a5db7ff6b2..b68d4dcae4cb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -3,12 +3,12 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createGroupedConvProgramInfo} from './conv-grouped'; +import {createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo} from './conv-grouped'; import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; import {createNaiveMatmulProgramInfo} from './matmul'; import {createTransposeProgramInfo} from './transpose'; @@ -110,7 +110,7 @@ const getAdjustedConvAttributes = (attributes: T, inpu // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, {kernelShape, pads}); return newAttributes; }; @@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record): ConvAt const strides = attributes.strides as [number, number]; const wIsConst = (attributes.w_is_const as () => boolean)(); - return createAttributeWithCacheKey( - {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + return { + autoPad, + format, + dilations, + group, + kernelShape, + pads, + strides, + wIsConst, + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { @@ -136,12 +146,37 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // check attributes // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ + const isChannelsLast = attributes.format === 'NHWC'; if (attributes.group !== 1) { - context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); + // NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other + // GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs. + // [webgpu]Conv - conv - vectorize group - B + // [webgpu]Conv - conv - vectorize group - D + const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere'); + if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && + inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { + const outputShape = calculateOutputShape( + inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, + isChannelsLast); + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + const convInputs = [inputs[0], transposedWeight]; + if (inputs.length === 3) { + convInputs.push(inputs[2]); + } + context.compute( + createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {inputs: convInputs}); + } else { + context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); + } return; } - const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index 2ff909c30e62..6080301d9946 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -54,8 +54,8 @@ const createCumsumProgramInfo = outputs: [{dims: inputShape, dataType: inputType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axis}, - ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, + ...createTensorShapeVariables(inputShape, inputShape) ] }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/depth-to-space.ts b/js/web/lib/wasm/jsep/webgpu/ops/depth-to-space.ts new file mode 100644 index 000000000000..83809b3d5de6 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/depth-to-space.ts @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + +export interface FormatAttributes { + readonly format: 'NHWC'|'NCHW'; +} + +export interface DepthToSpaceAttributes extends FormatAttributes, AttributeWithCacheKey { + readonly blocksize: number; + readonly mode: string; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('DepthToSpace requires 1 input.'); + } + if (inputs[0].dims.length !== 4) { + throw new Error('DepthToSpace requires 4D input.'); + } +}; + +const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, output: IndicesHelper): string => { + const reverseFunc = []; + reverseFunc.push(`fn perm(i: ${output.type.indices}) -> ${input.type.indices} { + var a: ${input.type.indices};`); + for (let i = 0; i < rank; ++i) { + reverseFunc.push(input.indicesSet('a', perm[i], `i[${i}]`)); + } + reverseFunc.push('return a;}'); + return reverseFunc.join('\n'); +}; + +const createDepthToSpaceProgramInfo = (inputTensor: TensorView, attributes: DepthToSpaceAttributes): ProgramInfo => { + let n: number, h: number, w: number, c: number; + let shape: number[]; + let perm: number[]; + const isChannelLast = attributes.format === 'NHWC'; + const blocksize = attributes.blocksize; + const isDCRmode = attributes.mode === 'DCR'; + if (isChannelLast) { + [n, h, w, c] = inputTensor.dims; + shape = isDCRmode ? [n, h, w, blocksize, blocksize, c / (blocksize ** 2)] : + [n, h, w, c / (blocksize ** 2), blocksize, blocksize]; + perm = isDCRmode ? [0, 1, 3, 2, 4, 5] : [0, 1, 4, 2, 5, 3]; + } else { + [n, h, w, c] = [inputTensor.dims[0], inputTensor.dims[2], inputTensor.dims[3], inputTensor.dims[1]]; + shape = isDCRmode ? [n, blocksize, blocksize, c / (blocksize ** 2), h, w] : + [n, c / (blocksize ** 2), blocksize, blocksize, h, w]; + perm = isDCRmode ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3]; + } + const reshapedInputTensor = inputTensor.reshape(shape); + const reshapedInputRank = reshapedInputTensor.dims.length; + const inputDataType = inputTensor.dataType; + + const reshapedInput = inputVariable('a', inputDataType, reshapedInputRank); + const permedOutput = outputVariable('output', inputDataType, reshapedInputRank); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(reshapedInput, permedOutput)} + + ${permFunctionBody(perm, reshapedInputRank, reshapedInput, permedOutput)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + + let indices = ${permedOutput.offsetToIndices('global_idx')}; + let aIndices = perm(indices); + + ${permedOutput.setByOffset('global_idx', reshapedInput.getByIndices('aIndices'))} + }`; + + return { + name: 'DepthToSpace', + shaderCache: {hint: `${inputTensor.dims};${attributes.blocksize};${attributes.mode}`, inputDependencies: ['rank']}, + getRunData: (inputs) => { + const outputShape = isChannelLast ? [n, h * blocksize, w * blocksize, c / (blocksize ** 2)] : + [n, c / (blocksize ** 2), h * blocksize, w * blocksize]; + const outputSize = ShapeUtil.size(outputShape); + const shapeBeforePerm = reshapedInputTensor.dims; + const shapeAfterPerm = ShapeUtil.sortBasedOnPerm(shapeBeforePerm, perm); + return { + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(shapeBeforePerm, shapeAfterPerm)], + }; + }, + getShaderSource, + }; +}; + +export const depthToSpace = (context: ComputeContext, attributes: DepthToSpaceAttributes): void => { + validateInputs(context.inputs); + context.compute(createDepthToSpaceProgramInfo(context.inputs[0], attributes)); +}; + +export const parseDepthToSpaceAttributes = (attributes: Record): DepthToSpaceAttributes => + createAttributeWithCacheKey({ + blocksize: attributes.blocksize as number, + mode: attributes.mode as string, + format: attributes.format as 'NHWC' | 'NCHW' + }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 4db7c04ad67b..19a009c2eb79 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; - +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -181,14 +181,12 @@ class EinsumEquation { const appendMax = (name: string): string => name + '_max'; const createEinsumProgramInfo = - (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, - einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { - const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); - const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + (inputShapes: Array, dataType: number, einsumEquation: EinsumEquation, + outputShape: readonly number[]): ProgramInfo => { + const ranks = inputShapes.map((dims) => dims.length); + const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank)); const outputSize = ShapeUtil.size(outputShape); - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); + const output = outputVariable('output', dataType, outputShape.length); const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -269,24 +267,20 @@ const createEinsumProgramInfo = }; return { name: 'Einsum', - shaderCache: { - hint: einsumEquation.equation, - inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') - }, + shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')}, getRunData: () => { // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The // filter is added to make sure that dimValue is never 0. const programUniformsInit: ProgramUniform[] = uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) - .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); - programUniformsInit.push({type: 'uint32', data: outputSize}); + .map( + (symbol) => + ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: DataType.uint32, data: outputSize}); const programUniforms: ProgramUniform[] = - inputShapes.filter((_, index) => enableInputShapesUniforms[index]) - .map((dims, _) => [...createTensorShapeVariables(dims)]) + inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + programUniforms.push(...createTensorShapeVariables(outputShape)); return ({ outputs: [{dims: outputShape, dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, @@ -299,11 +293,9 @@ const createEinsumProgramInfo = export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); const outputShape = einsumEquation.outputDims; const inputShapes = context.inputs.map((input, _) => input.dims); - context.compute(createEinsumProgramInfo( - enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); + context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 3dc4e957e0fe..80ee906423e1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -47,17 +47,11 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const outputShape: number[] = calculateOutputShape(inputShape, shape); const dataType = inputs[0].dataType; const components = dataType === DataType.bool ? 4 : 1; - const outputSize = ShapeUtil.size(outputShape) / components; - - const enableInputShapeUniform = enableShapesUniforms(inputShape.length); - const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); - + const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const getShaderSource = (shaderHelper: ShaderHelper) => { - const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; - const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; - const input = inputVariable('input', dataType, inputShapeOrRank, components); - const output = outputVariable('output', dataType, outputShapeOrRank, components); + const input = inputVariable('input', dataType, inputShape.length, components); + const output = outputVariable('output', dataType, outputShape.length, components); let assignment: string; if (dataType === DataType.bool) { const singleAssignment = (resStr: string, x: number, typeCast = '') => ` @@ -90,16 +84,11 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ${assignment}`; }; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; - if (enableInputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(inputShape)); - } - if (enableOutputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + const programUniforms: ProgramUniform[] = + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]; return { name: 'Expand', - shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, + shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts new file mode 100644 index 000000000000..f50a6a3f011f --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common'; +import * as unary from './unary-op'; + +// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias. + +const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => { + const dataType = inputTensors[0].dataType; + const outputSize = ShapeUtil.size(inputTensors[0].dims); + const biasLength = ShapeUtil.size(inputTensors[1].dims); + // can only use vec4 when bias length is multiple of 4 + const useVec4 = biasLength % 4 === 0; + const getShaderSource = (shaderHelper: ShaderHelper): string => { + const x = inputVariable('x', dataType, [1], 4); + const bias = inputVariable('bias', dataType, [1], 4); + const y = outputVariable('y', dataType, [1], 4); + + const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}]; + + const singleElementBias = (i: 0|1|2|3) => ` + let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size; + let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`; + const biasGetExpression = useVec4 ? + ` + let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` : + `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)} + let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`; + + return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)} + + ${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))} + + ${shaderHelper.mainStart(WORKGROUP_SIZE)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')} + + let x = ${x.getByOffset('global_idx')}; + ${biasGetExpression} + let x_in = x + bias; + ${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))} + }`; + }; + + return { + name: 'FastGeluWithBias', + shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']}, + getShaderSource, + getRunData: (inputs) => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + programUniforms: + [{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}], + dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)} + }) + }; +}; + +export const fastGelu = (context: ComputeContext): void => { + if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) { + unary.fastGelu(context); + } else { + context.compute(createFastGeluProgramInfo(context.inputs)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 0b5c0db2b511..6e66abacf347 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -1,44 +1,78 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {MAX_CLIP, MIN_CLIP} from '../../util'; +import {ProgramUniform} from '../types'; + +import {UniformsArrayType} from './common'; export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; - readonly activationCacheKey: string; + readonly alpha?: number; + readonly beta?: number; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): - {activationFunction: string; applyActivation: string} => { +export const getActivationSnippet = + (attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => { switch (attributes.activation) { case 'Relu': - return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; + return `value = max(value, ${valueType}(0.0));`; case 'Sigmoid': - return { - activationFunction: '', - applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` - }; + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; case 'Clip': - return { - activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ - attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. + return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ + baseType}(uniforms.clip_max)));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ + baseType}(uniforms.beta)));`; + case 'LeakyRelu': + return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case '': + return ''; + // TODO: adding other activations that can be fused. default: - return {activationFunction: '', applyActivation: ''}; + throw new Error(`Unsupported activation ${attributes.activation}`); + } + }; + +export const appendActivationUniformsData = + (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { + if (attributes.activation === 'Clip') { + programUniform.push( + {type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!}); + } else if (attributes.activation === 'HardSigmoid') { + programUniform.push( + {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); + } else if (attributes.activation === 'LeakyRelu') { + programUniform.push({type: DataType.float, data: attributes.alpha!}); } }; +export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => { + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } else if (attributes.activation === 'HardSigmoid') { + uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); + } else if (attributes.activation === 'LeakyRelu') { + uniforms.push({name: 'alpha', type: 'f32'}); + } +}; + export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { const activation = attributes?.activation as string || ''; - - if (activation === 'Clip') { + if (activation === 'HardSigmoid') { + const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; + return {activation, alpha, beta}; + } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; - return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + return {activation, clipMax, clipMin}; + } else if (activation === 'LeakyRelu') { + const [alpha] = attributes?.activation_params as [number] || [0.01]; + return {activation, alpha}; } - return {activation, activationCacheKey: activation}; + return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a945954adcaa..4ab6c175a67e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -46,11 +47,11 @@ const createGatherElementsProgramInfo = const output = outputVariable('output', inputOutputDataType, outputShape.length); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; - programUniforms.push(...createTensorShapeVariables(inputShape)); - programUniforms.push(...createTensorShapeVariables(indicesShape)); - programUniforms.push(...createTensorShapeVariables(outputShape)); + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis} + ]; + programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 53ca094abfd6..d48bb909f7f8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -31,35 +31,17 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const axisDimLimit = inputShape[axis]; const components = inputs[0].dataType === DataType.bool ? 4 : 1; - const outputSize = ShapeUtil.size(outputShape) / components; + const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); - const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); - const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; - const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); - const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; - if (enableInputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - } - if (enableIndicesShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); - } - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } - - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape) + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { - const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components); - const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); - const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components); + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length, components); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length); + const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); const calcDataIndices = (x: number|string): string => { const indicesRank = indicesShape.length; @@ -73,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath if (idx${x} < 0) { idx${x} = idx${x} + uniforms.axisDimLimit; } - var dataIndices${x} = ${data.type.indices}(0); + var dataIndices${x} : ${data.type.indices}; `; for (let i = 0, j = 0; i < inputRank; i++) { if (i === axis) { @@ -127,7 +109,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }; return { name: 'Gather', - shaderCache: {hint: attributes.cacheKey, inputDependencies}, + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 1c5d28e4b8e3..76302e1af2e5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,12 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs) { @@ -34,25 +35,6 @@ export interface GemmAttributes extends AttributeWithCacheKey { beta: number; } -const offsetC = (m: number, n: number, dims: readonly number[]): string => { - if (dims.length === 0) { - return '0u'; - } - - const broadcastM = (dims.length === 1 && m !== 1) || (dims.length === 2 && dims[0] !== m); - const broadcastN = dims[dims.length - 1] !== n; - - let offset = '0u'; - if (!broadcastM) { - offset += `+ m * ${dims[dims.length - 1]}u`; - } - if (!broadcastN) { - offset += '+n'; - } - - return offset; -}; - const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => { const aShape = inputs[0].dims.slice(); const bShape = inputs[1].dims.slice(); @@ -63,68 +45,93 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt throw new Error('Can\'t use gemm on the given tensors'); } const outputSize = ShapeUtil.size(outputShape); - let line = ''; - if (attributes.transA && attributes.transB) { - line = 'value += a[k * M + m] * b[n * K + k];'; - } else if (attributes.transA && !attributes.transB) { - line = 'value += a[k * M + m] * b[k * N + n];'; - } else if (!attributes.transA && attributes.transB) { - line = 'value += a[m * K + k] * b[n * K + k];'; - } else if (!attributes.transA && !attributes.transB) { - line = 'value += a[m * K + k] * b[k * N + n];'; - } - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;'; - const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : ''; - const inputStorageBuffersDeclarations = [ - `@group(0) @binding(0) var a : array<${dataType}>;`, - `@group(0) @binding(1) var b : array<${dataType}>;` + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha}, + {type: DataType.float, data: attributes.beta} ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; if (inputs.length === 3) { - inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var c : array<${dataType}>;`); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${M}u; - const N: u32 = ${N}u; - const K: u32 = ${K}u; - const alpha = ${dataType}(${attributes.alpha}); - const beta = ${dataType}(${attributes.beta}); + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + let line = ''; + if (attributes.transA && attributes.transB) { + line = 'value += a[k * uniforms.M + m] * b[n * uniforms.K + k];'; + } else if (attributes.transA && !attributes.transB) { + line = 'value += a[k * uniforms.M + m] * b[k * uniforms.N + n];'; + } else if (!attributes.transA && attributes.transB) { + line = 'value += a[m * uniforms.K + k] * b[n * uniforms.K + k];'; + } else if (!attributes.transA && !attributes.transB) { + line = 'value += a[m * uniforms.K + k] * b[k * uniforms.N + n];'; + } - ${inputStorageBuffersDeclarations.join('\n')} - @group(0) @binding(${inputs.length}) var output : array<${dataType}>; + const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;'; + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims); + const b = inputVariable('b', inputs[1].dataType, inputs[1].dims); + const dataType = a.type.value; + let c: IndicesHelper|null = null; + const variables = [a, b]; + if (inputs.length === 3) { + c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length); + variables.push(c); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + variables.push(output); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'}, + {name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - let m = global_idx / N; - let n = global_idx % N; + let m = global_idx / uniforms.N; + let n = global_idx % uniforms.N; var value = ${dataType}(0); - for (var k: u32 = 0u; k<${K}u; k++) { + for (var k: u32 = 0u; k < uniforms.K; k++) { ${line} } ${calculateAlpha} - ${calculateC} + ${(() => { + if (c != null) { + return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${ + dataType}(uniforms.beta) * ${c.getByOffset('cOffset')};`; + } + return ''; + })()} output[global_idx] = value; - }`; + }; + return { name: 'Gemm', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; }; +export const parseGemmAttributes = (attributes: Record): GemmAttributes => { + const transA = attributes.transA as boolean; + const transB = attributes.transB as boolean; + const alpha = attributes.alpha as number; + const beta = attributes.beta as number; + return {transA, transB, alpha, beta, cacheKey: `${attributes.transA};${attributes.transB};${attributes.alpha === 1}`}; +}; + export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => { validateInputs(context.inputs); context.compute(createGemmProgramInfo(context.inputs, attributes)); }; - -export const parseGemmAttributes = (attributes: Record): GemmAttributes => - createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 3a84844544c9..c1d762e62aaa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -4,58 +4,56 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; -export interface InstanceNormAttributes extends AttributeWithCacheKey { +export interface InstanceNormAttributes { epsilon: number; format: 'NHWC'|'NCHW'; } -const metadata = { - name: 'InstanceNormalization' -}; - const createInstanceNormProgramInfo = (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { const xShape = inputs[0].dims; - const outputShape = xShape; const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); const components = getMaxComponents(normSize); const normPackedSize = normSize / components; - const C = xShape[1]; - const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); - const variables = [x, scale, bias, output]; - const dataType = x.type.value; - const f32Type = components === 1 ? 'f32' : `vec${components}`; - const workgroupSize = 64; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - - const C: u32 = ${C}; - const normSize: u32 = ${normSize}; - const epsilon: f32 = ${attributes.epsilon}; + const inputShape = [xShape[0], xShape[1], normPackedSize]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; + const programUniforms: ProgramUniform[] = + [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; + programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); + const variables = [x, scale, bias, output]; + const dataType = x.type.value; + const f32Type = components === 1 ? 'f32' : `vec${components}`; + const workgroupSize = 64; + + const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}]; + return ` var meanShared : f32; var squaredNormShared : f32; var workgroupShared : array<${f32Type}, ${workgroupSize}>; const workgroupSize = ${workgroupSize}u; - ${shaderHelper.declareVariables(...variables)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart(workgroupSize)} let norm = global_idx / workgroupSize; - let batch = norm / C; - let channel = norm % C; + let batch = norm / uniforms.x_shape[1]; + let channel = norm % uniforms.x_shape[1]; let localIndex = local_id.x; // initialize workgroup memory var initial = ${f32Type}(0); - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); } workgroupShared[localIndex] = initial; @@ -69,13 +67,13 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize); + meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize); } workgroupBarrier(); // reinitialize workgroup memory. initial = ${f32Type}(0); - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); initial = initial + deviation * deviation; } @@ -94,23 +92,26 @@ const createInstanceNormProgramInfo = } workgroupBarrier(); - let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon); + let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ - f32Type}(channelShift)); + f32Type}(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`; + }; return { - ...metadata, - shaderCache: {hint: attributes.cacheKey}, + ...{name: 'InstanceNormalization'}, + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, ], - dispatchGroup: {x: normCount} + dispatchGroup: {x: normCount}, + programUniforms }), getShaderSource, }; @@ -120,10 +121,6 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, epsilon: number) => { const components = getMaxComponents(c); - const inputHelper = inputVariable('input', input.dataType, input.dims, components); - const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); - const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); - const WG = 64; // we will store channel scale and channel shift in [2, components] matrix // or in vec2 when components == 1 @@ -133,90 +130,107 @@ const computeMean = const unitsOfWork = n * c / components; const wgSize = Math.ceil(h / WG); - const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${h}; - const C: u32 = ${c / components}; - const imageSize: u32 = ${h * c / components}; + const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + const meanProgramUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(h * c / components)} + ]; + const getMeanShaderSource = (shaderHelper: ShaderHelper) => { + const inputHelper = inputVariable('input', input.dataType, input.dims, components); + return ` ${shaderHelper.declareVariables(inputHelper)} @group(0) @binding(1) var output : array<${outputType}>; + struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32}; + @group(0) @binding(2) var uniforms: Uniforms; ${shaderHelper.mainStart(WG)} - let currentImageNumber = global_idx / ${WG} / C; - let currentChannelNumber = (global_idx / ${WG}) % C; - let wgId = global_idx % ${WG}; - let wgOffset = wgId * ${wgSize}; - if (wgOffset >= H) { + let currentImageNumber = global_idx / ${WG} / uniforms.C; + let currentChannelNumber = (global_idx / ${WG}) % uniforms.C; + let wgOffset = local_id.x * uniforms.wg_size; + if (wgOffset >= uniforms.H) { return; } - let wgMax = min(wgOffset + ${wgSize}, H); + let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H); - let offset = currentImageNumber * imageSize + currentChannelNumber; + let offset = currentImageNumber * uniforms.image_size + currentChannelNumber; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; for (var i: u32 = wgOffset; i < wgMax; i++) { - let value = ${sumCastType}(input[offset + i * C]); + let value = ${sumCastType}(input[offset + i * uniforms.C]); sum += value; squaredSum += value * value; } output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; + }; const meanValues = context.compute( { name: 'InstanceNormComputeMean', - shaderCache: {hint: JSON.stringify({components, n, h, c})}, + shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, WG, 2], dataType: DataType.float}, ], dispatchGroup: {x: n * c / components}, + programUniforms: meanProgramUniforms }), getShaderSource: getMeanShaderSource, }, {inputs: [input], outputs: [-1]})[0]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${h}; - const C: u32 = ${c / components}; - const imageSize: u32 = ${WG * c / components}; - const epsilon: f32 = ${epsilon}; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(WG * c / components)} + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); + const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + return ` @group(0) @binding(0) var input : array<${outputType}>; @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; @group(0) @binding(3) var output : array<${outputType}>; + struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32}; + @group(0) @binding(4) var uniforms: Uniforms; ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} - let currentImageNumber = global_idx / C; - let currentChannelNumber = global_idx % C; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')} + let currentImageNumber = global_idx / uniforms.C; + let currentChannelNumber = global_idx % uniforms.C; - let offset = currentImageNumber * imageSize; + let offset = currentImageNumber * uniforms.image_size; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < ${WG}; i++) { + for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) { let value = input[offset + i + currentChannelNumber * ${WG}]; sum += value[0]; squaredSum += value[1]; } - sum = sum / f32(H); - squaredSum = squaredSum / f32(H); - let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); + sum = sum / f32(uniforms.H); + squaredSum = squaredSum / f32(uniforms.H); + let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon})); let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; }`; - + }; return context.compute( { name: 'InstanceNormComputeChannelScaleShift', - shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})}, + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: {hint: `${components};${epsilon}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, 2], dataType: DataType.float}, ], dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }, @@ -230,50 +244,51 @@ const createInstanceNormNHWCProgramInfo = const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; - const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; - const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); - const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; + const programUniforms: ProgramUniform[] = + [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${H}; - const C: u32 = ${C / components}; + const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); + const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); + return ` @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; @group(0) @binding(1) var scaleInput : array<${scaleType}>; @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; + struct Uniforms {H: u32, C : u32}; + @group(0) @binding(3) var uniforms: Uniforms; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / (C * H); - let currentChannelNumber = global_idx % C; + let currentImageNumber = global_idx / (uniforms.C * uniforms.H); + let currentChannelNumber = global_idx % uniforms.C; - let scaleOffset = currentImageNumber * C + currentChannelNumber; + let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber; let scale = scaleInput[scaleOffset]; output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; + }; context.compute( { - name: 'InstanceNormalization', - shaderCache: {hint: `${attributes.cacheKey}`}, + name: 'InstanceNormalizationNHWC', + shaderCache: {hint: `${components}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }, {inputs: [inputs[0], channelScaleShift]}); }; -export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes => - createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format}); - export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { if (attributes.format === 'NHWC') { createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 8a9eeecf2c68..b2a1bbe2bea4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -4,12 +4,12 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType,} from './common'; -export interface LayerNormAttributes extends AttributeWithCacheKey { +interface LayerNormAttributes { + simplified: boolean; axis: number; epsilon: number; } @@ -22,9 +22,11 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const createLayerNormProgramInfo = (inputs: readonly TensorView[], attributes: LayerNormAttributes, outputCount: number): ProgramInfo => { + const simplified = attributes.simplified; + const xShape = inputs[0].dims; const scale = inputs[1]; - const bias = inputs[2]; + const bias = !simplified && inputs[2]; const outputShape = xShape; const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); @@ -39,7 +41,7 @@ const createLayerNormProgramInfo = Got scale size of ${scaleSize} and bias size of ${biasSize}`); } - const meanInvStdDevDim = []; + const meanInvStdDevDim: number[] = []; for (let i = 0; i < xShape.length; ++i) { if (i < axis) { meanInvStdDevDim.push(xShape[i]); @@ -47,60 +49,69 @@ const createLayerNormProgramInfo = meanInvStdDevDim.push(1); } } - const components = getMaxComponents(normSize); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('scale', scale.dataType, scale.dims, components), + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize}, + {type: DataType.uint32, data: Math.floor(normSize / components)}, + {type: DataType.float, data: attributes.epsilon} ]; if (bias) { - variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + inputDependencies.push('type'); } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; - if (hasMeanDataOutput) { - variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: f32 = ${normSize}; - const normSizeVectorized: u32 = ${normSize / components}; - const epsilon: f32 = ${attributes.epsilon}; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanDataOutput) { + variables.push(outputVariable('mean_data_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } - ${shaderHelper.declareVariables(...variables)} + const uniforms: UniformsArrayType = [ + {name: 'norm_count', type: 'u32'}, {name: 'norm_size', type: 'f32'}, + {name: 'norm_size_vectorized', type: 'u32'}, {name: 'epsilon', type: 'f32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} - let offset = global_idx * normSizeVectorized; - var meanVector = ${fillVector('f32', components)}; - var meanSquareVector = ${fillVector('f32', components)}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} + let offset = global_idx * uniforms.norm_size_vectorized; + var mean_vector = ${fillVector('f32', components)}; + var mean_square_vector = ${fillVector('f32', components)}; - for (var h: u32 = 0u; h < normSizeVectorized; h++) { + for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; - meanVector += value; - meanSquareVector += value * value; + mean_vector += value; + mean_square_vector += value * value; } - let mean = ${sumVector('meanVector', components)} / normSize; - let meanSquare = sqrt(${sumVector('meanSquareVector', components)} - / normSize - mean * mean + epsilon); + let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; + let inv_std_dev = inverseSqrt(${sumVector('mean_square_vector', components)} / uniforms.norm_size ${ + simplified ? '' : '- mean * mean'} + uniforms.epsilon); - for (var j: u32 = 0; j < normSizeVectorized; j++) { + for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale + output[j + offset] = ${variables[0].type.value}((f32input ${simplified ? '' : '- mean'}) * inv_std_dev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } - ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''}; + ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (hasMeanDataOutput) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -111,15 +122,13 @@ const createLayerNormProgramInfo = return { name: 'LayerNormalization', - shaderCache: {hint: `${attributes.cacheKey}|${outputCount}|${inputs.length}`}, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}}), + shaderCache: {hint: `${components};${outputCount};${simplified}`, inputDependencies}, + getRunData: () => + ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}, programUniforms}), getShaderSource, }; }; -export const parseLayerNormAttributes = (attributes: LayerNormAttributes): LayerNormAttributes => - createAttributeWithCacheKey({axis: attributes.axis, epsilon: attributes.epsilon}); - export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => { validateInputs(context.inputs); context.compute(createLayerNormProgramInfo(context.inputs, attributes, context.outputCount)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index de9309d1e436..1a92d861002f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; -import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], @@ -27,11 +28,13 @@ export const createNaiveMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const batchSize = ShapeUtil.size(outerDims); const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K} ]; + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } @@ -42,7 +45,8 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { @@ -57,6 +61,12 @@ export const createNaiveMatmulProgramInfo = const outerDimsB = bShape.slice(0, -2); const broadCastADims = getBroadcastDims(outerDimsA, outerDims); const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'K', type: 'u32'} + ]; + appendActivationUniforms(activationAttributes, uniforms); + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; const name = variable.name; @@ -96,15 +106,10 @@ export const createNaiveMatmulProgramInfo = return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('M', 'u32') - .registerUniform('N', 'u32') - .registerUniform('K', 'u32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let col = (global_idx % (uniforms.N / ${components})) * ${components}; var index1 = global_idx / (uniforms.N / ${components}); let stride1 = uniforms.M / ${outputNumber}; @@ -134,8 +139,7 @@ export const createNaiveMatmulProgramInfo = return { name: 'MatMulNaive', shaderCache: { - hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ - isChannelsLast}`, + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] }, getRunData: () => ({ @@ -166,9 +170,8 @@ export const matMul = (context: ComputeContext): void => { const N = outputShape[outputShape.length - 1]; const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; if (N < 8 && K < 8) { - context.compute( - createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts new file mode 100644 index 000000000000..7f1a5b96863f --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -0,0 +1,304 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType, getTensorElementSize} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; + +// TODO support quantization bits not equal to 4 +export interface MatMulNBitsAttributes extends AttributeWithCacheKey { + k: number; + n: number; + accuracyLevel: number; + bits: number; + blockSize: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('MatMulNBits requires 3 or 4 inputs'); + } + const a = inputs[0]; + const aRank = a.dims.length; + if (a.dims[aRank - 1] !== attributes.k) { + throw new Error('The last dim of input shape does not match the k value'); + } + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const b = inputs[1]; + if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); + } + const scales = inputs[2]; + const scalesShape = scales.dims; + if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + throw new Error('scales input size error.'); + } + if (inputs.length === 4) { + const zeroPoints = inputs[3]; + const zeroPointsShape = zeroPoints.dims; + const expectedZeroPointsSize = + attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { + throw new Error('zeroPoints input size error.'); + } + } +}; + +export const createMatMulNBitsProgramInfo = + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes, + maxComputeWorkgroupSizes: [number, number, number], maxComputeWorkgroupStorageSize: number): ProgramInfo => { + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const dimAOuter = inputShape[aRank - 2]; + const dimInner = attributes.k; + const dimBOuter = attributes.n; + const batchDims = inputShape.slice(0, aRank - 2); + const batchSize = ShapeUtil.size(batchDims); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const blobSizeInWords = blobSize / 4; + const dataType = inputs[0].dataType; + const outputNumber = getMaxComponents(dimAOuter); + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const elementSize = getTensorElementSize(dataType)!; + const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize; + const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize); + const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0; + const components = (!useBlockwiseMatMulNBits || maxNumberOfComponents >= 4) ? getMaxComponents(dimBOuter) : + ((maxNumberOfComponents >= 2) && getMaxComponents(dimBOuter) >= 2) ? 2 : + 1; + const outputShape = batchDims.concat([dimAOuter, dimBOuter]); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + + const programUniforms: ProgramUniform[] = useBlockwiseMatMulNBits ? + [] : + [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.blockSize}]; + const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(inputShapeTemp)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputRank = inputShapeTemp.length; + const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const outputRank = outputShapeTemp.length; + const output = outputVariable('output', inputs[0].dataType, outputRank, components); + const uniforms: UniformsArrayType = [{name: 'output_size', type: 'u32'}, {name: 'block_size', type: 'u32'}]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const qDqDataType = (() => { + switch (aComponents) { + case 1: + return `array<${dataType}, 8>`; + case 2: + return `mat4x2<${dataType}>`; + case 4: + return `mat2x4<${dataType}>`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })(); + + const processOneBlock = ` + for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_data = ${b.getByIndices('b_indices')}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + let b_value: u32 = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'}; + let b_mask: u32 = 0x0F0F0F0Fu; + let b_value_lower: vec4 = unpack4xU8(b_value & b_mask); + let b_value_upper: vec4 = unpack4xU8((b_value >> 4) & b_mask); + let b_quantized_values = ${qDqDataType}(${ + Array.from({length: 4}, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`) + .join(', ')}); + let b_dequantized_values = ${(() => { + if (aComponents === 1) { + return `${qDqDataType}(${ + Array.from({length: 8}, (_, i) => `(b_quantized_values[${i}] - zero_point) * scale`).join(', ')});`; + } else { + return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale;`; + } + })()}; + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + for (var m: u32 = 0; m < ${useBlockwiseMatMulNBits ? dimAOuter : outputNumber}u; m++) { + ${a.indicesSet('a_indices', inputRank - 2, useBlockwiseMatMulNBits ? 'm' : `row * ${outputNumber} + m`)}; + ${a.indicesSet('a_indices', inputRank - 1, 'word_offset')}; + var input_offset = ${a.indicesToOffset('a_indices')}; + var a_data: ${qDqDataType}; + for (var j: u32 = 0; j < ${8 / aComponents}; j++) { + a_data[j] = ${a.getByOffset('input_offset')}; + input_offset++; + } + ${useBlockwiseMatMulNBits ? 'workgroup_shared[workgroup_shared_offset + m]' : 'output_values[m]'}${ + components > 1 ? '[c]' : ''} += ${ + Array + .from( + {length: 8 / aComponents}, + (_, i) => `${ + aComponents === 1 ? `a_data[${i}] * b_dequantized_values[${i}]` : + `dot(a_data[${i}], b_dequantized_values[${i}])`}`) + .join(' + ')}; + } + word_offset += ${8 / aComponents}; + } + }`; + const updateZeroPointIndex = zeroPoints ? ` + zero_point_offset += 4; + if (zero_point_offset == 32) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + }` : + ''; + + return useBlockwiseMatMulNBits ? ` + var workgroup_shared: array<${output.type.value}, ${dimAOuter * nBlocksPerCol}>; + ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart([ + nBlocksPerCol, 1, 1 + ])} + var a_indices: ${a.type.indices}; + var block = local_id.x; + var col = workgroup_id.y; + var batch = workgroup_id.z; + ${a.indicesSet('a_indices', '0', 'batch')}; + // Two zero points are packed into one byte when uniforms.bits is 4. + for (var c: u32 = 0; c < ${components}; c++) { + let col_times_components_plus_c = col * ${components} + c; + ${ + zeroPoints ? ` + var zero_point_bytes_per_col: u32 = (${nBlocksPerCol} + 1) / 2; + var zero_point_byte_count: u32 = col_times_components_plus_c * zero_point_bytes_per_col + (block >> 0x1u); + var zero_point_word_index: u32 = zero_point_byte_count >> 0x2u; + var zero_point_byte_offset: u32 = zero_point_byte_count & 0x3u; + var zero_point_nibble_offset: u32 = block & 0x1u; + var zero_point_bits_offset: u32 = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;` : + ''} + var b_indices: ${b.type.indices}; + ${b.indicesSet('b_indices', '0', 'col_times_components_plus_c')}; + // The scale and zero points are computed per block. + var scales_index = col_times_components_plus_c * ${nBlocksPerCol} + block; + let scale = ${scales.getByOffset('scales_index')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${zeroPoints ? '(zero_point_word) & 0xFu' : 8.0}); + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block * ${attributes.blockSize / aComponents}; + var workgroup_shared_offset: u32 = block * ${dimAOuter}; + ${processOneBlock} + } + workgroupBarrier(); + if (local_id.x == 0u) { + var output_indices: ${output.type.indices}; + ${output.indicesSet('output_indices', '0', 'batch')}; + ${output.indicesSet('output_indices', outputRank - 1, 'col')}; + ${output.indicesSet('output_indices', outputRank - 2, '0')}; + var output_offset = ${output.indicesToOffset('output_indices')}; + for (var m: u32 = 0u; m < ${dimAOuter}u; m++) { + var output_value: ${output.type.value} = ${output.type.value}(0); + var workgroup_shared_offset: u32 = m; + for (var b: u32 = 0u; b < ${nBlocksPerCol}u; b++) { + output_value += workgroup_shared[workgroup_shared_offset]; + workgroup_shared_offset += ${dimAOuter}; + } + ${output.setByOffset('output_offset', 'output_value')}; + output_offset += ${dimBOuter / components}; + } + } + }` : + ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var output_values: array<${output.type.value}, ${outputNumber}>; + var output_indices = ${output.offsetToIndices('global_idx')}; + var col = ${output.indicesGet('output_indices', outputRank - 1)}; + var row = ${output.indicesGet('output_indices', outputRank - 2)}; + var a_indices: ${a.type.indices} = output_indices; + // Two zero points are packed into one byte because uniforms.bits <= 4. + // zero_point_offset is either 0 or 4. It is bit offset within one byte. + // TODO support zero_point_offset for bits > 4 + ${ + zeroPoints ? ` + var zero_point_abs_offset = col * ${components} * ((${nBlocksPerCol} + 1) / 2); + var zero_point_index: u32 = zero_point_abs_offset / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = (zero_point_abs_offset % 4) * 8;` : + ''} + var scale_index = col * ${nBlocksPerCol * components}; + var b_indices: ${b.type.indices}; + for (var c: u32 = 0; c < ${components}; c++) { + ${b.indicesSet('b_indices', '0', `col * ${components} + c`)}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_index')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0}); + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + ${processOneBlock} + scale_index++; + ${updateZeroPointIndex} + block_offset += uniforms.block_size / ${aComponents}; + } + // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. + ${ + zeroPoints ? `if (zero_point_offset % 8 > 0) { + ${updateZeroPointIndex} + }` : + ''} + } + for (var k: u32 = 0u; k < ${outputNumber}u; k++) { + ${output.indicesSet('output_indices', outputRank - 2, `${outputNumber} * row + k`)}; + ${output.setByIndices('output_indices', 'output_values[k]')} + } + }`; + }; + return { + name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', + shaderCache: { + hint: `${attributes.cacheKey};${dimAOuter};${dataType};${inputs.length}`, + inputDependencies: Array(inputs.length).fill('rank') + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', + dispatchGroup: useBlockwiseMatMulNBits ? {x: 1, y: Math.ceil(dimBOuter / components), z: batchSize} : + {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource + }; + }; + +export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { + validateInputs(context.inputs, attributes); + const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); + const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize(); + context.compute(createMatMulNBitsProgramInfo( + context.inputs, attributes, maxComputeWorkgroupSizes, maxComputeWorkgroupStorageSize)); +}; + +export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index b7726a36bcaa..5c5c849d9981 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType} from '../types'; +import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { @@ -228,7 +229,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr }; }; - export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => createAttributeWithCacheKey({...attributes}); @@ -239,30 +239,37 @@ const addBiasTranspose = hiddenSize: number, biasOffset: number) => { const outputShape = [batchSize, sequenceLength, hiddenSize]; const outputSize = ShapeUtil.size(outputShape); - - const dataType = tensorTypeToWsglStorageType(qkv.dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const biasOffset = ${biasOffset}u; - const hiddenSize = ${hiddenSize}u; - - @group(0) @binding(0) var qkv: array<${dataType}>; - @group(0) @binding(1) var bias: array<${dataType}>; - @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; - + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset}, + {type: DataType.uint32, data: hiddenSize} + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); + const qkvInput = inputVariable('qkv', qkv.dataType, outputShape); + const biasInput = inputVariable('bias', bias.dataType, outputShape); + + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'bias_offset', type: 'u32'}, {name: 'hidden_size', type: 'u32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(qkvInput, biasInput, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let bias_offset_idx = (global_idx % uniforms.hidden_size) + uniforms.bias_offset; - qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx]; }`; + }; return context.compute( { name: 'MultiHeadAttentionAddBias', - shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + shaderCache: {inputDependencies: ['type', 'type']}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 18859e253aa0..d649d3d220ae 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -4,12 +4,11 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common'; -export interface PadAttributes extends AttributeWithCacheKey { +interface PadAttributes { // 0-constant, 1-reflect, 2-edge, 3-wrap readonly mode: number; readonly value: number; @@ -20,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 1) { throw new Error('Too few inputs'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Input type must be float.'); + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) { + throw new Error('Input type must be float or float16.'); } if (inputs.length >= 2) { @@ -35,27 +34,23 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const getPadConstant = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[], - dataType: string, constantValue: number): string => { - const inputRank = inputDims.length; - - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadConstant = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { break; } - if (k >= ${inputDims[i]}) { + if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { break; } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` - value = ${dataType}(${constantValue}); + return ` + value = ${output.type.value}(uniforms.constant_value); for (var i = 0; i < 1; i++) { var offset = 0; var k = 0; @@ -63,143 +58,142 @@ const getPadConstant = value = x[offset]; } `; - }; - -const getPadReflect = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { - const inputRank = inputDims.length; +}; - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadReflect = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { k = -k; } { - let _2n_1 = ${2 * (inputDims[i] - 1)}; + let _2n_1 = 2 * (i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1); k = k % _2n_1; - if(k >= ${inputDims[i]}) { + if(k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { k = _2n_1 - k; } } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` + return ` var offset = 0; var k = 0; ${block} value = x[offset]; `; - }; - -const getPadEdge = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { - const inputRank = inputDims.length; +}; - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadEdge = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { k = 0; } - if (k >= ${inputDims[i]}) { - k = ${inputDims[i] - 1}; + if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { + k = i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1; } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` + return ` var offset = 0; var k = 0; ${block} value = x[offset]; `; - }; - -const getPadWrap = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { - const inputRank = inputDims.length; +}; - let block = ''; - for (let i = inputRank - 1; i >= 0; --i) { - block += ` - k = i32(${output.indicesGet('indices', i)}) - ${pads[i]}; +const getPadWrap = (output: IndicesHelper, inputRank: number, padsLength: number): string => { + let block = ''; + for (let i = inputRank - 1; i >= 0; --i) { + block += ` + k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)}; if (k < 0) { - k += ${inputDims[i]}; + k += i32(${getElementAt('uniforms.x_shape', i, inputRank)}]); } - if (k >= ${inputDims[i]}) { - k -= ${inputDims[i]}; + if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) { + k -= i32(${getElementAt('uniforms.x_shape', i, inputRank)}); } - offset += k * ${inputStrides[i]}; + offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)}); `; - } + } - return ` + return ` var offset = 0; var k = 0; ${block} value = x[offset]; `; - }; - -const getPadSnippet = - (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], attributes: PadAttributes, - dataType: string): string => { - switch (attributes.mode) { - case 0: - return getPadConstant(output, inputDims, inputStrides, attributes.pads, dataType, attributes.value); - case 1: - return getPadReflect(output, inputDims, inputStrides, attributes.pads); - case 2: - return getPadEdge(output, inputDims, inputStrides, attributes.pads); - case 3: - return getPadWrap(output, inputDims, inputStrides, attributes.pads); - default: - throw new Error('Invalid mode'); - } - }; - -const generatePadCode = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: PadAttributes, dataType: string): - string => { - const inputDims = inputs[0].dims; - const outputDims = ShapeUtil.padShape(inputDims.slice(), attributes.pads); - const outputSize = ShapeUtil.size(outputDims); - const inputStrides = ShapeUtil.computeStrides(inputDims); - - const output = outputVariable('output', inputs[0].dataType, outputDims); - const input = inputVariable('x', inputs[0].dataType, inputDims); - - const padSnippet = getPadSnippet(output, inputDims, inputStrides, attributes, dataType); - const padCode = ` - ${shaderHelper.declareVariables(input, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - - let indices = ${output.offsetToIndices('global_idx')}; - - var value = ${dataType}(0); - ${padSnippet} - output[global_idx] = value; - }`; - return padCode; - }; +}; + +const getPadSnippet = (output: IndicesHelper, inputRank: number, attributes: PadAttributes): string => { + switch (attributes.mode) { + case 0: + return getPadConstant(output, inputRank, attributes.pads.length); + case 1: + return getPadReflect(output, inputRank, attributes.pads.length); + case 2: + return getPadEdge(output, inputRank, attributes.pads.length); + case 3: + return getPadWrap(output, inputRank, attributes.pads.length); + default: + throw new Error('Invalid mode'); + } +}; const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => { const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads); + const inputDims = inputs[0].dims; + const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = + [{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: attributes.pads}]; + if (attributes.mode === 0) { + programUniforms.push({type: inputs[0].dataType, data: attributes.value}); + } + + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const input = inputVariable('x', inputs[0].dataType, inputDims.length); + const dataType = input.type.value; + const padSnippet = getPadSnippet(output, inputDims.length, attributes); + const uniforms: UniformsArrayType = + [{name: 'output_size', type: 'u32'}, {name: 'pads', type: 'i32', length: attributes.pads.length}]; + if (attributes.mode === 0) { + uniforms.push({name: 'constant_value', type: dataType as UniformDataElementType}); + } + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + + let indices = ${output.offsetToIndices('global_idx')}; + + var value = ${dataType}(0); + ${padSnippet} + output[global_idx] = value; + }`; + }; + return { name: 'Pad', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.mode}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, + programUniforms }), - getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'), + getShaderSource, }; }; @@ -223,7 +217,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes const pads: number[] = []; updatePads.forEach(v => pads.push(v)); - return createAttributeWithCacheKey({mode: attributes.mode, value, pads}); + return {mode: attributes.mode, value, pads}; } else { return attributes; } @@ -234,10 +228,3 @@ export const pad = (context: ComputeContext, attributes: PadAttributes): void => const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes); context.compute(createPadProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); }; - -export const parsePadAttributes = (attributes: Record): PadAttributes => { - const mode = attributes.mode as number; - const value = attributes.value as number; - const pads = attributes.pads as number[]; - return createAttributeWithCacheKey({mode, value, pads}); -}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9e9b361c1af1..5521650e8ded 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -56,7 +57,8 @@ const getUniformAndPadInfo = generatePoolingCode( - shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms, - hasPads, pwStartEndNotZero, phStartEndNotZero), + shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, + (input.dataType === DataType.float16) ? -65504 : -1e5, uniforms, hasPads, pwStartEndNotZero, + phStartEndNotZero), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts index 9cf66111bf70..a21f48ef9ded 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/range.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -4,9 +4,9 @@ import {env} from 'onnxruntime-common'; import {DataType} from '../../../wasm-common'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common'; const validateInputsContent = (start: number, limit: number, delta: number): void => { const sameStartLimit = start === limit; @@ -22,23 +22,35 @@ const createRangeProgramInfo = (start: number, limit: number, delta: number, dat const numElements = Math.abs(Math.ceil((limit - start) / delta)); const outputShape: number[] = [numElements]; const outputSize = numElements; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: dataType, data: start}, {type: dataType, data: delta}, + ...createTensorShapeVariables(outputShape) + ]; - const output = outputVariable('output', dataType, outputShape); - const wgslType = output.type.storage; - - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(output)} + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', dataType, outputShape.length); + const wgslType = output.type.value; + const uniforms: UniformsArrayType = [ + {name: 'outputSize', type: 'u32'}, {name: 'start', type: wgslType as UniformDataElementType}, + {name: 'delta', type: wgslType as UniformDataElementType} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - output[global_idx] = ${wgslType}(${start}) + ${wgslType}(global_idx) * ${wgslType}(${delta}); + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + output[global_idx] = uniforms.start + ${wgslType}(global_idx) * uniforms.delta; }`; + }; + return { name: 'Range', - shaderCache: {hint: [start, limit, delta].map(x => x.toString()).join('_')}, + shaderCache: {hint: `${dataType}`}, getShaderSource, - getRunData: () => ( - {outputs: [{dims: outputShape, dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}}) + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index 7c440cbffea7..210b3ee7e2fc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -131,7 +131,7 @@ export const createReduceSharedProgramInfo = const workgroupSize = 32; const sharedMemorySnippet = ` - var aBestValues : array<${output.type.storage}, ${workgroupSize}>; + var aBestValues : array; `; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -145,10 +145,10 @@ export const createReduceSharedProgramInfo = let outputIndex = global_idx / ${workgroupSize}; let offset = outputIndex * uniforms.reduceSize; - var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]}); + var bestValue = f32(${reduceInitValues[reduceType]}); let Length = uniforms.reduceSize; for (var k = local_idx; k < Length; k = k + ${workgroupSize}) { - let candidate = ${output.type.storage}(${input.getByOffset('offset + k')}); + let candidate = f32(${input.getByOffset('offset + k')}); bestValue = ${reduceOps[reduceType]}; } aBestValues[local_idx] = bestValue; @@ -172,8 +172,8 @@ export const createReduceSharedProgramInfo = output.setByOffset( 'outputIndex', `${ - reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` : - `${reduceOutputValues[reduceType]}`}`)}; + reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` : + `${output.type.storage}(${reduceOutputValues[reduceType]})`}`)}; } }`; @@ -185,7 +185,7 @@ export const createReduceSharedProgramInfo = getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: outputSize}, - programUniforms: [{type: 'uint32', data: reduceSize}] + programUniforms: [{type: DataType.uint32, data: reduceSize}] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index e8851ac54694..e8205ba6fd92 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -100,10 +100,8 @@ export const createReduceProgramInfo = getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape) - ] + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index bea3e8625b41..2c6b537de1f0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -70,7 +71,6 @@ const validateInputs = const rank = inputs[0].dims.length; if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) { inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value)); - } else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') { throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize'); } @@ -110,41 +110,48 @@ const validateInputs = const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode, dType: string): string => - `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType}, - lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + + `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32, + lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${dType} { ` + (() => { switch (coordinateTransferMode) { case 'asymmetric': - return 'return xResized / xScale;'; + return `return ${dType}(xResized) / ${dType}(xScale);`; case 'pytorch_half_pixel': - return 'if (lengthResized > 1) { \ - return (xResized + 0.5) / xScale - 0.5; \ - } else { \ - return 0.0; \ - }'; + return `if (lengthResized > 1) { + return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5; + } else { + return 0.0; + }`; case 'tf_half_pixel_for_nn': - return 'return (xResized + 0.5) / xScale;'; + return `return (${dType}(xResized) + 0.5) / ${dType}(xScale);`; case 'align_corners': - return 'if (lengthResized == 1) { \ - return 0.0; \ - } else { \ - return xResized * (lengthOriginal - 1) / (lengthResized - 1); \ - }'; + return `if (lengthResized == 1) { + return 0.0; + } else { + // The whole part and the fractional part are calculated separately due to inaccuracy of floating + // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an + // offset-by-one error later in floor(). + let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1)); + let fract = + ${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1); + return whole + fract; + }`; case 'tf_crop_and_resize': - return `if (lengthResized > 1) { \ - return roiStart * (lengthOriginal - 1) + \ - (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \ - } else { \ - return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \ + return `if (lengthResized > 1) { + return ${dType}(roiStart) * ${dType}(lengthOriginal - 1) + + (${dType}(xResized) * ${dType}(roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) / + ${dType}(lengthResized - 1); + } else { + return 0.5 * ${dType}(roiStart + roiEnd) * ${dType}(lengthOriginal - 1); }`; case 'half_pixel_symmetric': - return [ - 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', - 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', - 'return offset + ((xResized + 0.5) / xScale) - 0.5;' - ].join('\n'); + return `const outputWidth = ${dType}xScale * ${dType}(lengthResized); + const adjustment = ${dType}(lengthResized) / outputWidth; + const center = ${dType}(lengthOriginal) / 2; + const offset = center * (1 - adjustment); + return offset + ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`; case 'half_pixel': - return 'return ((xResized + 0.5) / xScale) - 0.5;'; + return `return ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`; default: throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); } @@ -254,15 +261,15 @@ const calculateOriginalIndicesFromOutputIndices = output.type.value}, ${outputShape.length}> { var original_indices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var output_index = ${output.indicesGet('output_indices', 'i')}; var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; if (scale == 1.0) { - original_indices[i] = output_index; + original_indices[i] = ${output.type.value}(output_index); } else { - var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); - var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; + var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)}; original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, input_shape_i, roi_low, roi_hi); } @@ -276,23 +283,23 @@ const calculateInputIndicesFromOutputIndices = fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var output_index = ${output.indicesGet('output_indices', 'i')}; var input_index: u32; var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; if (scale == 1.0) { - input_index = u32(output_index); + input_index = output_index; } else { var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; - var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); - var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; + var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)}; var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, input_shape_i, roi_low, roi_hi); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) { + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${output.type.value}(input_shape_i))) { if (original_idx < 0) { input_index = 0; - } else if (original_idx > (input_shape_i - 1)) { - input_index = u32(input_shape_i) - 1; + } else if (original_idx > ${output.type.value}(input_shape_i - 1)) { + input_index = input_shape_i - 1; } else { input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1)); } @@ -391,8 +398,8 @@ const bicubicInterpolation = fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${ output.type.indices}) -> ${dType} { var output_index = ${output.indicesGet('output_indices', idx)}; - var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]}, - ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]}, + ${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx); @@ -635,11 +642,8 @@ const createResizeProgramInfo = outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, - {type: 'float32', data: scales}, - {type: 'float32', data: roi}, - ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape), + {type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales}, + {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape) ] }) }; @@ -656,6 +660,10 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v const scales: number[] = []; const sizes: number[] = []; const roi: number[] = []; + + // Note that scales in resize are always f32. roi can be f32 or f16. + // TODO: Currently this code does not support f16 for roi when passed as optional input. + const opsetVersion = getOpsetVersionFromCustomDataBuffer(context); if (attributes.antialias !== 0) { throw Error('Only default value (0) for Antialias attribute is supported'); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts new file mode 100644 index 000000000000..a58087072e4c --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, WORKGROUP_SIZE} from './common'; + +export interface RotaryEmbeddingAttributes { + readonly interleaved: boolean; + readonly numHeads: number; + readonly rotaryEmbeddingDim: number; + readonly scale: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): void => { + const [input, positionIds, cosCache, sinCache] = inputs; + const {numHeads, rotaryEmbeddingDim} = attributes; + + if (input.dims.length !== 3 && input.dims.length !== 4) { + throw new Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${input.dims.length}`); + } + if (!ShapeUtil.areEqual(positionIds.dims, []) && !ShapeUtil.areEqual(positionIds.dims, [1]) && + positionIds.dims.length !== 2) { + throw new Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${positionIds.dims.length}`); + } + if (cosCache.dims.length !== 2) { + throw new Error(`Input 'cos_cache' is expected to have 2 dimensions, got ${cosCache.dims.length}`); + } + if (sinCache.dims.length !== 2) { + throw new Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${sinCache.dims.length}`); + } + if (!ShapeUtil.areEqual(cosCache.dims, sinCache.dims)) { + throw new Error('Inputs \'cos_cache\' and \'sin_cache\' are expected to have the same shape'); + } + + if (rotaryEmbeddingDim > 0 && numHeads === 0) { + throw new Error('num_heads must be provided if rotary_embedding_dim is specified'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[input.dims.length - 2]; + const maxSequenceLength = cosCache.dims[0]; + const hiddenSize = ShapeUtil.sizeFromDimension(input.dims, 1) / sequenceLength; + const headSize = rotaryEmbeddingDim === 0 ? cosCache.dims[1] * 2 : hiddenSize / numHeads; + if (rotaryEmbeddingDim > headSize) { + throw new Error('rotary_embedding_dim must be less than or equal to head_size'); + } + + if (positionIds.dims.length === 2) { + if (batchSize !== positionIds.dims[0]) { + throw new Error(`Input 'position_ids' dimension 0 should be of size batch_size, got ${positionIds.dims[0]}`); + } + if (sequenceLength !== positionIds.dims[1]) { + throw new Error(`Input 'position_ids' dimension 1 should be of size sequence_length, got ${positionIds.dims[1]}`); + } + } + + if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) { + throw new Error(`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${ + cosCache.dims[1]}`); + } + + if (sequenceLength > maxSequenceLength) { + throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported'); + } +}; + +const createRotaryEmbeddingProgramInfo = + (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): ProgramInfo => { + const {interleaved, numHeads, rotaryEmbeddingDim, scale} = attributes; + const batchSize = inputs[0].dims[0]; + const batchStride = ShapeUtil.sizeFromDimension(inputs[0].dims, 1); + const sequenceLength = inputs[0].dims[inputs[0].dims.length - 2]; + const hiddenSize = batchStride / sequenceLength; + const halfRotaryEmbeddingDim = inputs[2].dims[1]; + const headSize = rotaryEmbeddingDim === 0 ? halfRotaryEmbeddingDim * 2 : hiddenSize / numHeads; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const globalShape = + new Array(batchSize, sequenceLength, hiddenSize / headSize, headSize - halfRotaryEmbeddingDim); + const globalStrides = ShapeUtil.computeStrides(globalShape); + + const programUniforms: ProgramUniform[] = [ + {type: DataType.float, data: scale}, + {type: DataType.uint32, data: globalShape}, + {type: DataType.uint32, data: globalStrides}, + + // strides for addressing the input/output tensor, in permutated order to align with the unfolded global index, + // i.e. BSNH + ...(inputs[0].dims.length === 3 ? + new Array({type: DataType.uint32, data: [batchStride, hiddenSize, headSize, 1]}) : + []), + ...(inputs[0].dims.length === 4 ? + new Array( + {type: DataType.uint32, data: [batchStride, headSize, sequenceLength * headSize, 1]}) : + []), + + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, inputs[2].dims, inputs[3].dims, inputs[0].dims), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length); + const positionIds = inputVariable('position_ids', inputs[1].dataType, inputs[1].dims.length); + const cosCache = inputVariable('cos_cache', inputs[2].dataType, inputs[2].dims.length); + const sinCache = inputVariable('sin_cache', inputs[3].dataType, inputs[3].dims.length); + const output = outputVariable('output', inputs[0].dataType, inputs[0].dims.length); + + shaderHelper.registerUniforms([ + {name: 'scale', type: 'f32'}, + {name: 'global_shape', type: 'u32', length: globalShape.length}, + {name: 'global_strides', type: 'u32', length: globalStrides.length}, + {name: 'input_output_strides', type: 'u32', length: globalStrides.length}, + ]); + + return ` + ${shaderHelper.declareVariables(input, positionIds, cosCache, sinCache, output)} + + ${shaderHelper.mainStart(WORKGROUP_SIZE)} + let half_rotary_emb_dim = uniforms.${cosCache.name}_shape[1]; + let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape; + let size = uniforms.global_shape[0] * uniforms.global_strides[0]; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('size')} + + if (bsnh[3] < half_rotary_emb_dim) { + let position_ids_idx = + ${positionIds.broadcastedIndicesToOffset('bsnh.xy', outputVariable('', positionIds.type.tensor, 2))}; + let position_id = + u32(${positionIds.getByOffset('position_ids_idx')}) + select(0, bsnh[1], position_ids_idx == 0); + let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], ${interleaved}); + let j = i + select(half_rotary_emb_dim, 1, ${interleaved}); + let re = ${input.getByOffset('i')} * ${cosCache.get('position_id', 'bsnh[3]')} - + ${input.getByOffset('j')} * ${sinCache.get('position_id', 'bsnh[3]')}; + ${output.setByOffset('i', 're')} + let im = ${input.getByOffset('i')} * ${sinCache.get('position_id', 'bsnh[3]')} + + ${input.getByOffset('j')} * ${cosCache.get('position_id', 'bsnh[3]')}; + ${output.setByOffset('j', 'im')} + } else { + let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim; + ${output.setByOffset('k', input.getByOffset('k'))} + } + }`; + }; + + return { + name: 'RotaryEmbedding', + shaderCache: { + hint: createAttributeWithCacheKey({ + interleaved, + }).cacheKey, + inputDependencies: ['rank', 'rank', 'rank', 'rank'], + }, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(ShapeUtil.size(globalShape) / WORKGROUP_SIZE)}, + programUniforms, + }), + }; + }; + +export const rotaryEmbedding = (context: ComputeContext, attributes: RotaryEmbeddingAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createRotaryEmbeddingProgramInfo(context.inputs, attributes)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 7e500f865c19..e7dc34d2fc75 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -4,12 +4,12 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; -export interface SkipLayerNormAttributes extends AttributeWithCacheKey { +export interface SkipLayerNormAttributes { + simplified: boolean; epsilon: number; } @@ -73,73 +73,89 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const createSkipLayerNormProgramInfo = (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean): ProgramInfo => { + const simplified = attributes.simplified; + const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const outputShape = inputShape; const outputSize = inputSize; const hiddenSize = inputShape.slice(-1)[0]; const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; - const hasBetaInput = inputs.length > 3; + const hasBetaInput = !simplified && inputs.length > 3; const hasBiasInput = inputs.length > 4; const hasMeanOutput = isTraining && outputCount > 1; const hasInvStdDevOutput = isTraining && outputCount > 2; const hasInputSkipBiasSumOutput = outputCount > 3; const components = getMaxComponents(hiddenSize); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanOutput) { - variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: f32 = ${hiddenSize}; - const hiddenSizeVectorized: u32 = ${hiddenSize / components}; - const epsilon: f32 = ${attributes.epsilon}; - ${shaderHelper.declareVariables(...variables)} + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, + {type: DataType.uint32, data: components}, + {type: DataType.uint32, data: hiddenSize}, + {type: DataType.float, data: attributes.epsilon}, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniformsArray: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'components', type: 'u32'}, + {name: 'hidden_size', type: 'u32'}, + {name: 'epsilon', type: 'f32'}, + ]; + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + + ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')} + let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components; + let offset = global_idx * hidden_size_vectorized; var sum = ${fillVector('f32', components)}; var squareSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - let skipValue = skip[offset + i]; - let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; - let inputValue = x[offset + i]; - let value = inputValue + skipValue + biasValue; - ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + let skip_value = skip[offset + i]; + let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'}; + let input_value = x[offset + i]; + let value = input_value + skip_value + bias_value; + ${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''} output[offset + i] = value; - let f32Value = ${castToF32(dataType, components, 'value')}; - sum += f32Value; - squareSum += f32Value * f32Value; + let f32_value = ${castToF32(dataType, components, 'value')}; + sum += f32_value; + squareSum += f32_value * f32_value; } - let mean = ${sumVector('sum', components)} / hiddenSize; - let variance = sqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); - ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} - ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''} - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i] - + ${hasBetaInput ? 'beta[i]' : '0.0'}; + let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); + let inv_std_dev = inverseSqrt(${sumVector('squareSum', components)} / f32(uniforms.hidden_size) ${ + simplified ? '' : '- mean * mean'} + uniforms.epsilon); + ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} + ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + output[offset + i] = (output[offset + i] ${simplified ? '' : `- ${dataType}(mean)`}) * ${ + dataType}(inv_std_dev) * gamma[i] ${hasBetaInput ? '+ beta[i]' : ''}; } }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (outputCount > 1) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -150,12 +166,14 @@ const createSkipLayerNormProgramInfo = if (outputCount > 3) { outputs.push({dims: inputShape, dataType: inputs[0].dataType}); } - return { name: 'SkipLayerNormalization', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: { + hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, + inputDependencies: inputs.map((_input, _index) => 'type') + }, getShaderSource, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}), }; }; @@ -178,8 +196,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm context.compute( createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); }; - -export const parseSkipLayerNormAttributes = (attributes: Record): SkipLayerNormAttributes => { - const epsilon = attributes.epsilon as number; - return createAttributeWithCacheKey({epsilon}); -}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5212c6475dce..a5e71f30e596 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -155,9 +155,9 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice ]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, - {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, + {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, + ...createTensorShapeVariables(inputs[0].dims, outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 324dc3af1a71..b0e3ddd14965 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,6 +5,7 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -136,7 +137,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut getRunData: () => ({ outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}, - programUniforms: [{type: 'uint32', data: packedCols}] + programUniforms: [{type: DataType.int32, data: packedCols}] }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index b8582614fa21..a09ac78b1700 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -67,24 +68,23 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); + const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); - programUniforms.push(...createTensorShapeVariables(inputShape)); - outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); + programUniforms.push( + {type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes)); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${ shaderHelper.registerUniform('input_size', 'u32') diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 90a36a7bec2a..f9728575fe07 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -79,10 +79,8 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index c4d43e9f466f..7ae801222b87 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,12 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -39,12 +40,9 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); - const useShapesUniforms = enableShapesUniforms(inputRank); const outputShape = getOutputShape(inputTensor.dims, perm); - const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; - const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; - const output = outputVariable('output', inputDataType, outShapeOrRank); - const input = inputVariable('a', inputDataType, inShapeOrRank); + const output = outputVariable('output', inputDataType, outputShape.length); + const input = inputVariable('a', inputDataType, inputRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} @@ -61,21 +59,14 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu }`; return { name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, getRunData: (inputs) => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: outputSize}, - ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: outputSize}, - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }; }, getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index a25e7fe4229b..5f105c745739 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -53,7 +53,7 @@ const createElementwiseProgramInfo = dispatchGroup: {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, ], }) }); @@ -178,7 +178,7 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string, varType = 'f32') => ` +export const erfImpl = (varType = 'f32') => ` const r0: ${varType} = 0.3275911; const r1: ${varType} = 0.254829592; const r2: ${varType} = -0.284496736; @@ -186,7 +186,7 @@ const r3: ${varType} = 1.421413741; const r4: ${varType} = -1.453152027; const r5: ${varType} = 1.061405429; -fn erf_vf32(v: ${dataType}) -> ${dataType} { +fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> { let absv = abs(v); let x = 1.0 / (1.0 + r0 * absv); return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); @@ -194,8 +194,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { export const erf = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType))); }; export const exp = (context: ComputeContext): void => { @@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => { export const gelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl(`vec4<${dataType}>`, dataType))); + context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { @@ -242,6 +240,26 @@ export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; +export interface HardSigmoidAttributes extends AttributeWithCacheKey { + readonly alpha: number; + readonly beta: number; +} + +export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => + createAttributeWithCacheKey(attributes as { + alpha: number; + beta: number; + }); + +export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'HardSigmoid', + a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta})))`, + undefined, attributes.cacheKey)); +}; + export const sin = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; @@ -258,8 +276,31 @@ export const tan = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan')); }; +export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`; + export const tanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); + // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression)); +}; + +export const fastGeluImpl = (varType = 'f32') => ` +const fast_gelu_a: ${varType} = 0.5; +const fast_gelu_b: ${varType} = 0.7978845608028654; +const fast_gelu_c: ${varType} = 0.035677408136300125; + +fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> { + return ${tanhExpression('v')}; +} +`; + +export const fastGeluExpression = (x: string) => + `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`; + +export const fastGelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined, + context.inputs[0].dataType)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 687ee054096c..a6375847fc42 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -27,7 +27,7 @@ const createWhereOpProgramShader = const expressionA = `a_data[index_a${x}][component_a${x}]`; const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; @@ -38,6 +38,7 @@ const createWhereOpProgramShader = let index_c${x} = offset_c${x} / 4u; let component_a${x} = offset_a${x} % 4u; let component_b${x} = offset_b${x} % 4u; + let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; @@ -76,7 +77,6 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); let outputShape = dimsA; let outputSize = ShapeUtil.size(dimsA); - const vecSize = Math.ceil(outputSize / 4); // TODO: deal with zero-sized tensors (eg. dims=[1,0]) if (isBroadcast) { @@ -88,6 +88,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => outputSize = ShapeUtil.size(outputShape); } + const vecSize = Math.ceil(outputSize / 4); + return { name: 'Where', shaderCache: {inputDependencies: ['rank', 'rank', 'rank']}, @@ -96,10 +98,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, - programUniforms: [ - {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), - ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index ae5bf68483b4..ccbcbe48505d 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {tensorDataTypeEnumToString} from '../../wasm-common'; +import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; + import {WebGpuBackend} from '../backend-webgpu'; import {LOG_DEBUG} from '../log'; -import {TensorView} from '../tensor-view'; import {createShaderHelper} from './ops/common'; import {Artifact, GpuData, ProgramInfo} from './types'; @@ -32,13 +32,12 @@ export class ProgramManager { setArtifact(key: unknown, artifact: Artifact): void { this.repo.set(key, artifact); } - run(buildArtifact: Artifact, inputTensorViews: readonly TensorView[], outputTensorViews: readonly TensorView[], - inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number], + run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number], uniformBufferBinding: GPUBindingResource|undefined): void { + TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; - const computePassEncoder = this.backend.getComputePassEncoder(); - computePassEncoder.setPipeline(buildArtifact.computePipeline); + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); const entries = []; for (const input of inputs) { entries.push({binding: entries.length, resource: {buffer: input.buffer}}); @@ -51,94 +50,44 @@ export class ProgramManager { } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - computePassEncoder.setBindGroup(0, bindGroup); - computePassEncoder.dispatchWorkgroups(...dispatchGroup); + if (this.backend.sessionStatus === 'capturing') { + const commandInfo = { + kernelId: this.backend.currentKernelId!, + computePipeline: buildArtifact.computePipeline, + bindGroup, + dispatchGroup + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(commandInfo); + } + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); + computePassEncoder.dispatchWorkgroups(...dispatchGroup); + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; - if (this.backend.isQueryEnabled()) { - if (typeof this.backend.queryData === 'undefined') { - this.backend.queryData = this.backend.gpuDataManager.create( - // eslint-disable-next-line no-bitwise - this.backend.querySetCount * 8, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); - } - const syncData = this.backend.gpuDataManager.create( - // eslint-disable-next-line no-bitwise - this.backend.querySetCount * 8, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); - + if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber || + this.backend.queryType === 'at-passes') { this.backend.endComputePass(); - this.backend.getCommandEncoder().resolveQuerySet(this.backend.querySet!, 0, 2, this.backend.queryData.buffer, 0); - this.backend.getCommandEncoder().copyBufferToBuffer( - this.backend.queryData.buffer, 0, syncData.buffer, 0, this.backend.querySetCount * 8); - this.backend.flush(); - - const kernelId = this.backend.currentKernelId!; - const kernelInfo = this.backend.kernels.get(kernelId)!; - - void syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { - const mappedData = new BigUint64Array(syncData.buffer.getMappedRange()); - const [startTimeU64, endTimeU64] = mappedData; - const [kernelType, kernelName] = kernelInfo; - - syncData.buffer.unmap(); - - if (typeof this.backend.queryTimeBase === 'undefined') { - this.backend.queryTimeBase = startTimeU64; - } - - const startTime = Number(startTimeU64 - this.backend.queryTimeBase); - const endTime = Number(endTimeU64 - this.backend.queryTimeBase); - - if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) { - throw new RangeError('incorrect timestamp range'); - } - - this.backend.gpuDataManager.release(syncData.id); - if (this.backend.env.webgpu.profiling?.ondata) { - this.backend.env.webgpu.profiling.ondata({ - version: 1, - inputsMetadata: inputTensorViews.map( - value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), - outputsMetadata: outputTensorViews.map( - value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), - kernelId, - kernelType, - kernelName, - startTime, - endTime, - }); - } else { - // if no callback is provided, print the profiling message to console - let inputShapes = ''; - inputTensorViews.forEach((value, i) => { - inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; - }); - let outputShapes = ''; - outputTensorViews.forEach((value, i) => { - outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; - }); - // eslint-disable-next-line no-console - console.log(`[profiling] kernel "${kernelId}|${kernelName}|${buildArtifact.programInfo.name}" ${inputShapes}${ - outputShapes}execution time: ${endTime - startTime} ns`); - } - }); } - - if (this.backend.pendingDispatchNumber >= 16) { + if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) { this.backend.flush(); } + TRACE_FUNC_END(buildArtifact.programInfo.name); } dispose(): void { // this.repo.forEach(a => this.glContext.deleteProgram(a.program)); } build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { + TRACE_FUNC_BEGIN(programInfo.name); const device = this.backend.device; const extensions: string[] = []; if (device.features.has('shader-f16')) { extensions.push('enable f16;'); } - const shaderHelper = createShaderHelper(normalizedDispatchGroupSize); + const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits); const userCode = programInfo.getShaderSource(shaderHelper); const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({code, label: programInfo.name}); @@ -147,7 +96,8 @@ export class ProgramManager { const computePipeline = device.createComputePipeline( {compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name}); - return {programInfo, computePipeline}; + TRACE_FUNC_END(programInfo.name); + return {programInfo, computePipeline, uniformVariablesInfo: shaderHelper.variablesInfo}; } normalizeDispatchGroupSize(dispatchGroup: ReturnType['dispatchGroup']): diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 23fa33a9bba8..2a584fc0a221 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../wasm-common'; import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; +export type SessionState = 'default'|'capturing'|'replaying'; + export enum GpuDataType { default = 0, upload = 1, @@ -12,6 +15,13 @@ export enum GpuDataType { } export type GpuDataId = number; +export type GpuArchitecture = 'ampere'; +export type GpuVendor = 'amd'|'intel'|'nvidia'; +export interface AdapterInfo { + isArchitecture: (architecture: GpuArchitecture) => boolean; + isVendor: (vendor: GpuVendor) => boolean; +} + export interface GpuData { type: GpuDataType; id: GpuDataId; @@ -23,12 +33,13 @@ export interface TensorInfo { dataType: number; } - export interface ProgramUniform { - type: 'int32'|'float32'|'uint32'; + type: DataType; data: number|readonly number[]; } +export type ProgramUniformVariableInfo = [type: DataType, length: number]; + /** * Represent the dependency of a program on a specific input tensor. * @@ -116,6 +127,7 @@ export interface ProgramInfo { export interface Artifact { programInfo: ProgramInfo; computePipeline: GPUComputePipeline; + uniformVariablesInfo: readonly ProgramUniformVariableInfo[]|undefined; } export interface ComputeContextInputsOutputsMapping { @@ -144,6 +156,11 @@ export interface ComputeContextInputsOutputsMapping { * A ComputeContext instance carries the states that representing the current running of a kernel. */ export interface ComputeContext { + /** + * gpu adapter info + */ + readonly adapterInfo: AdapterInfo; + /** * stores the pointer to OpKernelContext */ @@ -171,4 +188,8 @@ export interface ComputeContext { compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[]; output(index: number, dims: readonly number[]): number; + getMaxComputeWorkgroupSizes(): [number, number, number]; + getMaxComputeWorkgroupStoragesize(): number; } + +export type TimestampQuery = 'none'|'inside-passes'|'at-passes'; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 4df524cdcfb2..3ce37a2d6b65 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -79,8 +79,14 @@ self.onmessage = (ev: MessageEvent): void => { } case 'create': { const {model, options} = message!; - const sessionMetadata = createSession(model, options); - postMessage({type, out: sessionMetadata} as OrtWasmMessage); + createSession(model, options) + .then( + sessionMetadata => { + postMessage({type, out: sessionMetadata} as OrtWasmMessage); + }, + err => { + postMessage({type, err}); + }); break; } case 'release': @@ -97,7 +103,7 @@ self.onmessage = (ev: MessageEvent): void => { } else { postMessage( {type, out: outputs} as OrtWasmMessage, - extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[])); } }, err => { diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 86017a4ec690..6ff4e86b1235 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -155,7 +155,7 @@ export const createSession = ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('create', [resolve, reject]); - const message: OrtWasmMessage = {type: 'create', in : {model, options}}; + const message: OrtWasmMessage = {type: 'create', in : {model, options: {...options}}}; const transferable: Transferable[] = []; if (model instanceof Uint8Array) { transferable.push(model.buffer); diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts index b62287483208..2bece248669f 100644 --- a/js/web/lib/wasm/session-handler-inference.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {readFile} from 'node:fs/promises'; -import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; +import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; +import {loadFile} from './wasm-utils-load-file'; export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { @@ -43,23 +43,18 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan outputNames: string[]; async fetchModelAndCopyToWasmMemory(path: string): Promise { - // fetch model from url and move to wasm heap. The arraybufffer that held the http - // response is freed once we return - const response = await fetch(path); - if (response.status !== 200) { - throw new Error(`failed to load model: ${path}`); - } - const arrayBuffer = await response.arrayBuffer(); - return copyFromExternalBuffer(new Uint8Array(arrayBuffer)); + // fetch model from url and move to wasm heap. + return copyFromExternalBuffer(await loadFile(path)); } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { + TRACE_FUNC_BEGIN(); let model: Parameters[0]; if (typeof pathOrBuffer === 'string') { if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node - model = await readFile(pathOrBuffer); + model = await loadFile(pathOrBuffer); } else { // browser // fetch model and copy to wasm heap. @@ -70,6 +65,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options); + TRACE_FUNC_END(); } async dispose(): Promise { @@ -78,6 +74,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { + TRACE_FUNC_BEGIN(); const inputArray: Tensor[] = []; const inputIndices: number[] = []; Object.entries(feeds).forEach(kvp => { @@ -115,6 +112,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan for (let i = 0; i < results.length; i++) { resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } + TRACE_FUNC_END(); return resultMap; } diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 45ea48a2df20..48eac5749472 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -60,9 +60,6 @@ const setExecutionProviders = // check EP name switch (epName) { - case 'xnnpack': - epName = 'XNNPACK'; - break; case 'webnn': epName = 'WEBNN'; if (typeof ep !== 'string') { @@ -171,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.enableGraphCapture !== undefined) { + if (typeof sessionOptions.enableGraphCapture !== 'boolean') { + throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); + } + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + } + } + if (sessionOptions.freeDimensionOverrides) { for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { if (typeof name !== 'string') { diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c..54eaf5e0c43c 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -3,6 +3,12 @@ import {Tensor} from 'onnxruntime-common'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + // This file includes common definitions. They do NOT have dependency on the WebAssembly instance. /** @@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { switch (type) { case 'float16': - return Uint16Array; + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; case 'float32': return Float32Array; case 'uint8': @@ -169,7 +176,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index a9dfd9218bb6..9b27051f1b9f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -9,6 +9,7 @@ import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +import {loadFile} from './wasm-utils-load-file'; // #region Initializations @@ -83,27 +84,57 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { - // perform WebGPU availability check - if (typeof navigator === 'undefined' || !navigator.gpu) { - throw new Error('WebGPU is not supported in current environment'); - } - const adapter = await navigator.gpu.requestAdapter(); - if (!adapter) { - throw new Error( - 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); - } + if (!BUILD_DEFS.DISABLE_WEBGPU) { + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + const initJsep = require('./jsep/init').init; - if (!env.wasm.simd) { - throw new Error( - 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); - } + if (epName === 'webgpu') { + // perform WebGPU availability check + if (typeof navigator === 'undefined' || !navigator.gpu) { + throw new Error('WebGPU is not supported in current environment'); + } - // init JSEP if available + let adapter = env.webgpu.adapter as GPUAdapter | null; + if (!adapter) { + // if adapter is not set, request a new adapter. + const powerPreference = env.webgpu.powerPreference; + if (powerPreference !== undefined && powerPreference !== 'low-power' && + powerPreference !== 'high-performance') { + throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); + } + const forceFallbackAdapter = env.webgpu.forceFallbackAdapter; + if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { + throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); + } + adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + if (!adapter) { + throw new Error( + 'Failed to get GPU adapter. ' + + 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + } + } else { + // if adapter is set, validate it. + if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' || + typeof adapter.requestDevice !== 'function') { + throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.'); + } + } - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - const initJsep = require('./jsep/init').init; - await initJsep(getInstance(), env, adapter); + if (!env.wasm.simd) { + throw new Error( + 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); + } + + await initJsep('webgpu', getInstance(), env, adapter); + } + if (epName === 'webnn') { + // perform WebNN availability check + if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) { + throw new Error('WebNN is not supported in current environment'); + } + + await initJsep('webnn', getInstance(), env); + } } }; @@ -138,7 +169,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ]; const activeSessions = new Map(); @@ -187,108 +218,136 @@ export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => { * @param options an optional session options object. * @returns a 3-elements tuple containing [session handle, input names, output names] */ -export const createSession = - (modelData: Uint8Array|SerializableInternalBuffer, - options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { - let modelDataOffset: number, modelDataLength: number; - const wasm = getInstance(); +export const createSession = async( + modelData: Uint8Array|SerializableInternalBuffer, + options?: InferenceSession.SessionOptions): Promise => { + let modelDataOffset: number, modelDataLength: number; + const wasm = getInstance(); - if (Array.isArray(modelData)) { - // if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data - [modelDataOffset, modelDataLength] = modelData; - } else if (modelData.buffer === wasm.HEAPU8.buffer) { - // if model data uses the same buffer as the WASM heap, we don't need to copy it. - [modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength]; - } else { - // otherwise, copy the model data to the WASM heap. - [modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData); + if (Array.isArray(modelData)) { + // if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data + [modelDataOffset, modelDataLength] = modelData; + } else if (modelData.buffer === wasm.HEAPU8.buffer) { + // if model data uses the same buffer as the WASM heap, we don't need to copy it. + [modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength]; + } else { + // otherwise, copy the model data to the WASM heap. + [modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData); + } + + let sessionHandle = 0; + let sessionOptionsHandle = 0; + let ioBindingHandle = 0; + let allocs: number[] = []; + const inputNamesUTF8Encoded = []; + const outputNamesUTF8Encoded = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + + if (options?.externalData && wasm.mountExternalData) { + const loadingPromises = []; + for (const file of options.externalData) { + const path = typeof file === 'string' ? file : file.path; + loadingPromises.push(loadFile(typeof file === 'string' ? file : file.data).then(data => { + wasm.mountExternalData!(path, data); + })); } - let sessionHandle = 0; - let sessionOptionsHandle = 0; - let ioBindingHandle = 0; - let allocs: number[] = []; - const inputNamesUTF8Encoded = []; - const outputNamesUTF8Encoded = []; + // wait for all external data files to be loaded + await Promise.all(loadingPromises); + } - try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + if (sessionHandle === 0) { + checkLastError('Can\'t create a session.'); + } - sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); - if (sessionHandle === 0) { - checkLastError('Can\'t create a session.'); - } + const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); - const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); + const enableGraphCapture = !!options?.enableGraphCapture; - const inputNames = []; - const outputNames = []; - const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; - for (let i = 0; i < inputCount; i++) { - const name = wasm._OrtGetInputName(sessionHandle, i); - if (name === 0) { - checkLastError('Can\'t get an input name.'); - } - inputNamesUTF8Encoded.push(name); - inputNames.push(wasm.UTF8ToString(name)); + const inputNames = []; + const outputNames = []; + const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; + for (let i = 0; i < inputCount; i++) { + const name = wasm._OrtGetInputName(sessionHandle, i); + if (name === 0) { + checkLastError('Can\'t get an input name.'); + } + inputNamesUTF8Encoded.push(name); + inputNames.push(wasm.UTF8ToString(name)); + } + for (let i = 0; i < outputCount; i++) { + const name = wasm._OrtGetOutputName(sessionHandle, i); + if (name === 0) { + checkLastError('Can\'t get an output name.'); + } + outputNamesUTF8Encoded.push(name); + const nameString = wasm.UTF8ToString(name); + outputNames.push(nameString); + + if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (enableGraphCapture && options?.preferredOutputLocation === undefined) { + outputPreferredLocations.push('gpu-buffer'); + continue; } - for (let i = 0; i < outputCount; i++) { - const name = wasm._OrtGetOutputName(sessionHandle, i); - if (name === 0) { - checkLastError('Can\'t get an output name.'); - } - outputNamesUTF8Encoded.push(name); - const nameString = wasm.UTF8ToString(name); - outputNames.push(nameString); - - if (!BUILD_DEFS.DISABLE_WEBGPU) { - const location = typeof options?.preferredOutputLocation === 'string' ? - options.preferredOutputLocation : - options?.preferredOutputLocation?.[nameString] ?? 'cpu'; - if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { - throw new Error(`Not supported preferred output location: ${location}.`); - } - outputPreferredLocations.push(location); - } + const location = typeof options?.preferredOutputLocation === 'string' ? + options.preferredOutputLocation : + options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${location}.`); + } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${ + location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); } + outputPreferredLocations.push(location); + } + } - // use IO binding only when at least one output is preffered to be on GPU. - let bindingState: IOBindingState|null = null; - if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) { - ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); - if (ioBindingHandle === 0) { - checkLastError('Can\'t create IO binding.'); - } + // use IO binding only when at least one output is preffered to be on GPU. + let bindingState: IOBindingState|null = null; + if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) { + ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); + if (ioBindingHandle === 0) { + checkLastError('Can\'t create IO binding.'); + } - bindingState = { - handle: ioBindingHandle, - outputPreferredLocations, - outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), - }; - } + bindingState = { + handle: ioBindingHandle, + outputPreferredLocations, + outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), + }; + } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); - return [sessionHandle, inputNames, outputNames]; - } catch (e) { - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + activeSessions.set( + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); + return [sessionHandle, inputNames, outputNames]; + } catch (e) { + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - if (ioBindingHandle !== 0) { - wasm._OrtReleaseBinding(ioBindingHandle); - } + if (ioBindingHandle !== 0) { + wasm._OrtReleaseBinding(ioBindingHandle); + } - if (sessionHandle !== 0) { - wasm._OrtReleaseSession(sessionHandle); - } - throw e; - } finally { - wasm._free(modelDataOffset); - if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); - } - allocs.forEach(alloc => wasm._free(alloc)); - } - }; + if (sessionHandle !== 0) { + wasm._OrtReleaseSession(sessionHandle); + } + throw e; + } finally { + wasm._free(modelDataOffset); + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + + // unmount external data if necessary + wasm.unmountExternalData?.(); + } +}; export const releaseSession = (sessionId: number): void => { const wasm = getInstance(); @@ -296,13 +355,16 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; if (ioBindingState) { + if (enableGraphCapture) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); + } wasm._OrtReleaseBinding(ioBindingState.handle); } - wasm.jsepUnregisterBuffers?.(sessionId); + wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -311,70 +373,80 @@ export const releaseSession = (sessionId: number): void => { }; export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): - void => { - if (!tensor) { - tensorHandles.push(0); - return; - } + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, + enableGraphCapture = false): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); + } + + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + const registerBuffer = wasm.jsepRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; /** * perform inference run @@ -387,7 +459,12 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; + const ioBindingState = session[3]; + const enableGraphCapture = session[4]; + const inputOutputBound = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -410,13 +487,15 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + prepareInputOutputTensor( + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], + enableGraphCapture); } let inputValuesIndex = inputValuesOffset / 4; @@ -432,7 +511,7 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -469,10 +548,13 @@ export const run = async( } } } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } + wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); @@ -540,7 +622,11 @@ export const run = async( // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU // tensor for it. There is no mapping GPU buffer for an empty tensor. if (preferredLocation === 'gpu-buffer' && size > 0) { - const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const getBuffer = wasm.jsepGetBuffer; + if (!getBuffer) { + throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); + } + const gpuBuffer = getBuffer(dataOffset); const elementSize = getTensorElementSize(dataType); if (elementSize === undefined || !isGpuBufferSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); @@ -552,7 +638,7 @@ export const run = async( output.push([ type, dims, { gpuBuffer, - download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type), dispose: () => { wasm._OrtReleaseTensor(tensor); } @@ -578,10 +664,12 @@ export const run = async( } } - if (ioBindingState) { + if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); } - return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 2b7d492cc70b..9b9334c93b78 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -28,13 +28,34 @@ let initialized = false; let initializing = false; let aborted = false; -const isMultiThreadSupported = (): boolean => { - try { - // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. - if (typeof SharedArrayBuffer === 'undefined') { - return false; +const isMultiThreadSupported = (numThreads: number): boolean => { + // WebAssembly threads are set to 1 (single thread). + if (numThreads === 1) { + return false; + } + + // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. + if (typeof SharedArrayBuffer === 'undefined') { + if (typeof self !== 'undefined' && !self.crossOriginIsolated) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', but this will not work unless you enable crossOriginIsolated mode. ' + + 'See https://web.dev/cross-origin-isolation-guide/ for more info.'); } + return false; + } + + // onnxruntime-web does not support multi-threads in Node.js. + if (typeof process !== 'undefined' && process.versions && process.versions.node) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', however, currently onnxruntime-web does not support multi-threads in Node.js. ' + + 'Please consider using onnxruntime-node for performance critical scenarios.'); + } + try { // Test for transferability of SABs (for browsers. needed for Firefox) // https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ if (typeof MessageChannel !== 'undefined') { @@ -106,7 +127,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const numThreads = flags.numThreads!; const simd = flags.simd!; - const useThreads = numThreads > 1 && isMultiThreadSupported(); + const useThreads = isMultiThreadSupported(numThreads); const useSimd = simd && isSimdSupported(); const wasmPaths = flags.wasmPaths; @@ -167,6 +188,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise }; if (!BUILD_DEFS.DISABLE_WASM_THREAD && useThreads) { + config.numThreads = numThreads; if (typeof Blob === 'undefined') { config.mainScriptUrlOrBlob = path.join(__dirname, 'ort-wasm-threaded.js'); } else { diff --git a/js/web/lib/wasm/wasm-utils-load-file.ts b/js/web/lib/wasm/wasm-utils-load-file.ts new file mode 100644 index 000000000000..c6cdba2320bd --- /dev/null +++ b/js/web/lib/wasm/wasm-utils-load-file.ts @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as fs from 'fs'; +import {readFile} from 'node:fs/promises'; + +/** + * Load a file into a Uint8Array. + * + * @param file - the file to load. Can be a URL/path, a Blob, an ArrayBuffer, or a Uint8Array. + * @returns a Uint8Array containing the file data. + */ +export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Promise => { + if (typeof file === 'string') { + if (typeof process !== 'undefined' && process.versions && process.versions.node) { + // load file into ArrayBuffer in Node.js + try { + return new Uint8Array(await readFile(file)); + } catch (e) { + if (e.code === 'ERR_FS_FILE_TOO_LARGE') { + // file is too large, use fs.createReadStream instead + const stream = fs.createReadStream(file); + const chunks: Uint8Array[] = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + return new Uint8Array(Buffer.concat(chunks)); + } + throw e; + } + } else { + // load file into ArrayBuffer in browsers + const response = await fetch(file); + if (!response.ok) { + throw new Error(`failed to load external data file: ${file}`); + } + const contentLengthHeader = response.headers.get('Content-Length'); + const fileSize = contentLengthHeader ? parseInt(contentLengthHeader, 10) : 0; + if (fileSize < 1073741824 /* 1GB */) { + // when Content-Length header is not set, we cannot determine the file size. We assume it is small enough to + // load into memory. + return new Uint8Array(await response.arrayBuffer()); + } else { + // file is too large, use stream instead + if (!response.body) { + throw new Error(`failed to load external data file: ${file}, no response body.`); + } + const reader = response.body.getReader(); + + let buffer; + try { + // try to create ArrayBuffer directly + buffer = new ArrayBuffer(fileSize); + } catch (e) { + if (e instanceof RangeError) { + // use WebAssembly Memory to allocate larger ArrayBuffer + const pages = Math.ceil(fileSize / 65536); + buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + } else { + throw e; + } + } + + let offset = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + const {done, value} = await reader.read(); + if (done) { + break; + } + const chunkSize = value.byteLength; + const chunk = new Uint8Array(buffer, offset, chunkSize); + chunk.set(value); + offset += chunkSize; + } + return new Uint8Array(buffer, 0, fileSize); + } + } + + } else if (file instanceof Blob) { + return new Uint8Array(await file.arrayBuffer()); + } else if (file instanceof Uint8Array) { + return file; + } else { + return new Uint8Array(file); + } +}; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 890c5a0f3476..72fe383f04fe 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.17.0", + "version": "1.18.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -28,7 +28,7 @@ "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", - "electron": "^23.1.2", + "electron": "^28.1.4", "globby": "^13.1.3", "karma": "^6.4.1", "karma-browserstack-launcher": "^1.6.0", @@ -49,10 +49,10 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/@chiragrupani/karma-chromium-edge-launcher": { @@ -862,9 +862,9 @@ } }, "node_modules/cross-spawn/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -1042,14 +1042,14 @@ "dev": true }, "node_modules/electron": { - "version": "23.3.13", - "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz", - "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==", + "version": "28.1.4", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", + "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", "dev": true, "hasInstallScript": true, "dependencies": { "@electron/get": "^2.0.0", - "@types/node": "^16.11.26", + "@types/node": "^18.11.18", "extract-zip": "^2.0.1" }, "bin": { @@ -1059,12 +1059,6 @@ "node": ">= 12.20.55" } }, - "node_modules/electron/node_modules/@types/node": { - "version": "16.18.14", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz", - "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==", - "dev": true - }, "node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", @@ -1357,9 +1351,9 @@ "dev": true }, "node_modules/follow-redirects": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.2.tgz", - "integrity": "sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -1432,9 +1426,9 @@ } }, "node_modules/get-func-name": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz", - "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz", + "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==", "dev": true, "engines": { "node": "*" @@ -1542,9 +1536,9 @@ } }, "node_modules/global-agent/node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "optional": true, "dependencies": { @@ -2635,9 +2629,9 @@ } }, "node_modules/protobufjs": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.4.tgz", - "integrity": "sha512-AT+RJgD2sH8phPmCf7OUZR8xGdcJRga4+1cOaXJ64hvcSkVhNcRHOwIxUatPH15+nj59WAGTDv3LSGZPEQbJaQ==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "hasInstallScript": true, "dependencies": { "@protobufjs/aspromise": "^1.1.2", @@ -2908,9 +2902,9 @@ "dev": true }, "node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, "bin": { "semver": "bin/semver.js" @@ -4203,9 +4197,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } @@ -4339,22 +4333,14 @@ "dev": true }, "electron": { - "version": "23.3.13", - "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz", - "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==", + "version": "28.1.4", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", + "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", "dev": true, "requires": { "@electron/get": "^2.0.0", - "@types/node": "^16.11.26", + "@types/node": "^18.11.18", "extract-zip": "^2.0.1" - }, - "dependencies": { - "@types/node": { - "version": "16.18.14", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz", - "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==", - "dev": true - } } }, "emoji-regex": { @@ -4609,9 +4595,9 @@ "dev": true }, "follow-redirects": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.2.tgz", - "integrity": "sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "from": { @@ -4657,9 +4643,9 @@ "dev": true }, "get-func-name": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz", - "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz", + "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==", "dev": true }, "get-intrinsic": { @@ -4742,9 +4728,9 @@ }, "dependencies": { "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "optional": true, "requires": { @@ -5517,7 +5503,7 @@ "onnxruntime-common": { "version": "file:../common", "requires": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "p-cancelable": { @@ -5595,9 +5581,9 @@ "dev": true }, "protobufjs": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.4.tgz", - "integrity": "sha512-AT+RJgD2sH8phPmCf7OUZR8xGdcJRga4+1cOaXJ64hvcSkVhNcRHOwIxUatPH15+nj59WAGTDv3LSGZPEQbJaQ==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "requires": { "@protobufjs/aspromise": "^1.1.2", "@protobufjs/base64": "^1.1.2", @@ -5780,9 +5766,9 @@ "dev": true }, "semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true }, "semver-compare": { diff --git a/js/web/package.json b/js/web/package.json index 9b4531d7766f..384565dc0da9 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -1,6 +1,5 @@ { "license": "MIT", - "browser": "dist/ort-web.min.js", "unpkg": "dist/ort.min.js", "name": "onnxruntime-web", "repository": { @@ -8,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.17.0", + "version": "1.18.0", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", @@ -24,6 +23,7 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", + "test:training:e2e": "node ./test/training/e2e/run", "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", @@ -46,7 +46,7 @@ "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", - "electron": "^23.1.2", + "electron": "^28.1.4", "globby": "^13.1.3", "karma": "^6.4.1", "karma-browserstack-launcher": "^1.6.0", @@ -68,11 +68,14 @@ "exports": { ".": { "node": "./dist/ort.node.min.js", + "types": "./types.d.ts", "default": { "import": "./dist/esm/ort.min.js", "require": "./dist/cjs/ort.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.js", + "types": "./types.d.ts", "default": "./dist/ort.min.js" } } @@ -80,34 +83,41 @@ "./experimental": { "import": "./dist/esm/ort.all.min.js", "require": "./dist/cjs/ort.all.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.all.js", + "types": "./types.d.ts", "default": "./dist/ort.all.min.js" } }, "./wasm": { "import": "./dist/esm/ort.wasm.min.js", "require": "./dist/cjs/ort.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm.min.js" }, "./wasm-core": { "import": "./dist/esm/ort.wasm-core.min.js", "require": "./dist/cjs/ort.wasm-core.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm-core.min.js" }, "./webgl": { "import": "./dist/esm/ort.webgl.min.js", "require": "./dist/cjs/ort.webgl.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgl.min.js" }, "./webgpu": { "import": "./dist/esm/ort.webgpu.min.js", "require": "./dist/cjs/ort.webgpu.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" }, "./training": { "import": "./dist/esm/ort.training.wasm.min.js", "require": "./dist/cjs/ort.training.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.training.wasm.min.js" } }, diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 5151f27582c1..d3652f382035 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -121,7 +121,11 @@ async function buildOrt({ case 'node:fs/promises': case 'node:fs': case 'fs': - return {contents: 'export const readFile = undefined;'}; + return { + contents: 'export const readFile = undefined;' + + 'export const readFileSync = undefined;' + + 'export const createReadStream = undefined;' + }; case 'node:os': case 'os': return {contents: 'export const cpus = undefined;'}; @@ -367,10 +371,7 @@ async function main() { if (BUNDLE_MODE === 'dev') { // ort.all.js - await addBuildTask(buildOrt({ - outputBundleName: 'ort.all', - format: 'iife', - })); + await addBuildTask(buildOrt({outputBundleName: 'ort.all', format: 'iife', define: {...DEFAULT_DEFINE}})); } if (BUNDLE_MODE === 'perf') { @@ -404,7 +405,11 @@ async function main() { // ort.webgl[.min].js await addAllWebBuildTasks({ outputBundleName: 'ort.webgl', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true'}, + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WASM': 'true', + }, }); // ort.wasm-core[.min].js await addAllWebBuildTasks({ diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index ee955ec8d4f1..adcd940178e0 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -29,14 +29,15 @@ Options: *** General Options *** -h, --help Print this message. - -d, --debug Specify to run test runner in debug mode. - Debug mode outputs verbose log for test runner, sets up environment debug flag, and keeps karma not to exit after tests completed. + -d, --debug Specify to run test runner in debug mode. Debug mode does the following: + - outputs verbose log for test runner + - sets up environment debug flag (env.debug = true) + - opens Chromium debug port at 9333 and keeps karma not to exit after tests completed. -b=<...>, --backend=<...> Specify one or more backend(s) to run the test upon. Backends can be one or more of the following, splitted by comma: webgl webgpu wasm - xnnpack webnn -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: chrome (default) @@ -48,42 +49,61 @@ Options: bs (for BrowserStack tests) -p, --profile Enable profiler. Profiler will generate extra logs which include the information of events time consumption + -t, --trace Enable trace. -P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug. This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10. -c, --file-cache Enable file cache. + +*** Session Options *** + -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. + -o=<...>, --graph-optimization-level=<...> Specify graph optimization level. + Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'all'. -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following: - none (default) + none (default) gpu-tensor use pre-allocated GPU tensors for inputs and outputs gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' -*** Session Options *** - -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. - -o=<...>, --graph-optimization-level=<...> Specify graph optimization level. - Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'all'. *** Logging Options *** - --log-verbose=<...> Set log level to verbose - --log-info=<...> Set log level to info - --log-warning=<...> Set log level to warning - --log-error=<...> Set log level to error - The 4 flags above specify the logging configuration. Each flag allows to specify one or more category(s), splitted by comma. If use the flags without value, the log level will be applied to all category. + --log-verbose Set log level to verbose + --log-info Set log level to info + --log-warning Set log level to warning + --log-error Set log level to error + The 4 flags above specify the logging configuration. *** Backend Options *** + --wasm.<...>=<...> Set global environment flags for each backend. + --webgl.<...>=<...> These flags can be used multiple times to set multiple flags. For example: + --webgpu.<...>=<...> --webgpu.profiling.mode=default --wasm.numThreads=1 --wasm.simd=false + --webnn.<...>=<...> + + --webnn-device-type Set the WebNN device type (cpu/gpu/npu) + -x, --wasm-number-threads Set the WebAssembly number of threads + ("--wasm-number-threads" is deprecated. use "--wasm.numThreads" or "-x" instead) --wasm-init-timeout Set the timeout for WebAssembly backend initialization, in milliseconds + (deprecated. use "--wasm.initTimeout" instead) --wasm-enable-simd Set whether to enable SIMD + (deprecated. use "--wasm.simd" instead) --wasm-enable-proxy Set whether to enable proxy worker + (deprecated. use "--wasm.proxy" instead) --webgl-context-id Set the WebGL context ID (webgl/webgl2) + (deprecated. use "--webgl.contextId" instead) --webgl-matmul-max-batch-size Set the WebGL matmulMaxBatchSize + (deprecated. use "--webgl.matmulMaxBatchSize" instead) --webgl-texture-cache-mode Set the WebGL texture cache mode (initializerOnly/full) + (deprecated. use "--webgl.textureCacheMode" instead) --webgl-texture-pack-mode Set the WebGL texture pack mode (true/false) + (deprecated. use "--webgl.pack" instead) --webgpu-profiling-mode Set the WebGPU profiling mode (off/default) + (deprecated. use "--webgpu.profiling.mode" instead) *** Browser Options *** --no-sandbox This flag will be passed to Chrome. Sometimes Chrome need this flag to work together with Karma. + --user-data-dir=<...> This flag will be passed to browsers to specify the user data directory. --chromium-flags=<...> This flag will be passed to Chrome and Edge browsers. Can be used multiple times. Examples: @@ -110,7 +130,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; + type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'dev'|'perf'; type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; @@ -171,11 +191,12 @@ export interface TestRunnerCliArgs { cpuOptions?: InferenceSession.CpuExecutionProviderOption; cudaOptions?: InferenceSession.CudaExecutionProviderOption; - cudaFlags?: Record; wasmOptions?: InferenceSession.WebAssemblyExecutionProviderOption; webglOptions?: InferenceSession.WebGLExecutionProviderOption; + webnnOptions?: InferenceSession.WebNNExecutionProviderOption; globalEnvFlags?: Test.Options['globalEnvFlags']; noSandbox?: boolean; + userDataDir?: string; chromiumFlags: string[]; } @@ -259,40 +280,29 @@ function parseCpuOptions(_args: minimist.ParsedArgs): InferenceSession.CpuExecut return {name: 'cpu'}; } -function parseCpuFlags(_args: minimist.ParsedArgs): Record { - return {}; -} - function parseWasmOptions(_args: minimist.ParsedArgs): InferenceSession.WebAssemblyExecutionProviderOption { return {name: 'wasm'}; } function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags { - const numThreads = args.x || args['wasm-number-threads']; + const wasm = args.wasm || {}; + const numThreads = wasm.numThreads = wasm.numThreads ?? (args.x ?? args['wasm-number-threads']); if (typeof numThreads !== 'undefined' && typeof numThreads !== 'number') { - throw new Error('Flag "x"/"wasm-number-threads" must be a number value'); + throw new Error('Flag "wasm.numThreads"/"x"/"wasm-number-threads" must be a number value'); } - const initTimeout = args['wasm-init-timeout']; + const initTimeout = wasm.initTimeout = wasm.initTimeout ?? args['wasm-init-timeout']; if (typeof initTimeout !== 'undefined' && typeof initTimeout !== 'number') { - throw new Error('Flag "wasm-init-timeout" must be a number value'); - } - let simd = args['wasm-enable-simd']; - if (simd === 'true') { - simd = true; - } else if (simd === 'false') { - simd = false; - } else if (typeof simd !== 'undefined' && typeof simd !== 'boolean') { - throw new Error('Flag "wasm-enable-simd" must be a boolean value'); - } - let proxy = args['wasm-enable-proxy']; - if (proxy === 'true') { - proxy = true; - } else if (proxy === 'false') { - proxy = false; - } else if (typeof proxy !== 'undefined' && typeof proxy !== 'boolean') { - throw new Error('Flag "wasm-enable-proxy" must be a boolean value'); - } - return {numThreads, initTimeout, simd, proxy}; + throw new Error('Flag "wasm.initTimeout"/"wasm-init-timeout" must be a number value'); + } + const simd = wasm.simd = parseBooleanArg(wasm.simd ?? args['wasm-enable-simd']); + if (typeof simd !== 'undefined' && typeof simd !== 'boolean') { + throw new Error('Flag "wasm.simd"/"wasm-enable-simd" must be a boolean value'); + } + const proxy = wasm.proxy = parseBooleanArg(wasm.proxy ?? args['wasm-enable-proxy']); + if (typeof proxy !== 'undefined' && typeof proxy !== 'boolean') { + throw new Error('Flag "wasm.proxy"/"wasm-enable-proxy" must be a boolean value'); + } + return wasm; } function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLExecutionProviderOption { @@ -300,47 +310,58 @@ function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLEx } function parseWebglFlags(args: minimist.ParsedArgs): Partial { - const contextId = args['webgl-context-id']; + const webgl = args.webgl || {}; + const contextId = webgl.contextId = webgl.contextId ?? args['webgl-context-id']; if (contextId !== undefined && contextId !== 'webgl' && contextId !== 'webgl2') { - throw new Error('Flag "webgl-context-id" is invalid'); + throw new Error('Flag "webgl.contextId"/"webgl-context-id" is invalid'); } - const matmulMaxBatchSize = args['webgl-matmul-max-batch-size']; + const matmulMaxBatchSize = webgl.matmulMaxBatchSize = webgl.matmulMaxBatchSize ?? args['webgl-matmul-max-batch-size']; if (matmulMaxBatchSize !== undefined && typeof matmulMaxBatchSize !== 'number') { - throw new Error('Flag "webgl-matmul-max-batch-size" must be a number value'); + throw new Error('Flag "webgl.matmulMaxBatchSize"/"webgl-matmul-max-batch-size" must be a number value'); } - const textureCacheMode = args['webgl-texture-cache-mode']; + const textureCacheMode = webgl.textureCacheMode = webgl.textureCacheMode ?? args['webgl-texture-cache-mode']; if (textureCacheMode !== undefined && textureCacheMode !== 'initializerOnly' && textureCacheMode !== 'full') { - throw new Error('Flag "webgl-texture-cache-mode" is invalid'); + throw new Error('Flag "webgl.textureCacheMode"/"webgl-texture-cache-mode" is invalid'); } - const pack = args['webgl-texture-pack-mode']; + const pack = webgl.pack = parseBooleanArg(webgl.pack ?? args['webgl-texture-pack-mode']); if (pack !== undefined && typeof pack !== 'boolean') { - throw new Error('Flag "webgl-texture-pack-mode" is invalid'); + throw new Error('Flag "webgl.pack"/"webgl-texture-pack-mode" is invalid'); } - const async = args['webgl-async']; + const async = webgl.async = parseBooleanArg(webgl.async ?? args['webgl-async']); if (async !== undefined && typeof async !== 'boolean') { - throw new Error('Flag "webgl-async" is invalid'); + throw new Error('Flag "webgl.async"/"webgl-async" is invalid'); } - return {contextId, matmulMaxBatchSize, textureCacheMode, pack}; + return webgl; } function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { - const profilingMode = args['webgpu-profiling-mode']; + const webgpu = args.webgpu || {}; + const profilingMode = (webgpu.profiling = webgpu.profiling ?? {}).mode = + webgpu?.profiling?.mode ?? webgpu.profilingMode ?? args['webgpu-profiling-mode']; if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - const validateInputContent = args['webgpu-validate-input-content']; + const validateInputContent = webgpu.validateInputContent = + parseBooleanArg(webgpu.validateInputContent ?? args['webgpu-validate-input-content']); if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { throw new Error('Flag "webgpu-validate-input-content" is invalid'); } - return {profilingMode, validateInputContent}; + return webgpu; } -function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable { +function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExecutionProviderOption { + const deviceType = args['webnn-device-type']; + if (deviceType !== undefined && !['cpu', 'gpu', 'npu'].includes(deviceType)) { + throw new Error('Flag "webnn-device-type" is invalid'); + } + return {name: 'webnn', deviceType}; +} + +function parseGlobalEnvFlags(args: minimist.ParsedArgs) { const wasm = parseWasmFlags(args); const webgl = parseWebglFlags(args); const webgpu = parseWebgpuFlags(args); - const cpuFlags = parseCpuFlags(args); - return {webgl, wasm, webgpu, ...cpuFlags}; + return {webgl, wasm, webgpu}; } export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs { @@ -368,13 +389,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack', 'webnn']; + const browserBackends = ['webgl', 'webgpu', 'wasm', 'webnn']; // TODO: remove this when Chrome support WebNN. // we need this for now because Chrome does not support webnn yet, // and ChromeCanary is not in CI. - const defaultBrowserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack' /*, 'webnn'*/]; + const defaultBrowserBackends = ['webgl', 'webgpu', 'wasm' /*, 'webnn'*/]; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend = (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : defaultBrowserBackends) : @@ -385,19 +406,14 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } } - const globalEnvFlags = parseGlobalEnvFlags(args); - - if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); - } - // Options: // --log-verbose=<...> // --log-info=<...> // --log-warning=<...> // --log-error=<...> const logConfig = parseLogConfig(args); - globalEnvFlags.logLevel = logConfig[0]?.config.minimalSeverity; + let logLevel = logConfig[0]?.config.minimalSeverity; + // Option: -p, --profile const profile = (args.profile || args.p) ? true : false; if (profile) { @@ -405,9 +421,18 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig.push({category: 'Profiler.node', config: {minimalSeverity: 'verbose'}}); logConfig.push({category: 'Profiler.op', config: {minimalSeverity: 'verbose'}}); logConfig.push({category: 'Profiler.backend', config: {minimalSeverity: 'verbose'}}); - globalEnvFlags.logLevel = 'verbose'; + logLevel = 'verbose'; } + // Option: -t, --trace + const trace = parseBooleanArg(args.trace || args.t, false); + + // Options: + // --wasm.<...>=<...> + // --webgl.<...>=<...> + // --webgpu.<...>=<...> + const globalEnvFlags = {...parseGlobalEnvFlags(args), debug, trace, logLevel}; + // Option: -P[=<...>], --perf[=<...>] const perfArg = (args.perf || args.P); const perf = perfArg ? true : false; @@ -449,10 +474,14 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const wasmOptions = parseWasmOptions(args); const webglOptions = parseWebglOptions(args); + const webnnOptions = parseWebNNOptions(args); // Option: --no-sandbox const noSandbox = !!args['no-sandbox']; + // Option: --user-data-dir + const userDataDir = args['user-data-dir']; + // parse chromium flags let chromiumFlags = args['chromium-flags']; if (!chromiumFlags) { @@ -487,9 +516,11 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs fileCache, cpuOptions, webglOptions, + webnnOptions, wasmOptions, globalEnvFlags, noSandbox, + userDataDir, chromiumFlags }; } diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 74a03290332a..03d637b35bc7 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -12,6 +12,7 @@ import * as os from 'os'; import * as path from 'path'; import {inspect} from 'util'; +import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {bufferToBase64} from '../test/test-shared'; import {Test} from '../test/test-types'; @@ -165,6 +166,7 @@ async function main() { debug: args.debug, cpuOptions: args.cpuOptions, webglOptions: args.webglOptions, + webnnOptions: args.webnnOptions, wasmOptions: args.wasmOptions, globalEnvFlags: args.globalEnvFlags } @@ -263,10 +265,12 @@ async function main() { let modelUrl: string|null = null; let cases: Test.ModelTestCase[] = []; + let externalData: Array<{data: string; path: string}>|undefined; npmlog.verbose('TestRunnerCli.Init.Model', `Start to prepare test data from folder: ${testDataRootFolder}`); try { + const maybeExternalDataFiles: Array<[fileNameWithoutExtension: string, size: number]> = []; for (const thisPath of fs.readdirSync(testDataRootFolder)) { const thisFullPath = path.join(testDataRootFolder, thisPath); const stat = fs.lstatSync(thisFullPath); @@ -281,6 +285,8 @@ async function main() { } else { throw new Error('there are multiple model files under the folder specified'); } + } else { + maybeExternalDataFiles.push([path.parse(thisPath).name, stat.size]); } } else if (stat.isDirectory()) { const dataFiles: string[] = []; @@ -306,6 +312,34 @@ async function main() { if (modelUrl === null) { throw new Error('there are no model file under the folder specified'); } + // for performance consideration, we do not parse every model. when we think it's likely to have external + // data, we will parse it. We think it's "likely" when one of the following conditions is met: + // 1. any file in the same folder has the similar file name as the model file + // (e.g., model file is "model_abc.onnx", and there is a file "model_abc.pb" or "model_abc.onnx.data") + // 2. the file size is larger than 1GB + const likelyToHaveExternalData = maybeExternalDataFiles.some( + ([fileNameWithoutExtension, size]) => + path.basename(modelUrl!).startsWith(fileNameWithoutExtension) || size >= 1 * 1024 * 1024 * 1024); + if (likelyToHaveExternalData) { + const model = onnx.ModelProto.decode(fs.readFileSync(path.join(testDataRootFolder, path.basename(modelUrl!)))); + const externalDataPathSet = new Set(); + for (const initializer of model.graph!.initializer!) { + if (initializer.externalData) { + for (const data of initializer.externalData) { + if (data.key === 'location') { + externalDataPathSet.add(data.value!); + } + } + } + } + externalData = []; + const externalDataPaths = [...externalDataPathSet]; + for (const dataPath of externalDataPaths) { + const fullPath = path.resolve(testDataRootFolder, dataPath); + const url = path.join(TEST_DATA_BASE, path.relative(TEST_ROOT, fullPath)); + externalData.push({data: url, path: dataPath}); + } + } } catch (e) { npmlog.error('TestRunnerCli.Init.Model', `Failed to prepare test data. Error: ${inspect(e)}`); throw e; @@ -339,9 +373,23 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); npmlog.verbose('TestRunnerCli.Init.Model', ` Backend: ${backend}`); npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); + if (externalData) { + npmlog.verbose('TestRunnerCli.Init.Model', ` External data: ${externalData.length}`); + for (const data of externalData) { + npmlog.verbose('TestRunnerCli.Init.Model', ` - ${data.path}`); + } + } npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; + return { + name: path.basename(testDataRootFolder), + platformCondition, + modelUrl, + backend, + cases, + ioBinding, + externalData + }; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -494,14 +542,13 @@ async function main() { npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); const webgpu = args.backends.indexOf('webgpu') > -1; const webnn = args.backends.indexOf('webnn') > -1; - const browser = getBrowserNameFromEnv( - args.env, - args.bundleMode === 'perf' ? 'perf' : - args.debug ? 'debug' : - 'test', - webgpu, webnn); + const browser = getBrowserNameFromEnv(args.env); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; const chromiumFlags = ['--enable-features=SharedArrayBuffer', ...args.chromiumFlags]; + if (args.bundleMode === 'dev' && !args.debug) { + // use headless for 'test' mode (when 'perf' and 'debug' are OFF) + chromiumFlags.push('--headless=new'); + } if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); chromiumFlags.push('--remote-debugging-port=9333'); @@ -522,7 +569,13 @@ async function main() { if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } + if (process.argv.includes('--karma-debug')) { + karmaArgs.push('--log-level debug'); + } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); + if (args.userDataDir) { + karmaArgs.push(`--user-data-dir="${args.userDataDir}"`); + } karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { // There are currently 2 Edge browser launchers: @@ -614,15 +667,14 @@ async function main() { fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config); } - function getBrowserNameFromEnv( - env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean) { + function getBrowserNameFromEnv(env: TestRunnerCliArgs['env']) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu, webnn); + return 'ChromeTest'; case 'edge': return 'EdgeTest'; case 'firefox': - return 'Firefox'; + return 'FirefoxTest'; case 'electron': return 'Electron'; case 'safari': @@ -633,22 +685,6 @@ async function main() { throw new Error(`env "${env}" not supported.`); } } - - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, webnn: boolean) { - if (webnn) { - return 'ChromeCanaryTest'; - } else if (webgpu) { - return 'ChromeTest'; - } else { - switch (mode) { - case 'debug': - case 'perf': - return 'ChromeTest'; - default: - return 'ChromeTestHeadless'; - } - } - } } void main(); diff --git a/js/web/test/data/ops/add_zero-sized.jsonc b/js/web/test/data/ops/add_zero-sized.jsonc new file mode 100644 index 000000000000..37e08cd7f20a --- /dev/null +++ b/js/web/test/data/ops/add_zero-sized.jsonc @@ -0,0 +1,31 @@ +[ + { + "name": "Add with no attributes", + "operator": "Add", + "attributes": [], + "cases": [ + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc new file mode 100644 index 000000000000..be9625145d15 --- /dev/null +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -0,0 +1,641 @@ +[ + { + "name": "Concat 2D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": -2, "type": "int" }], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [], + "dims": [1, 4, 0, 64], + "type": "float32" + }, + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 0, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "All input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 7038e2a4f876..8ed48dd07e6f 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -392,5 +392,267 @@ ] } ] + }, + { + "name": "ConvTranspose without bias addition C", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "strides", "data": [2, 2], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + ], + "dims": [1, 4, 16, 16], + "type": "float32" + }, + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15 + ], + "dims": [4, 4, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, + 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, + 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, + 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, + 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, + 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, + 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, + 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, + 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, + 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, + 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, + 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, + 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, + 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, + 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, + 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, + 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, + 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, + 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, + 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, + 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, + 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, + 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, + 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, + 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, + 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, + 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, + 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, + 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, + 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, + 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, + 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, + 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, + 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, + 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, + 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, + 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, + 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, + 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, + 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, + 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, + 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, + 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, + 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, + 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, + 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, + 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 32, 36, 64, 72, + 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, + 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, + 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, + 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, + 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, + 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, + 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, + 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, + 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, + 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, + 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, + 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, + 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, + 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, + 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, + 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, + 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, + 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, + 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, + 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, + 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, + 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, + 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, + 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, + 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, + 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, + 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, + 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, + 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, + 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, + 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, + 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, + 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, + 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, + 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, + 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, + 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, + 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, + 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, + 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, + 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, + 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, + 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, + 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, + 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, + 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, + 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, + 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, + 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860 + ], + "dims": [1, 4, 32, 32], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 2e8eaaba191d..cc10df586423 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -298,7 +298,157 @@ } ] }, - + { + "name": "conv - vectorize group - A", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, + { "name": "group", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0], + "dims": [1, 2, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0], + "dims": [2, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - B", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [0.1, 0.2, 0.3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - C", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [34, 44, 54, 74, 84, 94, 386, 412, 438, 490, 516, 542, 1122, 1164, 1206, 1290, 1332, 1374], + "dims": [1, 3, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - D", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "strides", "data": [2, 2], "type": "ints" } + ], + "cases": [ + { + "name": "T[0] strides = [2, 2]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [34, 54, 386, 438, 1122, 1206], + "dims": [1, 3, 1, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "conv - pointwise", "operator": "Conv", diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 22bc04d558d9..613b4507b2b1 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -168,20 +168,39 @@ "name": "Expand - last dim is not divisible by 4", "inputs": [ { - "data": [true, false, false, true, true, true, false, false, false, true, true, true], - "dims": [2, 6], + "data": [true, false, false, true, true, true], + "dims": [1, 6], "type": "bool" }, { - "data": [2, 1], + "data": [3, 1], "dims": [2], "type": "int64" } ], "outputs": [ { - "data": [true, false, false, true, true, true, false, false, false, true, true, true], - "dims": [2, 6], + "data": [ + true, + false, + false, + true, + true, + true, + true, + false, + false, + true, + true, + true, + true, + false, + false, + true, + true, + true + ], + "dims": [3, 6], "type": "bool" } ] diff --git a/js/web/test/data/ops/fast-gelu.jsonc b/js/web/test/data/ops/fast-gelu.jsonc new file mode 100644 index 000000000000..2550173e9540 --- /dev/null +++ b/js/web/test/data/ops/fast-gelu.jsonc @@ -0,0 +1,211 @@ +[ + { + "name": "FastGelu test without bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.841192], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581, + 1.0617, 1.17393, 1.28671, 1.39957 + ], + "dims": [3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "FastGelu test with bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + }, + { + "data": [0.5], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.39957], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [3]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [2]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2, 3], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869, + 4.39999, 3.49938 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [7]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7], + "dims": [7], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989, + 4.09996, 3.59959 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[4x4], [8]", + "inputs": [ + { + "data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0], + "dims": [4, 4], + "type": "float32" + }, + { + "data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957, + 1.39957, 4.39999, 1.0617, -0.149419, 3.09737 + ], + "dims": [4, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 812e9d7c2def..6a10e3b96a26 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -108,5 +108,327 @@ ] } ] + }, + { + "name": "fused conv with clip", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [400, 470, 600, 600], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused group-conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused group-conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, -6, 51, 47, -170, -10, 251, 229, 847, 889, 973, 1015], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-162, 63, -158, 33, 281, 85, 105, 337, 455, 177, 515, 609], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/gather.jsonc b/js/web/test/data/ops/gather.jsonc index 0be077d237b8..d218d120d356 100644 --- a/js/web/test/data/ops/gather.jsonc +++ b/js/web/test/data/ops/gather.jsonc @@ -99,6 +99,28 @@ "operator": "Gather", "attributes": [], "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [false, true, false, false], + "dims": [4], + "type": "bool" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [true], + "dims": [], + "type": "bool" + } + ] + }, { "name": "data[2,4] indices[1]", "inputs": [ diff --git a/js/web/test/data/ops/instance-norm.jsonc b/js/web/test/data/ops/instance-norm.jsonc index 6a4e6912405e..f28b016d47ab 100644 --- a/js/web/test/data/ops/instance-norm.jsonc +++ b/js/web/test/data/ops/instance-norm.jsonc @@ -38,6 +38,79 @@ } ] }, + { + "name": "Simple test with NHWC, components 1", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5], + "dims": [1, 5, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539 + ], + "dims": [1, 5, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NHWC, components 2", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8], + "dims": [2, 6, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8, 9], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 7, 8, 9, 4, 5, 6, 7, 8, 9], + "dims": [2, 6, 1, 1], + "type": "float32" + } + ] + } + ] + }, { "name": "Simple test with NCHW", "operator": "InstanceNormalization", @@ -75,5 +148,161 @@ ] } ] + }, + { + "name": "Simple test with NCHW, components 1", + "operator": "InstanceNormalization", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5], + "dims": [1, 5, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539 + ], + "dims": [1, 5, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NCHW, components 2", + "operator": "InstanceNormalization", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2], + "dims": [1, 3, 6, 1], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.5361523628234863, 3.1216912269592285, 3.70723032951355, 4.292769432067871, 4.878308296203613, + 5.4638471603393555, 1.8666191101074219, 3.9555397033691406, 6.044460296630859, 8.133380889892578, + 6.044460296630859, 3.9555397033691406, 10.3915433883667, 8.634925842285156, 6.878308296203613, + 5.121691703796387, 3.365074634552002, 1.6084575653076172 + ], + "dims": [1, 3, 6, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NHWC, components 1, buffer reuse", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { + "domain": "", + "version": 17 + }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 4, 5, 6], + "dims": [2, 3, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NHWC, components 2, buffer reuse", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { + "domain": "", + "version": 17 + }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2], + "dims": [1, 6, 1, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8, 9], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539, + 16.348413467407227, 9, 1.6515865325927734 + ], + "dims": [1, 6, 1, 3], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc new file mode 100644 index 000000000000..63e0a0ed5287 --- /dev/null +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -0,0 +1,2486 @@ +[ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + }, + { + "dims": [8], + "type": "uint8", + "data": [248, 249, 250, 251, 252, 253, 254, 255] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 0, -505, -1600, -2043, -3904, -4285, -6912, -7231, 0, -1449, -5312, -6027, -12864, -12845, -22656, -21903, + 0, -2393, -9024, -10011, -21824, -21405, -38400, -36575, 0, -3337, -12736, -13995, -30784, -29965, -54144, + -51247, 0, -4281, -16448, -17979, -39744, -38525, -69888, -65919, 0, -5225, -20160, -21963, -48704, + -47085, -85632, -80591, 0, -6169, -23872, -25947, -57664, -55645, -101376, -95263, 0, -7113, -27584, + -29931, -66624, -64205, -117120, -109935 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [8, 32], + "type": "float32" + }, + { + "dims": [8, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -1073, -3763, -5429, -6071, -5689, -4283, -1853, 1601, -2449, -12499, -19477, -23383, -24217, -21979, + -16669, -8287, -3825, -21235, -33525, -40695, -42745, -39675, -31485, -18175, -5201, -29971, -47573, + -58007, -61273, -57371, -46301, -28063, -6577, -38707, -61621, -75319, -79801, -75067, -61117, -37951, + -7953, -47443, -75669, -92631, -98329, -92763, -75933, -47839, -9329, -56179, -89717, -109943, -116857, + -110459, -90749, -57727, -10705, -64915, -103765, -127255, -135385, -128155, -105565, -67615 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [8, 32], + "type": "float32" + }, + { + "dims": [8, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [8], + "type": "uint8", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 1935, 6941, 12491, 18585, 25223, 32405, 40131, 48401, 4655, 17661, 31211, 45305, 59943, 75125, 90851, + 107121, 7375, 28381, 49931, 72025, 94663, 117845, 141571, 165841, 10095, 39101, 68651, 98745, 129383, + 160565, 192291, 224561, 12815, 49821, 87371, 125465, 164103, 203285, 243011, 283281, 15535, 60541, 106091, + 152185, 198823, 246005, 293731, 342001, 18255, 71261, 124811, 178905, 233543, 288725, 344451, 400721, + 20975, 81981, 143531, 205625, 268263, 331445, 395171, 459441 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 48, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383 + ], + "dims": [8, 48], + "type": "float32" + }, + { + "dims": [8, 3, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192 + ] + }, + { + "dims": [24], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -7569, -13416, -24375, -14292, -20445, 5568, 4221, 46164, -17697, -39528, -73383, -45588, -66861, 10560, + 1869, 128916, -27825, -65640, -122391, -76884, -113277, 15552, -483, 211668, -37953, -91752, -171399, + -108180, -159693, 20544, -2835, 294420, -48081, -117864, -220407, -139476, -206109, 25536, -5187, 377172, + -58209, -143976, -269415, -170772, -252525, 30528, -7539, 459924, -68337, -170088, -318423, -202068, + -298941, 35520, -9891, 542676, -78465, -196200, -367431, -233364, -345357, 40512, -12243, 625428 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 48, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383 + ], + "dims": [8, 48], + "type": "float32" + }, + { + "dims": [8, 3, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192 + ] + }, + { + "dims": [24], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + }, + { + "dims": [16], + "type": "uint8", + "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -1353, -5984, -24751, -31500, -63509, -72376, -117627, -128612, -6105, -20576, -74527, -94284, -190565, + -215608, -354219, -384548, -10857, -35168, -124303, -157068, -317621, -358840, -590811, -640484, -15609, + -49760, -174079, -219852, -444677, -502072, -827403, -896420, -20361, -64352, -223855, -282636, -571733, + -645304, -1063995, -1152356, -25113, -78944, -273631, -345420, -698789, -788536, -1300587, -1408292, + -29865, -93536, -323407, -408204, -825845, -931768, -1537179, -1664228, -34617, -108128, -373183, -470988, + -952901, -1075000, -1773771, -1920164 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 64, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, + 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, + 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, + 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, + 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, + 505, 506, 507, 508, 509, 510, 511 + ], + "dims": [8, 64], + "type": "float32" + }, + { + "dims": [8, 4, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -13572, -28812, -27668, -10140, 23772, 74068, 140748, 192564, -33796, -91532, -100116, -59548, 30172, + 169044, 357068, 531252, -54020, -154252, -172564, -108956, 36572, 264020, 573388, 869940, -74244, -216972, + -245012, -158364, 42972, 358996, 789708, 1208628, -94468, -279692, -317460, -207772, 49372, 453972, + 1006028, 1547316, -114692, -342412, -389908, -257180, 55772, 548948, 1222348, 1886004, -134916, -405132, + -462356, -306588, 62172, 643924, 1438668, 2224692, -155140, -467852, -534804, -355996, 68572, 738900, + 1654988, 2563380 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 64, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, + 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, + 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, + 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, + 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, + 505, 506, 507, 508, 509, 510, 511 + ], + "dims": [8, 64], + "type": "float32" + }, + { + "dims": [8, 4, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -26004, -63644, -96932, -125868, -150452, -170684, -186564, -229340, -60564, -157084, -249252, -337068, + -420532, -499644, -574404, -707804, -95124, -250524, -401572, -548268, -690612, -828604, -962244, + -1186268, -129684, -343964, -553892, -759468, -960692, -1157564, -1350084, -1664732, -164244, -437404, + -706212, -970668, -1230772, -1486524, -1737924, -2143196, -198804, -530844, -858532, -1181868, -1500852, + -1815484, -2125764, -2621660, -233364, -624284, -1010852, -1393068, -1770932, -2144444, -2513604, + -3100124, -267924, -717724, -1163172, -1604268, -2041012, -2473404, -2901444, -3578588 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 80, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, + 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, + 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, + 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, + 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, + 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, + 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, + 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, + 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, + 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, + 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, + 631, 632, 633, 634, 635, 636, 637, 638, 639 + ], + "dims": [8, 80], + "type": "float32" + }, + { + "dims": [8, 5, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320 + ] + }, + { + "dims": [40], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 + ] + }, + { + "dims": [24], + "type": "uint8", + "data": [ + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, + 261, 262, 263 + ] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -19988, -63429, -155448, -216179, -358428, 740351, 259888, 172481, -56788, -186869, -451128, -632899, + -1053788, 1574031, 1165488, 546481, -93588, -310309, -746808, -1049619, -1749148, 2407711, 2071088, + 920481, -130388, -433749, -1042488, -1466339, -2444508, 3241391, 2976688, 1294481, -167188, -557189, + -1338168, -1883059, -3139868, 4075071, 3882288, 1668481, -203988, -680629, -1633848, -2299779, -3835228, + 4908751, 4787888, 2042481, -240788, -804069, -1929528, -2716499, -4530588, 5742431, 5693488, 2416481, + -277588, -927509, -2225208, -3133219, -5225948, 6576111, 6599088, 2790481 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0, + -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065, + 0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592, + 7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107, + -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827, + -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081, + -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, + -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405, + -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403, + -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577, + -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040, + 29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360, + 10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560, + -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104, + -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456, + -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483, + -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0, + -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587, + -123424, 47745 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, 608, 208, 1296, -288, 1280, -1488, 560, -3392, -864, -6000, -2992, -9312, -5824, -13328, -9360, 0, + 1824, 336, 3792, -1568, 3520, -5712, 1008, -12096, -3744, -20720, -10736, -31584, -19968, -44688, -31440, + 0, 3040, 464, 6288, -2848, 5760, -9936, 1456, -20800, -6624, -35440, -18480, -53856, -34112, -76048, + -53520, 0, 4256, 592, 8784, -4128, 8000, -14160, 1904, -29504, -9504, -50160, -26224, -76128, -48256, + -107408, -75600, 0, 5472, 720, 11280, -5408, 10240, -18384, 2352, -38208, -12384, -64880, -33968, -98400, + -62400, -138768, -97680, 0, 6688, 848, 13776, -6688, 12480, -22608, 2800, -46912, -15264, -79600, -41712, + -120672, -76544, -170128, -119760, 0, 7904, 976, 16272, -7968, 14720, -26832, 3248, -55616, -18144, + -94320, -49456, -142944, -90688, -201488, -141840, 0, 9120, 1104, 18768, -9248, 16960, -31056, 3696, + -64320, -21024, -109040, -57200, -165216, -104832, -232848, -163920, 0, 10336, 1232, 21264, -10528, 19200, + -35280, 4144, -73024, -23904, -123760, -64944, -187488, -118976, -264208, -186000, 0, 11552, 1360, 23760, + -11808, 21440, -39504, 4592, -81728, -26784, -138480, -72688, -209760, -133120, -295568, -208080, 0, + 12768, 1488, 26256, -13088, 23680, -43728, 5040, -90432, -29664, -153200, -80432, -232032, -147264, + -326928, -230160, 0, 13984, 1616, 28752, -14368, 25920, -47952, 5488, -99136, -32544, -167920, -88176, + -254304, -161408, -358288, -252240, 0, 15200, 1744, 31248, -15648, 28160, -52176, 5936, -107840, -35424, + -182640, -95920, -276576, -175552, -389648, -274320, 0, 16416, 1872, 33744, -16928, 30400, -56400, 6384, + -116544, -38304, -197360, -103664, -298848, -189696, -421008, -296400, 0, 17632, 2000, 36240, -18208, + 32640, -60624, 6832, -125248, -41184, -212080, -111408, -321120, -203840, -452368, -318480, 0, 18848, + 2128, 38736, -19488, 34880, -64848, 7280, -133952, -44064, -226800, -119152, -343392, -217984, -483728, + -340560 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140, + -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0, + -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260, + -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360, + 42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132, + -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096, + 80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332, + -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500, + 20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996, + -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860, + -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608, + -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024, + 48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556, + -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352, + 19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936, + 208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800, + -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304, + 159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224, + -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640, + 105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480, + -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220, + -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964, + 131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028, + -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900, + 55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892, + -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160, + 142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544, + -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580, + -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696, + 401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852, + -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632, + 223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384, + -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484, + -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080, + 435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668, + -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592, + 328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448, + -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236, + -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124, + -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056, + 54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700, + 160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664, + -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908, + -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988, + -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420, + -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416, + 625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940, + -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488, + 339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432, + -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920, + 122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156, + 293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600, + -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692, + 37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100, + -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892, + -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040, + 708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420, + -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416, + 519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420, + -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316, + -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920, + 768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476, + -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488, + 562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740, + -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740, + -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800, + 829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532, + -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560, + 604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060, + -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164, + -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680, + 889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588, + -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632, + 646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380, + -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588, + -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560, + 950212 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136, + 21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004, + 0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300, + 42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780, + 118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524, + 53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388, + 171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800, + 77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728, + 255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964, + 40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740, + 200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516, + 18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652, + 173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960, + 514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180, + 135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448, + 532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468, + 92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144, + 479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688, + 68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236, + 267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600, + 795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800, + 263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700, + 558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212, + 92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040, + 583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376, + 51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700, + 310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984, + 1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240, + 275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364, + 637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060, + 78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712, + 645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000, + 1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440, + 401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700, + 849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252, + 135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640, + 860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420, + 23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992, + 558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460, + 1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948, + 214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624, + 1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548, + 61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976, + 750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112, + 1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360, + 429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012, + 992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624, + 165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572, + 640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532, + 1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220, + 348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340, + 1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280, + 181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628, + 702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348, + 1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220, + 380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668, + 1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936, + 197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684, + 764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164, + 1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220, + 412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996, + 1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592, + 214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740, + 827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980, + 1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220, + 445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324, + 1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248, + 230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796, + 889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796, + 2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220, + 477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652, + 1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904, + 246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852, + 951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612, + 2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220, + 509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980, + 1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124, + 108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012, + 87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, + 4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620, + -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172, + -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372, + -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916, + 548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, + 238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844, + -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420, + -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876, + -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, + 876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004, + 207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188, + -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492, + -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388, + 1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220, + 450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260, + 284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372, + 349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, + 337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, + 284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364, + 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500, + 43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, + -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, + 1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076, + 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764, + 443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556, + 144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, + -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300, + 1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, + 1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108, + 633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188, + 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436, + 3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964, + 1856780, 2218404, 2611836, 3037076, 3240172 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476, + 100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, + 56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844, + 6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228, + -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124, + -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060, + -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852, + 302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468, + 175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580, + -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724, + -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996, + -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, + 466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124, + 260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212, + -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524, + -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172, + -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956, + 484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564, + 138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492, + -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692, + -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836, + 1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780, + 173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084, + -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396, + -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532, + 927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764, + 474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492, + -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468, + -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156, + 2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116, + 224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860, + -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004, + -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, + 860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124, + -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980, + -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940, + -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036, + 1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884, + -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916, + 2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564, + 293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788, + -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508, + -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676, + 1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884, + -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228, + -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508, + -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172, + 1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820, + -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964, + -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676, + 3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012, + 362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716, + -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012, + -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564, + 1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644, + -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476, + -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076, + -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308, + 1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084, + -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044, + -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436, + 4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460, + 431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644, + -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676, + -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204, + 1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516, + -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116, + -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764, + 2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012, + 1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628, + -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940, + -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156, + 3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900, + 182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324, + -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708, + 5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980, + 905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404, + -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116, + -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428, + 2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404, + 491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636, + -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836, + -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652, + 3300516, 4360956, 5485972, 5643372 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404, + 163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, + 170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620, + 255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572, + 79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716, + 158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244, + -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396, + 716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892, + 1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100, + 350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820, + 541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164, + 36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, + -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836, + 1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468, + 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852, + 791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668, + 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956, + 1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068, + 2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, + 544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204, + 594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300, + -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748, + 1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972, + 2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836, + 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364, + 69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564, + 3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, + 1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932, + 1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388, + 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156, + 2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948, + 3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452, + 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444, + 957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700, + -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, + 2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612, + 3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036, + 355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084, + 113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644, + 5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308, + 1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172, + 1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508, + -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956, + 2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196, + 3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756, + 424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716, + 139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292, + 6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860, + 1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996, + 1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636, + -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284, + 3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780, + 4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476, + 494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348, + 165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940, + 8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412, + 1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820, + 2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764, + -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612, + 3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364, + 5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196, + 564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980, + 191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588, + 9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964, + 2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644, + 2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460, + -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356, + 3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020, + 4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044, + 386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436, + -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220, + 9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292, + 1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164, + 2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268, + 12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420, + 3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844, + 4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564, + 217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676, + 6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916, + 8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156, + 1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100, + 1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668, + 13169812, 13583340 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040, + -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416, + 59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072, + 52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888, + 66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712, + 10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832, + -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104, + -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672, + -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648, + 315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488, + 104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056, + -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360, + -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824, + -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256, + 0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360, + 266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600, + 147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560, + -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960, + -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024, + -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640, + 367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488, + 314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872, + 28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840, + -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640, + -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384, + 453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088, + 385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896, + 34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720, + -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256, + -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128, + 538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688, + 456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472, + -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600, + -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872, + -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872, + 624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288, + 527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776, + -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744, + -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0, + -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040, + 665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600, + 355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496, + -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760, + -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760, + 1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536, + 411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048, + -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648, + -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592, + -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464, + 1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384, + 168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112, + -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008, + -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160, + -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888, + 775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936, + -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400, + -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960, + 1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696, + 520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888, + -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496, + -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320, + -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288, + 881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400, + -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384, + -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280, + 2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992, + 586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256, + -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560, + -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208, + 1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800, + 584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760, + -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128, + -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112, + 1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112, + 80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640, + -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0, + -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824, + 933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184, + -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960, + -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360, + -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888, + 1129464, 1623328, 2149480, 2707920, 2790744 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400, + 38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912, + 186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064, + 213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000, + 366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456, + 166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400, + 348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248, + 58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200, + 224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712, + 1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800, + 477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760, + 874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320, + 179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928, + 493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472, + 1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488, + 789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736, + 1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680, + 243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984, + 669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096, + 2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576, + 1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712, + 1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040, + 308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840, + 716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400, + 2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760, + 950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512, + 1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888, + 113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568, + 1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000, + 2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112, + 656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384, + 1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656, + 4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080, + 1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512, + 2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760, + 345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800, + 905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880, + 4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888, + 1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600, + 2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0, + 72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944, + 2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616, + 3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480, + 565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456, + 1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704, + 4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960, + 1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488, + 2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600, + 6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560, + 2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120, + 3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104, + 487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400, + 1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752, + 5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104, + 1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536, + 2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464, + 7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360, + 2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376, + 3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496, + 231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936, + 3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968, + 4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600, + 1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120, + 1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960, + 6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128, + 1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752, + 3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000, + 8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368, + 3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496, + 4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256, + 447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920, + 4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000, + 6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736, + 1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072, + 2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912, + 8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840, + 2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200, + 3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0, + 134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632, + 4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448, + 5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568, + 748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000, + 1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464, + 8100216, 8852256, 9636584, 10453200, 10794200 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1]", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1]; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [1, 8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [1, 8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [1, 8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1, 2]", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric, batchDim = [1, 2]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [1, 2, 8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [1, 8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [1, 2, 8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247, 0, -5889, -22624, -14403, -40896, -18565, -54816, -18375, 0, -6577, -25312, + -16083, -45760, -20725, -61344, -20503, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, 0, + -7953, -30688, -19443, -55488, -25045, -74400, -24759, 0, -8641, -33376, -21123, -60352, -27205, -80928, + -26887, 0, -9329, -36064, -22803, -65216, -29365, -87456, -29015, 0, -10017, -38752, -24483, -70080, + -31525, -93984, -31143, 0, -10705, -41440, -26163, -74944, -33685, -100512, -33271 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; output shape = 8 X 16; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [8, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; output shape = 16 X 8; K=16, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [16, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247, 0, -5889, -22624, -14403, -40896, -18565, -54816, -18375, 0, -6577, -25312, + -16083, -45760, -20725, -61344, -20503, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, 0, + -7953, -30688, -19443, -55488, -25045, -74400, -24759, 0, -8641, -33376, -21123, -60352, -27205, -80928, + -26887, 0, -9329, -36064, -22803, -65216, -29365, -87456, -29015, 0, -10017, -38752, -24483, -70080, + -31525, -93984, -31143, 0, -10705, -41440, -26163, -74944, -33685, -100512, -33271 + ] + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/rotary-embedding.jsonc b/js/web/test/data/ops/rotary-embedding.jsonc new file mode 100644 index 000000000000..1b564ecc7740 --- /dev/null +++ b/js/web/test/data/ops/rotary-embedding.jsonc @@ -0,0 +1,925 @@ +[ + { + "name": "RotaryEmbedding with no attributes", + "operator": "RotaryEmbedding", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [], + "cases": [ + { + "name": "T[2,8,24] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, 0.6923, 1.1571, 0.7572, -1.1471, -0.5302, -0.4391, + 0.5516, 1.0461, -0.4812, -0.1443, -0.4862, -0.6423, 0.674, -0.4614, 0.5475, 1.1495, 0.2389, 0.8582, + -0.0259, -0.6099, -0.223, 1.0963, -1.5704, -0.4595, 0.9507, 0.6696, -0.7721, -1.7415, 1.2087, -0.6387, + -1.1052, -0.5243, -0.04, -0.4671, 0.4909, -0.1931, -0.1937, -0.0447, -0.3171, 2.6839, -0.0076, 1.5185, + 0.8465, 0.3737, 0.0242, -0.0703, 1.1279, 0.8862, 1.2275, -0.1786, -0.8767, -1.8072, -0.263, 0.9387, + -0.8021, 0.7813, 0.5001, -1.4202, -0.385, 0.9263, -0.0443, -0.2323, 0.548, 1.5696, 0.6193, -1.1346, + 1.7878, -0.516, 0.1192, -2.1572, 0.046, 1.1202, -1.4812, -0.9082, 0.1728, -1.5132, -0.4489, 0.337, + -0.1541, -0.9266, 0.2416, 0.927, -1.1146, 1.8758, -0.4312, 1.3714, 1.2106, -0.4272, -0.8529, 1.0328, + 1.8441, 1.7698, -0.762, 0.2168, 0.1322, -0.2802, 0.146, 2.1002, 0.8437, -0.1534, 0.4321, 0.836, 0.5955, + -1.5452, -0.0491, -0.8794, 0.2418, -1.4203, 0.3635, 0.2362, 0.3672, -0.1128, -0.8664, -0.6354, -1.4409, + -0.3413, -0.2409, -0.3188, 1.1054, 0.4265, 0.5867, -1.3279, 0.3201, 0.0125, 1.8157, 1.0745, 0.7372, + -0.2429, 0.71, -0.4299, -0.2304, 0.1645, 0.9489, -0.1816, -0.5968, 1.0394, 0.0204, 1.1786, -0.3315, + -0.3997, -0.9304, -1.4268, -1.1526, -0.1132, 0.149, 1.3967, -1.4634, -0.1412, -0.6339, -1.5995, -0.1366, + 0.7604, 0.1514, 0.0824, -1.183, -1.6572, 2.0099, -0.9108, -0.2256, 0.4527, -1.8254, 0.6475, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -1.4979, + -1.1358, 1.632, 0.2493, 0.8266, 0.3424, -0.4992, 0.2964, 0.7298, 1.8544, 0.3516, 0.0454, 1.5415, -0.2822, + -2.0774, 1.2323, 0.3963, -1.1503, -0.4775, -1.9287, -1.6164, 0.3998, 0.902, -0.0764, -1.8059, -0.5762, + -1.4362, -0.2706, -1.0183, -0.462, 2.0891, 0.1782, 1.1591, -0.8151, 1.3, -1.2464, -0.5099, 0.5098, + -3.3525, 0.4326, 0.7414, -0.7775, -0.4271, -0.3807, 1.3245, 2.4936, 0.3139, 1.0095, 0.2323, 0.845, + -1.2244, -0.4511, 0.6266, 0.9095, -1.7981, 1.5241, -0.4121, 0.2341, -0.4737, -1.3333, -1.615, 0.4164, + 0.71, -0.2429, -0.5656, 0.0863, 0.0352, -0.7227, -1.3613, -0.0988, -1.9114, -0.3009, 0.1435, 0.7029, + -0.3467, 0.5092, -0.0828, 0.6253, 0.7113, -1.2138, 1.5964, -0.8346, -1.1515, -0.7923, -0.8254, -3.0038, + 2.4033, -0.3398, 0.0922, 1.7053, 1.1114, 0.7462, 2.366, -0.8409, -0.6654, -0.653, -0.7899, -1.0957, + -0.7149, -0.1072, -0.1967, -2.3416, -1.2609, -1.6375, -0.3576, 0.9413, -0.5694, 0.3954, 0.1383, -0.7477, + -0.8689, 1.8286, 0.851, -1.4793, -0.1597, 0.8541, 0.238, 1.4392, -0.5644, 0.3158, -1.0686, -0.1313, + -0.0181, 0.2438, -0.8801, 0.1413, -0.3587, 0.8002, -0.5982, -1.4301, -0.662, 0.7324, -0.725, 0.061, + 0.9293, -0.6902, -0.0125, -0.2089, -0.1664, 0.5428, 0.4245, -0.7901, 0.5665, 0.9044, 0.1948, -0.1723, + 1.2705, 1.0303, 1.2202, 1.3762, -0.2959, 0.7237, -1.2077, 0.7937, -0.6705, 0.9287, 1.0583, 0.0496, + -1.3118, 0.5556, 0.0459, -0.1324, -0.5513, -0.7409, -1.8002, 0.9892, 0.3619, -1.4522 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,8,24] Scalar T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, 0.6923, 1.1571, 0.7572, -1.1471, -0.5302, -0.4391, + 0.5516, 1.0461, -0.4812, -0.1443, -0.4862, -0.6423, 0.674, -0.4614, 0.5475, 1.1495, 0.2389, 0.8582, + -0.0259, -0.6099, -0.223, 1.0963, -1.5704, -0.4595, 0.9507, 0.6696, -0.7721, -1.7415, 1.2087, -0.6387, + -1.1052, -0.5243, -0.04, -0.4671, 0.4909, -0.1931, -0.1937, -0.0447, -0.3171, 2.6839, -0.0076, 1.5185, + 0.8465, 0.3737, 0.0242, -0.0703, 1.1279, 0.8862, 1.2275, -0.1786, -0.8767, -1.8072, -0.263, 0.9387, + -0.8021, 0.7813, 0.5001, -1.4202, -0.385, 0.9263, -0.0443, -0.2323, 0.548, 1.5696, 0.6193, -1.1346, + 1.7878, -0.516, 0.1192, -2.1572, 0.046, 1.1202, -1.4812, -0.9082, 0.1728, -1.5132, -0.4489, 0.337, + -0.1541, -0.9266, 0.2416, 0.927, -1.1146, 1.8758, -0.4312, 1.3714, 1.2106, -0.4272, -0.8529, 1.0328, + 1.8441, 1.7698, -0.762, 0.2168, 0.1322, -0.2802, 0.146, 2.1002, 0.8437, -0.1534, 0.4321, 0.836, 0.5955, + -1.5452, -0.0491, -0.8794, 0.2418, -1.4203, 0.3635, 0.2362, 0.3672, -0.1128, -0.8664, -0.6354, -1.4409, + -0.3413, -0.2409, -0.3188, 1.1054, 0.4265, 0.5867, -1.3279, 0.3201, 0.0125, 1.8157, 1.0745, 0.7372, + -0.2429, 0.71, -0.4299, -0.2304, 0.1645, 0.9489, -0.1816, -0.5968, 1.0394, 0.0204, 1.1786, -0.3315, + -0.3997, -0.9304, -1.4268, -1.1526, -0.1132, 0.149, 1.3967, -1.4634, -0.1412, -0.6339, -1.5995, -0.1366, + 0.7604, 0.1514, 0.0824, -1.183, -1.6572, 2.0099, -0.9108, -0.2256, 0.4527, -1.8254, 0.6475, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -1.4979, + -1.1358, 1.632, 0.2493, 0.8266, 0.3424, -0.4992, 0.2964, 0.7298, 1.8544, 0.3516, 0.0454, 1.5415, -0.2822, + -2.0774, 1.2323, 0.3963, -1.1503, -0.4775, -1.9287, -1.6164, 0.3998, 0.902, -0.0764, -1.8059, -0.5762, + -1.4362, -0.2706, -1.0183, -0.462, 2.0891, 0.1782, 1.1591, -0.8151, 1.3, -1.2464, -0.5099, 0.5098, + -3.3525, 0.4326, 0.7414, -0.7775, -0.4271, -0.3807, 1.3245, 2.4936, 0.3139, 1.0095, 0.2323, 0.845, + -1.2244, -0.4511, 0.6266, 0.9095, -1.7981, 1.5241, -0.4121, 0.2341, -0.4737, -1.3333, -1.615, 0.4164, + 0.71, -0.2429, -0.5656, 0.0863, 0.0352, -0.7227, -1.3613, -0.0988, -1.9114, -0.3009, 0.1435, 0.7029, + -0.3467, 0.5092, -0.0828, 0.6253, 0.7113, -1.2138, 1.5964, -0.8346, -1.1515, -0.7923, -0.8254, -3.0038, + 2.4033, -0.3398, 0.0922, 1.7053, 1.1114, 0.7462, 2.366, -0.8409, -0.6654, -0.653, -0.7899, -1.0957, + -0.7149, -0.1072, -0.1967, -2.3416, -1.2609, -1.6375, -0.3576, 0.9413, -0.5694, 0.3954, 0.1383, -0.7477, + -0.8689, 1.8286, 0.851, -1.4793, -0.1597, 0.8541, 0.238, 1.4392, -0.5644, 0.3158, -1.0686, -0.1313, + -0.0181, 0.2438, -0.8801, 0.1413, -0.3587, 0.8002, -0.5982, -1.4301, -0.662, 0.7324, -0.725, 0.061, + 0.9293, -0.6902, -0.0125, -0.2089, -0.1664, 0.5428, 0.4245, -0.7901, 0.5665, 0.9044, 0.1948, -0.1723, + 1.2705, 1.0303, 1.2202, 1.3762, -0.2959, 0.7237, -1.2077, 0.7937, -0.6705, 0.9287, 1.0583, 0.0496, + -1.3118, 0.5556, 0.0459, -0.1324, -0.5513, -0.7409, -1.8002, 0.9892, 0.3619, -1.4522 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,4,8,6] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 0.3152, 1.7528, -0.765, 1.8299, -0.2784, -0.2719, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, -1.1874, -0.7468, -0.932, -0.8579, -0.9647, -0.0991, -1.019, + 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, -1.9791, + 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, 2.1134, + -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 0.5599, + -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, 0.1036, + -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 0.0352, + -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 1.0122, -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, -0.0489, 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.0311, + -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, 2.5211, + -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.6052, + 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, 0.6207, -0.169, + -0.5816, 1.2632, 0.0695, 1.1862, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.5996, -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -0.2937, 0.9238, -1.2185, 0.4138, 0.5033, 0.9174, -0.4792, 0.6756, + -0.3413, -0.2242, -0.2111, 0.6282, -0.1213, -1.1116, -0.7401, -0.7879, 0.0606, -2.3337, -1.0941, -0.3682, + -0.0163, -0.0645, -0.8101, 0.1415, 0.8238, 0.2262, 1.2912, 0.6488, 1.2114, 1.3569, -0.2852, 0.6051, + 0.2167, -0.2181, -1.6306, 1.4788, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, -1.6105, 0.2979, 1.1537, + -1.5604, 1.2779, -1.2514, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 0.1667, -1.4055, 1.5895, + 1.0838, -0.9077, -0.806, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, 0.0551, 0.5873, -0.5887, + -1.4733, -0.8565, 0.74, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 0.2754, -0.0261, -0.4618, + -0.5646, -1.0389, 0.5819, 1.8698, -0.2635, -2.0799, -0.6313, 0.409, -1.1458, 0.6056, 0.5763, -3.3558, + 0.2836, 0.6909, -0.7631, 1.5646, 0.3338, 0.7105, 0.4683, -0.6179, 0.0818, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -1.7142, -0.5319, -0.8848, 0.6513, 1.0002, -1.4699, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 1.1629, 0.0616, -1.3136, -0.2764, 0.0277, -0.1126, 1.3697, 0.0002, 1.5333, + -1.0556, -0.1254, 0.1527, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 2.4451, -0.35, 1.3289, + -0.6494, 0.3478, 1.0038, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.0901, 0.6106, 2.3603, + 1.3908, -0.7917, -0.6734, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, 0.061, 0.6776, 0.4361, + -0.8052, 0.3955, 0.8988, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, + -0.0259, -0.6099, -0.223, 1.0963, -1.5704, -0.4595, 0.8465, 0.3737, 0.0242, -0.0703, 1.1279, 0.8862, + 1.7878, -0.516, 0.1192, -2.1572, 0.046, 1.1202, 1.8441, 1.7698, -0.762, 0.2168, 0.1322, -0.2802, -1.4409, + -0.3413, -0.2409, -0.3188, 1.1054, 0.4265, -0.3315, -0.3997, -0.9304, -1.4268, -1.1526, -0.1132, -1.019, + 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, 0.6923, 1.1571, 0.7572, -1.1471, -0.5302, -0.4391, 0.9507, + 0.6696, -0.7721, -1.7415, 1.2087, -0.6387, 1.2275, -0.1786, -0.8767, -1.8072, -0.263, 0.9387, -1.4812, + -0.9082, 0.1728, -1.5132, -0.4489, 0.337, 0.146, 2.1002, 0.8437, -0.1534, 0.4321, 0.836, 0.5867, -1.3279, + 0.3201, 0.0125, 1.8157, 1.0745, 0.149, 1.3967, -1.4634, -0.1412, -0.6339, -1.5995, 0.1036, -0.3514, + 0.2421, 0.6463, 0.873, -0.9276, 0.5516, 1.0461, -0.4812, -0.1443, -0.4862, -0.6423, -1.1052, -0.5243, + -0.04, -0.4671, 0.4909, -0.1931, -0.8021, 0.7813, 0.5001, -1.4202, -0.385, 0.9263, -0.1541, -0.9266, + 0.2416, 0.927, -1.1146, 1.8758, 0.5955, -1.5452, -0.0491, -0.8794, 0.2418, -1.4203, 0.7372, -0.2429, 0.71, + -0.4299, -0.2304, 0.1645, -0.1366, 0.7604, 0.1514, 0.0824, -1.183, -1.6572, 1.0311, -1.9557, -0.1482, + 1.7376, 2.2039, -0.6589, 0.674, -0.4614, 0.5475, 1.1495, 0.2389, 0.8582, -0.1937, -0.0447, -0.3171, + 2.6839, -0.0076, 1.5185, -0.0443, -0.2323, 0.548, 1.5696, 0.6193, -1.1346, -0.4312, 1.3714, 1.2106, + -0.4272, -0.8529, 1.0328, 0.3635, 0.2362, 0.3672, -0.1128, -0.8664, -0.6354, 0.9489, -0.1816, -0.5968, + 1.0394, 0.0204, 1.1786, 2.0099, -0.9108, -0.2256, 0.4527, -1.8254, 0.6475, 0.8964, 0.5717, -0.239, 0.6983, + -1.3416, 0.2715, -1.4979, -1.1358, 1.632, 0.2493, 0.8266, 0.3424, -1.8059, -0.5762, -1.4362, -0.2706, + -1.0183, -0.462, 0.2323, 0.845, -1.2244, -0.4511, 0.6266, 0.9095, 0.1435, 0.7029, -0.3467, 0.5092, + -0.0828, 0.6253, -0.7899, -1.0957, -0.7149, -0.1072, -0.1967, -2.3416, -1.0686, -0.1313, -0.0181, 0.2438, + -0.8801, 0.1413, 0.1948, -0.1723, 1.2705, 1.0303, 1.2202, 1.3762, -0.2852, 0.6051, 0.2167, -0.2181, + -1.6306, 1.4788, -0.4992, 0.2964, 0.7298, 1.8544, 0.3516, 0.0454, 2.0891, 0.1782, 1.1591, -0.8151, 1.3, + -1.2464, -1.7981, 1.5241, -0.4121, 0.2341, -0.4737, -1.3333, 0.7113, -1.2138, 1.5964, -0.8346, -1.1515, + -0.7923, -1.2609, -1.6375, -0.3576, 0.9413, -0.5694, 0.3954, -0.3587, 0.8002, -0.5982, -1.4301, -0.662, + 0.7324, -0.2959, 0.7237, -1.2077, 0.7937, -0.6705, 0.9287, 0.2754, -0.0261, -0.4618, -0.5646, -1.0389, + 0.5819, 1.5415, -0.2822, -2.0774, 1.2323, 0.3963, -1.1503, -0.5099, 0.5098, -3.3525, 0.4326, 0.7414, + -0.7775, -1.615, 0.4164, 0.71, -0.2429, -0.5656, 0.0863, -0.8254, -3.0038, 2.4033, -0.3398, 0.0922, + 1.7053, 0.1383, -0.7477, -0.8689, 1.8286, 0.851, -1.4793, -0.725, 0.061, 0.9293, -0.6902, -0.0125, + -0.2089, 1.0583, 0.0496, -1.3118, 0.5556, 0.0459, -0.1324, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, + 0.1527, -0.4775, -1.9287, -1.6164, 0.3998, 0.902, -0.0764, -0.4271, -0.3807, 1.3245, 2.4936, 0.3139, + 1.0095, 0.0352, -0.7227, -1.3613, -0.0988, -1.9114, -0.3009, 1.1114, 0.7462, 2.366, -0.8409, -0.6654, + -0.653, -0.1597, 0.8541, 0.238, 1.4392, -0.5644, 0.3158, -0.1664, 0.5428, 0.4245, -0.7901, 0.5665, 0.9044, + -0.5513, -0.7409, -1.8002, 0.9892, 0.3619, -1.4522 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + } + ] + }, + { + "name": "T[1,2,18] T[1,2] T[4,3] T[4,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + -0.848, 0.5266, -1.2944, -0.0243, -0.2354, -0.7087, -0.9647, -0.0991, -0.2994, -0.065, -1.572, -1.3211 + ], + "dims": [1, 2, 18], + "type": "float32" + }, + { + "data": [0, 1], + "dims": [1, 2], + "type": "int64" + }, + { + "data": [1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065], + "dims": [4, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, + -0.4377, 0.537, -1.2929, -0.7267, -0.2107, -0.7115, -0.4666, -0.0261, -0.2965, -0.8469, -1.5749, -1.3217 + ], + "dims": [1, 2, 18], + "type": "float32" + } + ] + }, + { + "name": "T[1,3,2,6] T[1,2] T[4,3] T[4,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.848, 0.5266, -1.2944, -0.0243, -0.2354, -0.7087, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -0.9647, -0.0991, -0.2994, -0.065, -1.572, -1.3211 + ], + "dims": [1, 3, 2, 6], + "type": "float32" + }, + { + "data": [0, 1], + "dims": [1, 2], + "type": "int64" + }, + { + "data": [1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065], + "dims": [4, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.4377, 0.537, -1.2929, -0.7267, -0.2107, -0.7115, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -0.4666, -0.0261, -0.2965, -0.8469, -1.5749, -1.3217 + ], + "dims": [1, 3, 2, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "RotaryEmbedding with interleaved pattern", + "operator": "RotaryEmbedding", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "interleaved", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[1,3,8] T[1] T[8,2] T[8,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.2188, 1.1676, -1.0574, -0.1188, + -0.7396, -1.2425, -0.1752, 0.699, -0.811, 0.6737, -1.1233, -0.0919, -0.6861, 0.7202, 0.1963, 0.6142 + ], + "dims": [1, 3, 8], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 0.5403, 0.9999, -0.4161, 0.9998, -0.99, 0.9996, -0.6536, 0.9992, 0.2837, 0.9988, 0.9602, 0.9982, + 0.7539, 0.9976 + ], + "dims": [8, 2], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.8415, 0.01, 0.9093, 0.02, 0.1411, 0.03, -0.7568, 0.04, -0.9589, 0.05, -0.2794, 0.06, 0.657, + 0.0699 + ], + "dims": [8, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.6411, -0.3948, -1.0561, -0.1294, + 0.646, -1.2937, -0.1822, 0.6972, -0.2751, -1.0178, -1.1212, -0.1143, -0.3694, -0.9235, 0.184, 0.618 + ], + "dims": [1, 3, 8], + "type": "float32" + } + ] + }, + { + "name": "T[1,3,8] Scalar T[8,2] T[8,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.2188, 1.1676, -1.0574, -0.1188, + -0.7396, -1.2425, -0.1752, 0.699, -0.811, 0.6737, -1.1233, -0.0919, -0.6861, 0.7202, 0.1963, 0.6142 + ], + "dims": [1, 3, 8], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 0.5403, 0.9999, -0.4161, 0.9998, -0.99, 0.9996, -0.6536, 0.9992, 0.2837, 0.9988, 0.9602, 0.9982, + 0.7539, 0.9976 + ], + "dims": [8, 2], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.8415, 0.01, 0.9093, 0.02, 0.1411, 0.03, -0.7568, 0.04, -0.9589, 0.05, -0.2794, 0.06, 0.657, + 0.0699 + ], + "dims": [8, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.6411, -0.3948, -1.0561, -0.1294, + 0.646, -1.2937, -0.1822, 0.6972, -0.2751, -1.0178, -1.1212, -0.1143, -0.3694, -0.9235, 0.184, 0.618 + ], + "dims": [1, 3, 8], + "type": "float32" + } + ] + }, + { + "name": "T[1,2,3,4] T[1] T[8,2] T[8,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.811, 0.6737, -1.1233, -0.0919, + -0.132, -0.2751, -0.235, 0.0937, -0.7396, -1.2425, -0.1752, 0.699, -0.6861, 0.7202, 0.1963, 0.6142 + ], + "dims": [1, 2, 3, 4], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 0.5403, 0.9999, -0.4161, 0.9998, -0.99, 0.9996, -0.6536, 0.9992, 0.2837, 0.9988, 0.9602, 0.9982, + 0.7539, 0.9976 + ], + "dims": [8, 2], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.8415, 0.01, 0.9093, 0.02, 0.1411, 0.03, -0.7568, 0.04, -0.9589, 0.05, -0.2794, 0.06, 0.657, + 0.0699 + ], + "dims": [8, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.6411, -0.3948, -1.0561, -0.1294, -0.2751, -1.0178, -1.1212, -0.1143, + -0.132, -0.2751, -0.235, 0.0937, 0.646, -1.2937, -0.1822, 0.6972, -0.3694, -0.9235, 0.184, 0.618 + ], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ] + }, + { + "name": "T[2,8,24] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.4713, -0.954, -0.9229, 0.3027, -0.5708, -0.2363, -1.2713, 0.1137, 0.8112, -1.1659, -0.5824, -0.4419, + -0.7649, 0.7011, -0.4569, -0.5639, -0.5328, -0.6424, 1.0979, 0.8773, 0.5462, 0.0793, 0.2582, 0.8576, + 0.2653, 1.2295, -0.1839, -0.4517, -1.5052, -0.4651, 0.1155, -2.1237, -0.7586, -0.211, 1.1441, -0.6304, + 0.4186, 0.2303, -0.1519, 1.1903, 0.5382, -0.1906, -1.008, 2.3112, -0.222, -0.9655, -0.0099, 1.5198, + 0.7652, -0.641, 0.0365, -0.0452, 1.0593, 0.8929, 1.4856, 0.0038, -1.0865, 1.4794, -0.2417, 0.9428, + -0.6894, -0.6293, 0.2904, 1.5747, -0.4956, 0.9199, -0.2424, 0.1801, 0.7503, -1.4576, 0.6529, -1.134, + -0.6807, -0.0252, -0.3834, 2.7394, 0.1308, 1.1203, -2.1196, -0.9618, 0.197, -0.0972, -0.2764, 0.3332, + -0.4522, 1.1844, 0.3867, -0.6626, -0.9405, 1.8656, 0.5053, -1.2361, 1.2072, 0.1789, -1.1002, 1.0129, + 1.7702, 0.1949, -1.1653, 1.6049, -0.2755, -0.2749, 2.1087, 0.4272, 0.8076, 0.29, -0.0714, 0.8261, -1.1016, + -1.3814, -0.1366, 0.2981, 0.606, -1.4132, 0.0893, -0.1939, 0.2779, 0.391, -0.8906, -0.6489, -1.2496, + 0.3383, -0.0315, -0.7461, 1.151, 0.4445, 0.3203, -0.9031, 0.2727, 0.2609, 2.0968, 1.0974, 0.712, -0.5164, + 0.7415, -0.0031, -0.1568, 0.1533, 0.5487, -0.3357, -0.9064, 1.0546, 0.0542, 1.187, -0.4045, -1.3431, + -0.6094, -1.1105, -0.9631, -0.1137, -0.7219, 0.8582, -1.3443, -0.6684, -1.0227, -1.5929, -0.2622, 0.2264, + 0.0713, 0.1843, -1.3387, -1.6797, 2.3165, 0.1009, 0.1081, -0.9969, -1.4488, 0.6291, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, -0.0261, + -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, 0.5985, -1.0968, + 1.5662, 1.4693, 0.8776, 0.3408, 0.4345, 1.2549, 0.6631, 1.4543, 0.3374, 0.0445, 1.232, 1.4311, -2.0483, + -0.7272, 0.4114, -1.1449, 1.6283, -0.9524, -1.6435, 0.5422, 0.9907, -0.0708, 0.3972, 0.7376, -1.5947, + 1.6138, -0.9586, -0.46, 0.3993, -1.5884, 1.2934, -1.4467, 1.2833, -1.2459, -0.776, 0.3108, -3.3677, + -0.0287, 0.6942, -0.7601, -0.6993, 2.369, 1.3834, -0.5234, 0.3435, 1.0053, 0.1604, -0.956, -1.2641, + 0.2406, 0.4973, 0.9206, -1.9987, -1.1733, -0.4197, -0.0366, -0.672, -1.335, -1.596, -0.1097, 0.6386, + 0.5624, -0.6184, 0.0778, 0.1867, 0.9643, -1.3629, -0.0972, -1.7907, -0.3037, 0.8245, -0.0789, -0.294, + -0.2833, -0.2165, 0.6264, -1.1726, 0.7926, 1.3621, 1.3586, -0.9007, -0.8138, -2.7421, 1.3155, 2.4507, + 0.0507, 0.6305, 1.69, 0.521, -0.3309, 2.063, 1.8026, -0.7859, -0.6802, -1.1003, -0.199, -0.5391, -0.937, + 0.0857, -2.333, -2.0112, 0.7193, -0.1272, -0.9981, -0.1818, 0.3973, -0.9963, 1.4929, -1.0109, 0.4304, + 1.016, -1.459, 0.2682, 1.5658, 0.1762, 0.3038, -0.7491, 0.3052, -1.1534, -0.0478, 0.0021, -0.0665, + -0.8118, 0.131, 0.2171, 0.5485, -0.161, -1.5784, -0.866, 0.7289, -0.4678, 0.1937, 1.1287, -0.5772, + -0.0259, -0.2212, 0.2479, 0.6336, 0.6407, -0.6543, 0.3838, 0.9039, 0.4724, 0.7117, 1.0165, 1.027, 1.1908, + 1.375, -0.085, 0.5517, -1.3842, 0.3703, -0.8806, 0.9336, 0.8362, 0.8105, -1.1566, -0.6813, 0.0294, + -0.1122, 0.562, -0.2884, -2.0803, 0.4684, 0.6009, -1.416 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,8,24] Scalar T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.4713, -0.954, -0.9229, 0.3027, -0.5708, -0.2363, -1.2713, 0.1137, 0.8112, -1.1659, -0.5824, -0.4419, + -0.7649, 0.7011, -0.4569, -0.5639, -0.5328, -0.6424, 1.0979, 0.8773, 0.5462, 0.0793, 0.2582, 0.8576, + 0.2653, 1.2295, -0.1839, -0.4517, -1.5052, -0.4651, 0.1155, -2.1237, -0.7586, -0.211, 1.1441, -0.6304, + 0.4186, 0.2303, -0.1519, 1.1903, 0.5382, -0.1906, -1.008, 2.3112, -0.222, -0.9655, -0.0099, 1.5198, + 0.7652, -0.641, 0.0365, -0.0452, 1.0593, 0.8929, 1.4856, 0.0038, -1.0865, 1.4794, -0.2417, 0.9428, + -0.6894, -0.6293, 0.2904, 1.5747, -0.4956, 0.9199, -0.2424, 0.1801, 0.7503, -1.4576, 0.6529, -1.134, + -0.6807, -0.0252, -0.3834, 2.7394, 0.1308, 1.1203, -2.1196, -0.9618, 0.197, -0.0972, -0.2764, 0.3332, + -0.4522, 1.1844, 0.3867, -0.6626, -0.9405, 1.8656, 0.5053, -1.2361, 1.2072, 0.1789, -1.1002, 1.0129, + 1.7702, 0.1949, -1.1653, 1.6049, -0.2755, -0.2749, 2.1087, 0.4272, 0.8076, 0.29, -0.0714, 0.8261, -1.1016, + -1.3814, -0.1366, 0.2981, 0.606, -1.4132, 0.0893, -0.1939, 0.2779, 0.391, -0.8906, -0.6489, -1.2496, + 0.3383, -0.0315, -0.7461, 1.151, 0.4445, 0.3203, -0.9031, 0.2727, 0.2609, 2.0968, 1.0974, 0.712, -0.5164, + 0.7415, -0.0031, -0.1568, 0.1533, 0.5487, -0.3357, -0.9064, 1.0546, 0.0542, 1.187, -0.4045, -1.3431, + -0.6094, -1.1105, -0.9631, -0.1137, -0.7219, 0.8582, -1.3443, -0.6684, -1.0227, -1.5929, -0.2622, 0.2264, + 0.0713, 0.1843, -1.3387, -1.6797, 2.3165, 0.1009, 0.1081, -0.9969, -1.4488, 0.6291, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, -0.0261, + -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, 0.5985, -1.0968, + 1.5662, 1.4693, 0.8776, 0.3408, 0.4345, 1.2549, 0.6631, 1.4543, 0.3374, 0.0445, 1.232, 1.4311, -2.0483, + -0.7272, 0.4114, -1.1449, 1.6283, -0.9524, -1.6435, 0.5422, 0.9907, -0.0708, 0.3972, 0.7376, -1.5947, + 1.6138, -0.9586, -0.46, 0.3993, -1.5884, 1.2934, -1.4467, 1.2833, -1.2459, -0.776, 0.3108, -3.3677, + -0.0287, 0.6942, -0.7601, -0.6993, 2.369, 1.3834, -0.5234, 0.3435, 1.0053, 0.1604, -0.956, -1.2641, + 0.2406, 0.4973, 0.9206, -1.9987, -1.1733, -0.4197, -0.0366, -0.672, -1.335, -1.596, -0.1097, 0.6386, + 0.5624, -0.6184, 0.0778, 0.1867, 0.9643, -1.3629, -0.0972, -1.7907, -0.3037, 0.8245, -0.0789, -0.294, + -0.2833, -0.2165, 0.6264, -1.1726, 0.7926, 1.3621, 1.3586, -0.9007, -0.8138, -2.7421, 1.3155, 2.4507, + 0.0507, 0.6305, 1.69, 0.521, -0.3309, 2.063, 1.8026, -0.7859, -0.6802, -1.1003, -0.199, -0.5391, -0.937, + 0.0857, -2.333, -2.0112, 0.7193, -0.1272, -0.9981, -0.1818, 0.3973, -0.9963, 1.4929, -1.0109, 0.4304, + 1.016, -1.459, 0.2682, 1.5658, 0.1762, 0.3038, -0.7491, 0.3052, -1.1534, -0.0478, 0.0021, -0.0665, + -0.8118, 0.131, 0.2171, 0.5485, -0.161, -1.5784, -0.866, 0.7289, -0.4678, 0.1937, 1.1287, -0.5772, + -0.0259, -0.2212, 0.2479, 0.6336, 0.6407, -0.6543, 0.3838, 0.9039, 0.4724, 0.7117, 1.0165, 1.027, 1.1908, + 1.375, -0.085, 0.5517, -1.3842, 0.3703, -0.8806, 0.9336, 0.8362, 0.8105, -1.1566, -0.6813, 0.0294, + -0.1122, 0.562, -0.2884, -2.0803, 0.4684, 0.6009, -1.416 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,4,8,6] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 0.3152, 1.7528, -0.765, 1.8299, -0.2784, -0.2719, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, -1.1874, -0.7468, -0.932, -0.8579, -0.9647, -0.0991, -1.019, + 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, -1.9791, + 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, 2.1134, + -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 0.5599, + -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, 0.1036, + -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 0.0352, + -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 1.0122, -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, -0.0489, 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.0311, + -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, 2.5211, + -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.6052, + 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, 0.6207, -0.169, + -0.5816, 1.2632, 0.0695, 1.1862, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.5996, -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -0.2937, 0.9238, -1.2185, 0.4138, 0.5033, 0.9174, -0.4792, 0.6756, + -0.3413, -0.2242, -0.2111, 0.6282, -0.1213, -1.1116, -0.7401, -0.7879, 0.0606, -2.3337, -1.0941, -0.3682, + -0.0163, -0.0645, -0.8101, 0.1415, 0.8238, 0.2262, 1.2912, 0.6488, 1.2114, 1.3569, -0.2852, 0.6051, + 0.2167, -0.2181, -1.6306, 1.4788, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, -1.6105, 0.2979, 1.1537, + -1.5604, 1.2779, -1.2514, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 0.1667, -1.4055, 1.5895, + 1.0838, -0.9077, -0.806, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, 0.0551, 0.5873, -0.5887, + -1.4733, -0.8565, 0.74, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 0.2754, -0.0261, -0.4618, + -0.5646, -1.0389, 0.5819, 1.8698, -0.2635, -2.0799, -0.6313, 0.409, -1.1458, 0.6056, 0.5763, -3.3558, + 0.2836, 0.6909, -0.7631, 1.5646, 0.3338, 0.7105, 0.4683, -0.6179, 0.0818, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -1.7142, -0.5319, -0.8848, 0.6513, 1.0002, -1.4699, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 1.1629, 0.0616, -1.3136, -0.2764, 0.0277, -0.1126, 1.3697, 0.0002, 1.5333, + -1.0556, -0.1254, 0.1527, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 2.4451, -0.35, 1.3289, + -0.6494, 0.3478, 1.0038, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.0901, 0.6106, 2.3603, + 1.3908, -0.7917, -0.6734, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, 0.061, 0.6776, 0.4361, + -0.8052, 0.3955, 0.8988, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -0.4713, -0.954, -0.9229, 0.3027, -0.5708, -0.2363, + 0.2653, 1.2295, -0.1839, -0.4517, -1.5052, -0.4651, 0.7652, -0.641, 0.0365, -0.0452, 1.0593, 0.8929, + -0.6807, -0.0252, -0.3834, 2.7394, 0.1308, 1.1203, 1.7702, 0.1949, -1.1653, 1.6049, -0.2755, -0.2749, + -1.2496, 0.3383, -0.0315, -0.7461, 1.151, 0.4445, -0.4045, -1.3431, -0.6094, -1.1105, -0.9631, -0.1137, + -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, -1.2713, 0.1137, 0.8112, -1.1659, -0.5824, -0.4419, + 0.1155, -2.1237, -0.7586, -0.211, 1.1441, -0.6304, 1.4856, 0.0038, -1.0865, 1.4794, -0.2417, 0.9428, + -2.1196, -0.9618, 0.197, -0.0972, -0.2764, 0.3332, 2.1087, 0.4272, 0.8076, 0.29, -0.0714, 0.8261, 0.3203, + -0.9031, 0.2727, 0.2609, 2.0968, 1.0974, -0.7219, 0.8582, -1.3443, -0.6684, -1.0227, -1.5929, 0.1036, + -0.3514, 0.2421, 0.6463, 0.873, -0.9276, -0.7649, 0.7011, -0.4569, -0.5639, -0.5328, -0.6424, 0.4186, + 0.2303, -0.1519, 1.1903, 0.5382, -0.1906, -0.6894, -0.6293, 0.2904, 1.5747, -0.4956, 0.9199, -0.4522, + 1.1844, 0.3867, -0.6626, -0.9405, 1.8656, -1.1016, -1.3814, -0.1366, 0.2981, 0.606, -1.4132, 0.712, + -0.5164, 0.7415, -0.0031, -0.1568, 0.1533, -0.2622, 0.2264, 0.0713, 0.1843, -1.3387, -1.6797, 1.0311, + -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, 1.0979, 0.8773, 0.5462, 0.0793, 0.2582, 0.8576, -1.008, 2.3112, + -0.222, -0.9655, -0.0099, 1.5198, -0.2424, 0.1801, 0.7503, -1.4576, 0.6529, -1.134, 0.5053, -1.2361, + 1.2072, 0.1789, -1.1002, 1.0129, 0.0893, -0.1939, 0.2779, 0.391, -0.8906, -0.6489, 0.5487, -0.3357, + -0.9064, 1.0546, 0.0542, 1.187, 2.3165, 0.1009, 0.1081, -0.9969, -1.4488, 0.6291, 0.8964, 0.5717, -0.239, + 0.6983, -1.3416, 0.2715, 0.5985, -1.0968, 1.5662, 1.4693, 0.8776, 0.3408, 0.3972, 0.7376, -1.5947, 1.6138, + -0.9586, -0.46, 0.1604, -0.956, -1.2641, 0.2406, 0.4973, 0.9206, 0.8245, -0.0789, -0.294, -0.2833, + -0.2165, 0.6264, -1.1003, -0.199, -0.5391, -0.937, 0.0857, -2.333, -1.1534, -0.0478, 0.0021, -0.0665, + -0.8118, 0.131, 0.4724, 0.7117, 1.0165, 1.027, 1.1908, 1.375, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, + 1.4788, 0.4345, 1.2549, 0.6631, 1.4543, 0.3374, 0.0445, 0.3993, -1.5884, 1.2934, -1.4467, 1.2833, -1.2459, + -1.9987, -1.1733, -0.4197, -0.0366, -0.672, -1.335, -1.1726, 0.7926, 1.3621, 1.3586, -0.9007, -0.8138, + -2.0112, 0.7193, -0.1272, -0.9981, -0.1818, 0.3973, 0.2171, 0.5485, -0.161, -1.5784, -0.866, 0.7289, + -0.085, 0.5517, -1.3842, 0.3703, -0.8806, 0.9336, 0.2754, -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, + 1.232, 1.4311, -2.0483, -0.7272, 0.4114, -1.1449, -0.776, 0.3108, -3.3677, -0.0287, 0.6942, -0.7601, + -1.596, -0.1097, 0.6386, 0.5624, -0.6184, 0.0778, -2.7421, 1.3155, 2.4507, 0.0507, 0.6305, 1.69, -0.9963, + 1.4929, -1.0109, 0.4304, 1.016, -1.459, -0.4678, 0.1937, 1.1287, -0.5772, -0.0259, -0.2212, 0.8362, + 0.8105, -1.1566, -0.6813, 0.0294, -0.1122, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, 1.6283, + -0.9524, -1.6435, 0.5422, 0.9907, -0.0708, -0.6993, 2.369, 1.3834, -0.5234, 0.3435, 1.0053, 0.1867, + 0.9643, -1.3629, -0.0972, -1.7907, -0.3037, 0.521, -0.3309, 2.063, 1.8026, -0.7859, -0.6802, 0.2682, + 1.5658, 0.1762, 0.3038, -0.7491, 0.3052, 0.2479, 0.6336, 0.6407, -0.6543, 0.3838, 0.9039, 0.562, -0.2884, + -2.0803, 0.4684, 0.6009, -1.416 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "RotaryEmbedding with custom rotary dim", + "operator": "RotaryEmbedding", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "rotary_embedding_dim", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "T[1,2,6] T[1,2] T[2,2] T[2,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586 + ], + "dims": [1, 2, 6], + "type": "float32" + }, + { + "data": [0, 1], + "dims": [1, 2], + "type": "int64" + }, + { + "data": [1.0, 1.0, 1.0, 0.5403], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [0.0, 0.0, 0.0, 0.8415], + "dims": [2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.0427, -0.225, -0.8673, -1.5071, -0.4586 + ], + "dims": [1, 2, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/simplified-layer-norm.jsonc b/js/web/test/data/ops/simplified-layer-norm.jsonc new file mode 100644 index 000000000000..346919ab63e4 --- /dev/null +++ b/js/web/test/data/ops/simplified-layer-norm.jsonc @@ -0,0 +1,48 @@ +[ + { + "name": "SimplifiedLayerNormalization", + "operator": "SimplifiedLayerNormalization", + "opset": { "domain": "", "version": 16 }, + "attributes": [ + { + "name": "epsilon", + "data": 1e-5, + "type": "float" + } + ], + "inputShapeDefinitions": "rankOnly", + "cases": [ + { + "name": "default", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [2, 2, 2, 2, 2, 2, 2, 2], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.39605894684791565, 0.7921178936958313, 1.1881768703460693, 1.5842357873916626, 1.9802948236465454, + 2.3763537406921387, 2.7724127769470215, 3.168471574783325, 1.4164010286331177, 1.5737788677215576, + 1.731156826019287, 1.888534665107727, 2.045912504196167, 2.2032904624938965, 2.360668420791626, + 2.5180463790893555 + ], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": null, + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/skip-simplified-layer-norm.jsonc b/js/web/test/data/ops/skip-simplified-layer-norm.jsonc new file mode 100644 index 000000000000..9cf521238224 --- /dev/null +++ b/js/web/test/data/ops/skip-simplified-layer-norm.jsonc @@ -0,0 +1,53 @@ +[ + { + "name": "SkipSimplifiedLayerNormalization", + "operator": "SkipSimplifiedLayerNormalization", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { + "name": "epsilon", + "data": 1e-5, + "type": "float" + } + ], + "inputShapeDefinitions": "rankOnly", + "cases": [ + { + "name": "default", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [2, 2, 2, 2, 2, 2, 2, 2], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.21693046391010284, 0.650791347026825, 1.084652304649353, 1.5185132026672363, 1.9523741006851196, + 2.386234998703003, 2.820096015930176, 3.2539567947387695, 1.3915272951126099, 1.5552364587783813, + 1.7189455032348633, 1.8826546669006348, 2.046363592147827, 2.2100727558135986, 2.37378191947937, + 2.5374910831451416 + ], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": null, + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/tanh.jsonc b/js/web/test/data/ops/tanh.jsonc new file mode 100644 index 000000000000..f7691535bd71 --- /dev/null +++ b/js/web/test/data/ops/tanh.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "tanh with no attributes", + "operator": "Tanh", + "attributes": [], + "cases": [ + { + "name": "T[2,4]", + "inputs": [ + { + "data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc index 047fd6fd7511..990120dd3708 100644 --- a/js/web/test/data/ops/where.jsonc +++ b/js/web/test/data/ops/where.jsonc @@ -168,5 +168,39 @@ ] } ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1", + "inputs": [ + { + "data": [true, false], + "dims": [1, 1, 2, 1], + "type": "bool" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 4], + "type": "float32" + }, + { + "data": [5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/e2e/browser-test-webgl.js b/js/web/test/e2e/browser-test-webgl.js index e503f38ae573..974c81d064c8 100644 --- a/js/web/test/e2e/browser-test-webgl.js +++ b/js/web/test/e2e/browser-test-webgl.js @@ -6,3 +6,16 @@ it('Browser E2E testing - WebGL backend', async function() { await testFunction(ort, {executionProviders: ['webgl']}); }); + +it('Browser E2E testing - invalid buffer', async () => { + try { + await ort.InferenceSession.create( + new Uint8Array(Array.from({length: 100}, () => 42)), {executionProviders: ['webgl']}); + + // Should not reach here. + assert(false); + } catch (e) { + assert(e.message.includes('as ONNX format')); + assert(e.message.includes('as ORT format')); + } +}); diff --git a/js/web/test/e2e/browser-test-webgpu-external-data.js b/js/web/test/e2e/browser-test-webgpu-external-data.js new file mode 100644 index 000000000000..8fb0b4d6ec54 --- /dev/null +++ b/js/web/test/e2e/browser-test-webgpu-external-data.js @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +it('Browser E2E testing - WebGPU backend with external data', async function() { + const session = await ort.InferenceSession.create('./model_with_orig_ext_data.onnx', { + executionProviders: ['webgpu'], + externalData: [{data: './model_with_orig_ext_data.bin', path: 'model_with_orig_ext_data.bin'}] + }); + + const fetches = await session.run({X: new ort.Tensor('float32', [1, 1], [1, 2])}); + + const Y = fetches.Y; + + assert(Y instanceof ort.Tensor); + assert(Y.dims.length === 2 && Y.dims[0] === 2 && Y.dims[1] === 3); + assert(Y.data[0] === 1); + assert(Y.data[1] === 1); + assert(Y.data[2] === 0); + assert(Y.data[3] === 0); + assert(Y.data[4] === 0); + assert(Y.data[5] === 0); +}); diff --git a/js/web/test/e2e/karma.conf.js b/js/web/test/e2e/karma.conf.js index b7ff408fa29c..b541d9d12011 100644 --- a/js/web/test/e2e/karma.conf.js +++ b/js/web/test/e2e/karma.conf.js @@ -15,6 +15,8 @@ if (typeof USER_DATA !== 'string') { throw new Error('flag --user-data= is required'); } +const flags = ['--ignore-gpu-blocklist', '--gpu-vendor-id=0x10de']; + module.exports = function(config) { const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; config.set({ @@ -25,10 +27,14 @@ module.exports = function(config) { {pattern: TEST_MAIN}, {pattern: './node_modules/onnxruntime-web/dist/*.wasm', included: false, nocache: true}, {pattern: './model.onnx', included: false}, + {pattern: './model_with_orig_ext_data.onnx', included: false}, + {pattern: './model_with_orig_ext_data.bin', included: false}, ], plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], proxies: { '/model.onnx': '/base/model.onnx', + '/model_with_orig_ext_data.onnx': '/base/model_with_orig_ext_data.onnx', + '/model_with_orig_ext_data.bin': '/base/model_with_orig_ext_data.bin', '/test-wasm-path-override/ort-wasm.wasm': '/base/node_modules/onnxruntime-web/dist/ort-wasm.wasm', '/test-wasm-path-override/renamed.wasm': '/base/node_modules/onnxruntime-web/dist/ort-wasm.wasm', }, @@ -43,10 +49,11 @@ module.exports = function(config) { hostname: 'localhost', browsers: [], customLaunchers: { - Chrome_default: {base: 'ChromeHeadless', chromeDataDir: USER_DATA}, + Chrome_default: {base: 'Chrome', flags, chromeDataDir: USER_DATA}, Chrome_no_threads: { - base: 'ChromeHeadless', + base: 'Chrome', chromeDataDir: USER_DATA, + flags // TODO: no-thread flags }, Edge_default: {base: 'Edge', edgeDataDir: USER_DATA} diff --git a/js/web/test/e2e/model_with_orig_ext_data.bin b/js/web/test/e2e/model_with_orig_ext_data.bin new file mode 100644 index 000000000000..d69e6beeff85 Binary files /dev/null and b/js/web/test/e2e/model_with_orig_ext_data.bin differ diff --git a/js/web/test/e2e/model_with_orig_ext_data.onnx b/js/web/test/e2e/model_with_orig_ext_data.onnx new file mode 100644 index 000000000000..6f9cce0bc5b4 --- /dev/null +++ b/js/web/test/e2e/model_with_orig_ext_data.onnx @@ -0,0 +1,19 @@ +  onnx-example: +: +X +model_with_orig_ext_dataY"Pad* +mode"constant +test-model*JBmodel_with_orig_ext_dataj( +locationmodel_with_orig_ext_data.binpZ +X +  + +Z& +model_with_orig_ext_data + + +b +Y +  + +B \ No newline at end of file diff --git a/js/web/test/e2e/run.js b/js/web/test/e2e/run.js index 2776f6dff46a..46c04792f1b9 100644 --- a/js/web/test/e2e/run.js +++ b/js/web/test/e2e/run.js @@ -119,6 +119,7 @@ async function testAllBrowserCases({hostInKarma}) { await runKarma({hostInKarma, main: './browser-test-wasm-path-override-prefix.js'}); await runKarma({hostInKarma, main: './browser-test-wasm-path-override-prefix.js', ortMain: 'ort.wasm.min.js'}); await runKarma({hostInKarma, main: './browser-test-wasm-image-tensor-image.js'}); + await runKarma({hostInKarma, main: './browser-test-webgpu-external-data.js', ortMain: 'ort.webgpu.min.js'}); } async function runKarma({hostInKarma, main, browser = BROWSER, ortMain = 'ort.min.js'}) { diff --git a/js/web/test/e2e/simple-http-server.js b/js/web/test/e2e/simple-http-server.js index 1244aaddafd2..6a6162855df8 100644 --- a/js/web/test/e2e/simple-http-server.js +++ b/js/web/test/e2e/simple-http-server.js @@ -16,6 +16,7 @@ const validRequests = { '/dist/ort-wasm-simd.wasm': ['dist/ort-wasm-simd.wasm', 'application/wasm'], '/dist/ort-wasm-threaded.wasm': ['dist/ort-wasm-threaded.wasm', 'application/wasm'], '/dist/ort-wasm-simd-threaded.wasm': ['dist/ort-wasm-simd-threaded.wasm', 'application/wasm'], + '/dist/ort-wasm-simd.jsep.wasm': ['dist/ort-wasm-simd.jsep.wasm', 'application/wasm'], // proxied .wasm files: '/test-wasm-path-override/ort-wasm.wasm': ['dist/ort-wasm.wasm', 'application/wasm'], @@ -25,6 +26,7 @@ const validRequests = { '/dist/ort.min.js': ['dist/ort.min.js', 'text/javascript'], '/dist/ort.js': ['dist/ort.js', 'text/javascript'], '/dist/ort.webgl.min.js': ['dist/ort.webgl.min.js', 'text/javascript'], + '/dist/ort.webgpu.min.js': ['dist/ort.webgpu.min.js', 'text/javascript'], '/dist/ort.wasm.min.js': ['dist/ort.wasm.min.js', 'text/javascript'], '/dist/ort.wasm-core.min.js': ['dist/ort.wasm-core.min.js', 'text/javascript'], }; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 594ce9feed31..811e3659b598 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -472,11 +472,11 @@ // "test_cumsum_2d_axis_0", // "test_cumsum_2d_axis_1", // "test_cumsum_2d_negative_axis", - // "test_depthtospace_crd_mode_example", - // "test_depthtospace_crd_mode", - // "test_depthtospace_dcr_mode", - // "test_depthtospace_example", - // "test_depthtospace", + "test_depthtospace_crd_mode_example", + "test_depthtospace_crd_mode", + "test_depthtospace_dcr_mode", + "test_depthtospace_example", + "test_depthtospace", // // "test_dequantizelinear_axis", // // "test_dequantizelinear", // // "test_det_2d", @@ -553,7 +553,7 @@ "test_gemm_broadcast", "test_gemm_default_matrix_bias", "test_gemm_default_no_bias", - "test_gemm_default_scalar_bias", + // "test_gemm_default_scalar_bias", "test_gemm_default_single_elem_vector_bias", "test_gemm_default_vector_bias", "test_gemm_default_zero_bias", @@ -597,9 +597,9 @@ // // "test_hardmax_example", // // "test_hardmax_negative_axis", // // "test_hardmax_one_hot", - // // "test_hardsigmoid_default", - // // "test_hardsigmoid_example", - // // "test_hardsigmoid", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", "test_if", @@ -637,9 +637,9 @@ "test_layer_normalization_4d_axis_negative_1", // // "test_layer_normalization_4d_axis_negative_2_expanded", "test_layer_normalization_4d_axis_negative_2", - "test_layer_normalization_4d_axis_negative_3_expanded", + // "test_layer_normalization_4d_axis_negative_3_expanded", "test_layer_normalization_4d_axis_negative_3", - "test_layer_normalization_4d_axis_negative_4_expanded", + // "test_layer_normalization_4d_axis_negative_4_expanded", "test_layer_normalization_4d_axis_negative_4", "test_layer_normalization_4d_axis0_expanded", "test_layer_normalization_4d_axis0", @@ -1231,7 +1231,7 @@ "test_split_variable_parts_1d", "test_split_variable_parts_2d", "test_split_variable_parts_default_axis", - // // "test_split_zero_size_splits", + "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", "test_squeeze_negative_axes", @@ -1334,6 +1334,7 @@ "acos.jsonc", "add.jsonc", "add_int32.jsonc", + "add_zero-sized.jsonc", //"and.jsonc", "asin.jsonc", "attention.jsonc", @@ -1343,16 +1344,19 @@ "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", + "concat_zero-sized.jsonc", "cast.jsonc", "conv.jsonc", "cos.jsonc", "div.jsonc", "div_int32.jsonc", - //"depth-to-space.jsonc", + "depth-to-space.jsonc", "equal.jsonc", "exp.jsonc", "expand.jsonc", + "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", @@ -1361,6 +1365,7 @@ "less.jsonc", "log.jsonc", "matmul.jsonc", + "matmulnbits.jsonc", "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", @@ -1380,7 +1385,10 @@ "pow_int32.jsonc", "pow-big-number.jsonc", "reshape.jsonc", + "rotary-embedding.jsonc", + "simplified-layer-norm.jsonc", "skip-layer-norm.jsonc", + "skip-simplified-layer-norm.jsonc", "slice.jsonc", //"softmax.jsonc", "sin.jsonc", @@ -1389,6 +1397,7 @@ "sub.jsonc", "sub_int32.jsonc", "tan.jsonc", + "tanh.jsonc", "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc", @@ -1501,99 +1510,1046 @@ "webnn": { "onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"], "node": [ - // Check in node tests that have native Wasm implementations. - // (i.e.) not tests that rely on the fallback cpu implementations. - // Use the 'cpu' level of node tests to test those implementations. + "test_abs", + // "test_acos_example", + // "test_acos", + // "test_acosh_example", + // "test_acosh", + // // "test_adagrad_multiple", + // // "test_adagrad", + // // "test_adam_multiple", + // // "test_adam", "test_add_bcast", + // "test_add_uint8", "test_add", - "test_sub_bcast", - "test_sub_example", - "test_sub", - "test_mul_bcast", - "test_mul_example", - "test_mul", - "test_div_bcast", - "test_div_example", - "test_div", - "test_xor_bcast3v1d", - "test_xor_bcast3v2d", - "test_xor_bcast4v2d", - "test_xor_bcast4v3d", - "test_xor_bcast4v4d", - "test_xor2d", - "test_xor3d", - "test_xor4d", - "test_or_bcast3v1d", - "test_or_bcast3v2d", - "test_or_bcast4v2d", - "test_or_bcast4v3d", - "test_or_bcast4v4d", - "test_and_bcast3v1d", - "test_and_bcast3v2d", - "test_and_bcast4v2d", - "test_and_bcast4v3d", - "test_and_bcast4v4d", - "test_and2d", - "test_and3d", - "test_and4d", - "test_prelu_broadcast", - "test_prelu_example", + // "test_and_bcast3v1d", + // "test_and_bcast3v2d", + // "test_and_bcast4v2d", + // "test_and_bcast4v3d", + // "test_and_bcast4v4d", + // "test_and2d", + // "test_and3d", + // "test_and4d", + "test_argmax_default_axis_example_select_last_index", + "test_argmax_default_axis_example", + "test_argmax_default_axis_random_select_last_index", + "test_argmax_default_axis_random", + "test_argmax_keepdims_example_select_last_index", + "test_argmax_keepdims_example", + "test_argmax_keepdims_random_select_last_index", + "test_argmax_keepdims_random", + "test_argmax_negative_axis_keepdims_example_select_last_index", + "test_argmax_negative_axis_keepdims_example", + "test_argmax_negative_axis_keepdims_random_select_last_index", + "test_argmax_negative_axis_keepdims_random", + "test_argmax_no_keepdims_example_select_last_index", + "test_argmax_no_keepdims_example", + "test_argmax_no_keepdims_random_select_last_index", + "test_argmax_no_keepdims_random", + "test_argmin_default_axis_example_select_last_index", + "test_argmin_default_axis_example", + "test_argmin_default_axis_random_select_last_index", + "test_argmin_default_axis_random", + "test_argmin_keepdims_example_select_last_index", + "test_argmin_keepdims_example", + "test_argmin_keepdims_random_select_last_index", + "test_argmin_keepdims_random", + "test_argmin_negative_axis_keepdims_example_select_last_index", + "test_argmin_negative_axis_keepdims_example", + "test_argmin_negative_axis_keepdims_random_select_last_index", + "test_argmin_negative_axis_keepdims_random", + "test_argmin_no_keepdims_example_select_last_index", + "test_argmin_no_keepdims_example", + "test_argmin_no_keepdims_random_select_last_index", + "test_argmin_no_keepdims_random", + // "test_asin_example", + // "test_asin", + // "test_asinh_example", + // "test_asinh", + // "test_atan_example", + // "test_atan", + // "test_atanh_example", + // "test_atanh", + // "test_averagepool_1d_default", + // "test_averagepool_2d_ceil", + "test_averagepool_2d_default", + "test_averagepool_2d_pads_count_include_pad", + "test_averagepool_2d_pads", + "test_averagepool_2d_precomputed_pads_count_include_pad", + "test_averagepool_2d_precomputed_pads", + "test_averagepool_2d_precomputed_same_upper", + "test_averagepool_2d_precomputed_strides", + "test_averagepool_2d_same_lower", + "test_averagepool_2d_same_upper", + "test_averagepool_2d_strides", + // "test_averagepool_3d_default", "test_basic_conv_with_padding", "test_basic_conv_without_padding", + // "test_basic_convinteger", + "test_batchnorm_epsilon_training_mode", "test_batchnorm_epsilon", + "test_batchnorm_example_training_mode", "test_batchnorm_example", - "opset{10,11,12}/test_cast_STRING_to_FLOAT", - "test_clip_splitbounds", - "test_clip_outbounds", - "test_clip_inbounds", - "test_clip_example", - "test_clip_default_min", - "test_clip_default_max", + // // "test_bernoulli_double_expanded", + // // "test_bernoulli_double", + // // "test_bernoulli_expanded", + // // "test_bernoulli_seed_expanded", + // // "test_bernoulli_seed", + // // "test_bernoulli", + // // "test_bitshift_left_uint16", + // // "test_bitshift_left_uint32", + // // "test_bitshift_left_uint64", + // // "test_bitshift_left_uint8", + // // "test_bitshift_right_uint16", + // // "test_bitshift_right_uint32", + // // "test_bitshift_right_uint64", + // // "test_bitshift_right_uint8", + // // "test_blackmanwindow_expanded", + // // "test_blackmanwindow_symmetric_expanded", + // // "test_blackmanwindow_symmetric", + // // "test_blackmanwindow", + // // "test_cast_BFLOAT16_to_FLOAT", + "test_cast_DOUBLE_to_FLOAT", + // "test_cast_DOUBLE_to_FLOAT16", + // // "test_cast_FLOAT_to_BFLOAT16", + "test_cast_FLOAT_to_DOUBLE", + // // "test_cast_FLOAT_to_FLOAT16", + // // "test_cast_FLOAT_to_STRING", + // "test_cast_FLOAT16_to_DOUBLE", + // "test_cast_FLOAT16_to_FLOAT", + // // "test_cast_STRING_to_FLOAT", + // // "test_castlike_BFLOAT16_to_FLOAT_expanded", + // // "test_castlike_BFLOAT16_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT_expanded", + // // "test_castlike_DOUBLE_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT16_expanded", + // // "test_castlike_DOUBLE_to_FLOAT16", + // // "test_castlike_FLOAT_to_BFLOAT16_expanded", + // // "test_castlike_FLOAT_to_BFLOAT16", + // // "test_castlike_FLOAT_to_DOUBLE_expanded", + // // "test_castlike_FLOAT_to_DOUBLE", + // // "test_castlike_FLOAT_to_FLOAT16_expanded", + // // "test_castlike_FLOAT_to_FLOAT16", + // // "test_castlike_FLOAT_to_STRING_expanded", + // // "test_castlike_FLOAT_to_STRING", + // // "test_castlike_FLOAT16_to_DOUBLE_expanded", + // // "test_castlike_FLOAT16_to_DOUBLE", + // // "test_castlike_FLOAT16_to_FLOAT_expanded", + // // "test_castlike_FLOAT16_to_FLOAT", + // // "test_castlike_STRING_to_FLOAT_expanded", + // // "test_castlike_STRING_to_FLOAT", + "test_ceil_example", + "test_ceil", + // "test_celu_expanded", + // "test_celu", "test_clip_default_inbounds", + "test_clip_default_int8_inbounds", + "test_clip_default_int8_max", + "test_clip_default_int8_min", + "test_clip_default_max", + "test_clip_default_min", + "test_clip_example", + "test_clip_inbounds", + "test_clip_outbounds", + "test_clip_splitbounds", "test_clip", + // // "test_compress_0", + // // "test_compress_1", + // // "test_compress_default_axis", + // // "test_compress_negative_axis", + "test_concat_1d_axis_0", + "test_concat_1d_axis_negative_1", + "test_concat_2d_axis_0", + "test_concat_2d_axis_1", + "test_concat_2d_axis_negative_1", + "test_concat_2d_axis_negative_2", + "test_concat_3d_axis_0", + "test_concat_3d_axis_1", + "test_concat_3d_axis_2", + "test_concat_3d_axis_negative_1", + "test_concat_3d_axis_negative_2", + "test_concat_3d_axis_negative_3", + "test_conv_with_autopad_same", "test_conv_with_strides_and_asymmetric_padding", "test_conv_with_strides_no_padding", "test_conv_with_strides_padding", - "test_gemm_nobroadcast", - "test_gemm_broadcast", - "test_matmul_2d", - "test_matmul_3d", - "test_matmul_4d", - "test_softmax_axis_0", - "test_softmax_axis_1", - "test_softmax_axis_2", - "test_softmax_default_axis", - "test_softmax_example", - "test_softmax_large_number", - "test_sum_example", - "test_sum_one_input", - "test_sum_two_inputs", - "test_averagepool_1d_default", - "test_averagepool_2d_default", - "test_averagepool_2d_pads", - "test_averagepool_2d_precomputed_pads", - "test_averagepool_2d_precomputed_same_upper", - "test_averagepool_2d_precomputed_strides", - "test_averagepool_2d_same_upper", - "test_averagepool_2d_same_lower", - "test_averagepool_2d_strides", - "test_averagepool_3d_default", - "test_maxpool_1d_default", - "test_maxpool_2d_default", - "test_maxpool_2d_pads", - "test_maxpool_2d_precomputed_pads", - "test_maxpool_2d_precomputed_same_upper", - "test_maxpool_2d_precomputed_strides", - "test_maxpool_2d_same_lower", - "test_maxpool_2d_same_upper", - "test_maxpool_2d_strides", - "test_maxpool_3d_default", - "test_globalaveragepool_precomputed", - "test_globalaveragepool", - "test_globalmaxpool_precomputed", - "test_globalmaxpool", - "test_instancenorm_epsilon", - "test_instancenorm_example" + // // "test_convinteger_with_padding", + // // "test_convinteger_without_padding", + "test_convtranspose_1d", + // // "test_convtranspose_3d", + // "test_convtranspose_autopad_same", + "test_convtranspose_dilations", + "test_convtranspose_kernel_shape", + "opset{9,17}/test_convtranspose_output_shape", + "test_convtranspose_pad", + "test_convtranspose_pads", + "test_convtranspose_with_kernel", + "test_convtranspose", + "test_cos_example", + "test_cos", + // "test_cosh_example", + // "test_cosh", + // "test_cumsum_1d_exclusive", + // "test_cumsum_1d_reverse_exclusive", + // "test_cumsum_1d_reverse", + // "test_cumsum_1d", + // "test_cumsum_2d_axis_0", + // "test_cumsum_2d_axis_1", + // "test_cumsum_2d_negative_axis", + // "test_depthtospace_crd_mode_example", + // "test_depthtospace_crd_mode", + // "test_depthtospace_dcr_mode", + // "test_depthtospace_example", + // "test_depthtospace", + // // "test_dequantizelinear_axis", + // // "test_dequantizelinear", + // // "test_det_2d", + // // "test_det_nd", + // // "test_dft_axis", + // // "test_dft_inverse", + // // "test_dft", + "test_div_bcast", + "test_div_example", + // "test_div_uint8", + "test_div", + // // "test_dropout_default_mask_ratio", + // // "test_dropout_default_mask", + // // "test_dropout_default_old", + // // "test_dropout_default_ratio", + // // "test_dropout_default", + // // "test_dropout_random_old", + // // "test_dropout_random", + // // "test_dynamic_slice_default_axes", + // // "test_dynamic_slice_end_out_of_bounds", + // // "test_dynamic_slice_neg", + // // "test_dynamic_slice_start_out_of_bounds", + // // "test_dynamic_slice", + // // "test_dynamicquantizelinear_expanded", + // // "test_dynamicquantizelinear_max_adjusted_expanded", + // // "test_dynamicquantizelinear_max_adjusted", + // // "test_dynamicquantizelinear_min_adjusted_expanded", + // // "test_dynamicquantizelinear_min_adjusted", + // // "test_dynamicquantizelinear", + // "test_edge_pad", + // "test_einsum_batch_diagonal", + // "test_einsum_batch_matmul", + // "test_einsum_inner_prod", + // "test_einsum_sum", + // "test_einsum_transpose", + "test_elu_default", + "test_elu_example", + "test_elu", + "test_equal_bcast", + "test_equal", + // "test_erf", + "test_exp_example", + "test_exp", + // "test_expand_dim_changed", + // "test_expand_dim_unchanged", + // "test_eyelike_populate_off_main_diagonal", + // "test_eyelike_with_dtype", + // "test_eyelike_without_dtype", + "test_flatten_axis0", + "test_flatten_axis1", + "test_flatten_axis2", + "test_flatten_axis3", + "test_flatten_default_axis", + "test_flatten_negative_axis1", + "test_flatten_negative_axis2", + "test_flatten_negative_axis3", + "test_flatten_negative_axis4", + "test_floor_example", + "test_floor", + "test_gather_0", + "test_gather_1", + "test_gather_2d_indices", + "test_gather_negative_indices", + "test_gather_elements_0", + "test_gather_elements_1", + "test_gather_elements_negative_indices", + // "test_gathernd_example_float32", + // "test_gathernd_example_int32_batch_dim1", + // "test_gathernd_example_int32", + "test_gemm_all_attributes", + "test_gemm_alpha", + "test_gemm_beta", + "test_gemm_broadcast", + "test_gemm_default_matrix_bias", + "test_gemm_default_no_bias", + // "test_gemm_default_scalar_bias", + "test_gemm_default_single_elem_vector_bias", + "test_gemm_default_vector_bias", + "test_gemm_default_zero_bias", + "test_gemm_nobroadcast", + "test_gemm_transposeA", + "test_gemm_transposeB", + "test_globalaveragepool_precomputed", + "test_globalaveragepool", + "test_globalmaxpool_precomputed", + "test_globalmaxpool", + "test_greater_bcast", + "test_greater_equal_bcast_expanded", + "test_greater_equal_bcast", + "test_greater_equal_expanded", + "test_greater_equal", + "test_greater", + // // "test_gridsample_aligncorners_true", + // // "test_gridsample_bicubic", + // // "test_gridsample_bilinear", + // // "test_gridsample_border_padding", + // // "test_gridsample_nearest", + // // "test_gridsample_reflection_padding", + // // "test_gridsample_zeros_padding", + // // "test_gridsample", + // // "test_gru_batchwise", + // // "test_gru_defaults", + // // "test_gru_seq_length", + // // "test_gru_with_initial_bias", + // // "test_hammingwindow_expanded", + // // "test_hammingwindow_symmetric_expanded", + // // "test_hammingwindow_symmetric", + // // "test_hammingwindow", + // // "test_hannwindow_expanded", + // // "test_hannwindow_symmetric_expanded", + // // "test_hannwindow_symmetric", + // // "test_hannwindow", + // // "test_hardmax_axis_0", + // // "test_hardmax_axis_1", + // // "test_hardmax_axis_2", + // // "test_hardmax_default_axis", + // // "test_hardmax_example", + // // "test_hardmax_negative_axis", + // // "test_hardmax_one_hot", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", + "test_hardswish_expanded", + "test_hardswish", + // "test_if", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", + "test_instancenorm_epsilon", + "test_instancenorm_example", + // "test_isinf_negative", + // "test_isinf_positive", + // "test_isinf", + // "test_isnan", + // "test_layer_normalization_2d_axis_negative_1_expanded", + "test_layer_normalization_2d_axis_negative_1", + // "test_layer_normalization_2d_axis_negative_2_expanded", + "test_layer_normalization_2d_axis_negative_2", + // "test_layer_normalization_2d_axis0_expanded", + "test_layer_normalization_2d_axis0", + // "test_layer_normalization_2d_axis1_expanded", + "test_layer_normalization_2d_axis1", + // "test_layer_normalization_3d_axis_negative_1_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_1_epsilon", + // "test_layer_normalization_3d_axis_negative_2_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_2_epsilon", + // "test_layer_normalization_3d_axis_negative_3_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_3_epsilon", + // "test_layer_normalization_3d_axis0_epsilon_expanded", + "test_layer_normalization_3d_axis0_epsilon", + // "test_layer_normalization_3d_axis1_epsilon_expanded", + "test_layer_normalization_3d_axis1_epsilon", + // "test_layer_normalization_3d_axis2_epsilon_expanded", + "test_layer_normalization_3d_axis2_epsilon", + // "test_layer_normalization_4d_axis_negative_1_expanded", + "test_layer_normalization_4d_axis_negative_1", + // "test_layer_normalization_4d_axis_negative_2_expanded", + "test_layer_normalization_4d_axis_negative_2", + // "test_layer_normalization_4d_axis_negative_3_expanded", + "test_layer_normalization_4d_axis_negative_3", + // "test_layer_normalization_4d_axis_negative_4_expanded", + "test_layer_normalization_4d_axis_negative_4", + // "test_layer_normalization_4d_axis0_expanded", + "test_layer_normalization_4d_axis0", + // "test_layer_normalization_4d_axis1_expanded", + "test_layer_normalization_4d_axis1", + // "test_layer_normalization_4d_axis2_expanded", + "test_layer_normalization_4d_axis2", + // "test_layer_normalization_4d_axis3_expanded", + "test_layer_normalization_4d_axis3", + // "test_layer_normalization_default_axis_expanded", + "test_layer_normalization_default_axis", + "test_leakyrelu_default", + "test_leakyrelu_example", + "test_leakyrelu", + "test_less_bcast", + "test_less_equal_bcast_expanded", + "test_less_equal_bcast", + "test_less_equal_expanded", + "test_less_equal", + "test_less", + "test_log_example", + "test_log", + // // "test_logsoftmax_axis_0_expanded", + // // "test_logsoftmax_axis_0", + // // "test_logsoftmax_axis_1_expanded", + // // "test_logsoftmax_axis_1", + // // "test_logsoftmax_axis_2_expanded", + // // "test_logsoftmax_axis_2", + // // "test_logsoftmax_default_axis_expanded", + // // "test_logsoftmax_default_axis", + // // "test_logsoftmax_example_1_expanded", + // // "test_logsoftmax_example_1", + // // "test_logsoftmax_large_number_expanded", + // // "test_logsoftmax_large_number", + // // "test_logsoftmax_negative_axis_expanded", + // // "test_logsoftmax_negative_axis", + // "test_lrn_default", + // "test_lrn", + // // "test_lstm_batchwise", + // // "test_lstm_defaults", + // // "test_lstm_with_initial_bias", + // // "test_lstm_with_peepholes", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", + // // "test_matmulinteger", + "test_max_example", + // "test_max_float16", + "test_max_float32", + "test_max_float64", + // "test_max_int16", + // "test_max_int32", + // "test_max_int64", + // "test_max_int8", + "test_max_one_input", + "test_max_two_inputs", + // "test_max_uint16", + // "test_max_uint32", + // "test_max_uint64", + // "test_max_uint8", + // "test_maxpool_1d_default", + // "test_maxpool_2d_ceil", + "test_maxpool_2d_default", + "test_maxpool_2d_dilations", + "test_maxpool_2d_pads", + "test_maxpool_2d_precomputed_pads", + "test_maxpool_2d_precomputed_same_upper", + "test_maxpool_2d_precomputed_strides", + "test_maxpool_2d_same_lower", + "test_maxpool_2d_same_upper", + "test_maxpool_2d_strides", + // "test_maxpool_2d_uint8", + // "test_maxpool_3d_default", + // "test_maxpool_with_argmax_2d_precomputed_pads", + // "test_maxpool_with_argmax_2d_precomputed_strides", + // // "test_maxunpool_export_with_output_shape", + // // "test_maxunpool_export_without_output_shape", + // // "test_mean_example", + // // "test_mean_one_input", + // // "test_mean_two_inputs", + // // "test_melweightmatrix", + "test_min_example", + // "test_min_float16", + "test_min_float32", + "test_min_float64", + // "test_min_int16", + // "test_min_int32", + // "test_min_int64", + // "test_min_int8", + "test_min_one_input", + "test_min_two_inputs", + // "test_min_uint16", + // "test_min_uint32", + // "test_min_uint64", + // "test_min_uint8", + // "test_mod_bcast", + // "test_mod_broadcast", + // "test_mod_float_mixed_sign_example", + // "test_mod_fmod_mixed_sign_example", + // "test_mod_int64_fmod", + // "test_mod_int64_mixed_sign_example", + // "test_mod_mixed_sign_float16", + // "test_mod_mixed_sign_float32", + // "test_mod_mixed_sign_float64", + // "test_mod_mixed_sign_int16", + // "test_mod_mixed_sign_int32", + // "test_mod_mixed_sign_int64", + // "test_mod_mixed_sign_int8", + // "test_mod_uint16", + // "test_mod_uint32", + // "test_mod_uint64", + // "test_mod_uint8", + // // "test_momentum_multiple", + // // "test_momentum", + "test_mul_bcast", + "test_mul_example", + // "test_mul_uint8", + "test_mul", + // "test_mvn_expanded", + // "test_mvn", + "test_neg_example", + "test_neg", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NC_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NC", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // // "test_nesterov_momentum", + // // "test_nllloss_NC_expanded", + // // "test_nllloss_NC", + // // "test_nllloss_NCd1_expanded", + // // "test_nllloss_NCd1_ii_expanded", + // // "test_nllloss_NCd1_ii", + // // "test_nllloss_NCd1_mean_weight_negative_ii_expanded", + // // "test_nllloss_NCd1_mean_weight_negative_ii", + // // "test_nllloss_NCd1_weight_expanded", + // // "test_nllloss_NCd1_weight_ii_expanded", + // // "test_nllloss_NCd1_weight_ii", + // // "test_nllloss_NCd1_weight", + // // "test_nllloss_NCd1", + // // "test_nllloss_NCd1d2_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii", + // // "test_nllloss_NCd1d2_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_reduction_mean", + // // "test_nllloss_NCd1d2_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight", + // // "test_nllloss_NCd1d2", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight", + // "test_nonmaxsuppression_center_point_box_format", + // "test_nonmaxsuppression_flipped_coordinates", + // "test_nonmaxsuppression_identical_boxes", + // "test_nonmaxsuppression_limit_output_size", + // "test_nonmaxsuppression_single_box", + // "test_nonmaxsuppression_suppress_by_IOU_and_scores", + // "test_nonmaxsuppression_suppress_by_IOU", + // "test_nonmaxsuppression_two_batches", + // "test_nonmaxsuppression_two_classes", + // "test_nonzero_example", + "test_not_2d", + "test_not_3d", + "test_not_4d", + // // "test_onehot_negative_indices", + // // "test_onehot_with_axis", + // // "test_onehot_with_negative_axis", + // // "test_onehot_without_axis", + // // "test_optional_get_element_sequence", + // // "test_optional_get_element", + // // "test_optional_has_element_empty", + // // "test_optional_has_element", + // "test_or_bcast3v1d", + // "test_or_bcast3v2d", + // "test_or_bcast4v2d", + // "test_or_bcast4v3d", + // "test_or_bcast4v4d", + // "test_or2d", + // "test_or3d", + // "test_or4d", + "test_pow_bcast_array", + "test_pow_bcast_scalar", + "test_pow_example", + // "test_pow_types_float", + // "test_pow_types_float32_int32", + // "test_pow_types_float32_int64", + // "test_pow_types_float32_uint32", + // "test_pow_types_float32_uint64", + // "test_pow_types_int", + // "test_pow_types_int32_float32", + // "test_pow_types_int32_int32", + // "test_pow_types_int64_float32", + // "test_pow_types_int64_int64", + "test_pow", + "test_prelu_broadcast", + "test_prelu_example", + // // "test_qlinearconv", + // // "test_qlinearmatmul_2D", + // // "test_qlinearmatmul_3D", + // // "test_quantizelinear_axis", + // // "test_quantizelinear", + // "test_range_float_type_positive_delta_expanded", + // "test_range_float_type_positive_delta", + // "test_range_int32_type_negative_delta_expanded", + // "test_range_int32_type_negative_delta", + "test_reciprocal_example", + "test_reciprocal", + "test_reduce_l1_default_axes_keepdims_example", + "test_reduce_l1_default_axes_keepdims_random", + "test_reduce_l1_do_not_keepdims_example", + "test_reduce_l1_do_not_keepdims_random", + "test_reduce_l1_keep_dims_example", + "test_reduce_l1_keep_dims_random", + "test_reduce_l1_negative_axes_keep_dims_example", + "test_reduce_l1_negative_axes_keep_dims_random", + "test_reduce_l2_default_axes_keepdims_example", + "test_reduce_l2_default_axes_keepdims_random", + "test_reduce_l2_do_not_keepdims_example", + "test_reduce_l2_do_not_keepdims_random", + "test_reduce_l2_keep_dims_example", + "test_reduce_l2_keep_dims_random", + "test_reduce_l2_negative_axes_keep_dims_example", + "test_reduce_l2_negative_axes_keep_dims_random", + "test_reduce_log_sum_asc_axes", + "test_reduce_log_sum_default", + "test_reduce_log_sum_desc_axes", + // tests "test_reduce_log_sum_exp_*" on opset17/opset18 are excluded because they use float64. + // "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example", + // "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random", + // "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_example", + // "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_random", + // "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_example", + // "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random", + // "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example", + // "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random", + "test_reduce_log_sum_negative_axes", + "test_reduce_log_sum", + "test_reduce_max_default_axes_keepdim_example", + // "test_reduce_max_default_axes_keepdims_random", + // "test_reduce_max_do_not_keepdims_example", + // "test_reduce_max_do_not_keepdims_random", + // "test_reduce_max_keepdims_example", + // "test_reduce_max_keepdims_random", + // "test_reduce_max_negative_axes_keepdims_example", + // "test_reduce_max_negative_axes_keepdims_random", + // "test_reduce_mean_default_axes_keepdims_example", + // "test_reduce_mean_default_axes_keepdims_random", + // "test_reduce_mean_do_not_keepdims_example", + // "test_reduce_mean_do_not_keepdims_random", + // "test_reduce_mean_keepdims_example", + // "test_reduce_mean_keepdims_random", + // "test_reduce_mean_negative_axes_keepdims_example", + // "test_reduce_mean_negative_axes_keepdims_random", + // "test_reduce_min_default_axes_keepdims_example", + // "test_reduce_min_default_axes_keepdims_random", + // "test_reduce_min_do_not_keepdims_example", + // "test_reduce_min_do_not_keepdims_random", + // "test_reduce_min_keepdims_example", + // "test_reduce_min_keepdims_random", + // "test_reduce_min_negative_axes_keepdims_example", + // "test_reduce_min_negative_axes_keepdims_random", + // "test_reduce_prod_default_axes_keepdims_example", + // "test_reduce_prod_default_axes_keepdims_random", + // "test_reduce_prod_do_not_keepdims_example", + // "test_reduce_prod_do_not_keepdims_random", + // "test_reduce_prod_keepdims_example", + // "test_reduce_prod_keepdims_random", + // "test_reduce_prod_negative_axes_keepdims_example", + // "test_reduce_prod_negative_axes_keepdims_random", + // "test_reduce_sum_default_axes_keepdims_example", + // "test_reduce_sum_default_axes_keepdims_random", + // "test_reduce_sum_do_not_keepdims_example", + // "test_reduce_sum_do_not_keepdims_random", + "test_reduce_sum_empty_axes_input_noop_example", + "test_reduce_sum_empty_axes_input_noop_random", + // "test_reduce_sum_keepdims_example", + // "test_reduce_sum_keepdims_random", + // "test_reduce_sum_negative_axes_keepdims_example", + // "test_reduce_sum_negative_axes_keepdims_random", + // "test_reduce_sum_square_default_axes_keepdims_example", + // "test_reduce_sum_square_default_axes_keepdims_random", + // "test_reduce_sum_square_do_not_keepdims_example", + // "test_reduce_sum_square_do_not_keepdims_random", + // "test_reduce_sum_square_keepdims_example", + // "test_reduce_sum_square_keepdims_random", + // "test_reduce_sum_square_negative_axes_keepdims_example", + // "test_reduce_sum_square_negative_axes_keepdims_random", + // "test_reflect_pad", + "test_relu", + "test_reshape_allowzero_reordered", + "test_reshape_extended_dims", + "test_reshape_negative_dim", + "test_reshape_negative_extended_dims", + "test_reshape_one_dim", + "test_reshape_reduced_dims", + "test_reshape_reordered_all_dims", + "test_reshape_reordered_dims", + "test_reshape_reordered_last_dims", + "test_reshape_zero_and_negative_dim", + "test_reshape_zero_dim", + "test_resize_downsample_linear", + "test_resize_downsample_nearest", + "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", + // "test_resize_downsample_scales_cubic_align_corners", + "test_resize_downsample_scales_cubic", + // "test_resize_downsample_scales_linear_align_corners", + "test_resize_downsample_scales_linear", + "test_resize_downsample_scales_nearest", + "test_resize_downsample_sizes_cubic", + "test_resize_downsample_sizes_linear_pytorch_half_pixel", + "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", + "test_resize_downsample_sizes_nearest", + "test_resize_nearest", + "test_resize_tf_crop_and_resize", + "test_resize_upsample_linear", + "test_resize_upsample_nearest", + "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", + "test_resize_upsample_scales_cubic_align_corners", + "test_resize_upsample_scales_cubic_asymmetric", + "test_resize_upsample_scales_cubic", + "test_resize_upsample_scales_linear_align_corners", + "test_resize_upsample_scales_linear", + "test_resize_upsample_scales_nearest", + "test_resize_upsample_sizes_cubic", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + "test_resize_upsample_sizes_nearest", + // // "test_reversesequence_batch", + // // "test_reversesequence_time", + // // "test_rnn_seq_length", + // // "test_roialign_aligned_false", + // // "test_roialign_aligned_true", + // // "test_roialign", + // // "test_round", + // // "test_scan_sum", + // // "test_scan9_sum", + // // "test_scatter_elements_with_axis", + // // "test_scatter_elements_with_duplicate_indices", + // // "test_scatter_elements_with_negative_indices", + // // "test_scatter_elements_without_axis", + // // "test_scatter_with_axis", + // // "test_scatter_without_axis", + // // "test_scatternd_add", + // // "test_scatternd_multiply", + // // "test_scatternd", + // // "test_sce_mean_3d_expanded", + // // "test_sce_mean_3d_log_prob_expanded", + // // "test_sce_mean_3d_log_prob", + // // "test_sce_mean_3d", + // // "test_sce_mean_expanded", + // // "test_sce_mean_log_prob_expanded", + // // "test_sce_mean_log_prob", + // // "test_sce_mean_no_weight_ii_3d_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob", + // // "test_sce_mean_no_weight_ii_3d", + // // "test_sce_mean_no_weight_ii_4d_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob", + // // "test_sce_mean_no_weight_ii_4d", + // // "test_sce_mean_no_weight_ii_expanded", + // // "test_sce_mean_no_weight_ii_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_log_prob", + // // "test_sce_mean_no_weight_ii", + // // "test_sce_mean_weight_expanded", + // // "test_sce_mean_weight_ii_3d_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob", + // // "test_sce_mean_weight_ii_3d", + // // "test_sce_mean_weight_ii_4d_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob", + // // "test_sce_mean_weight_ii_4d", + // // "test_sce_mean_weight_ii_expanded", + // // "test_sce_mean_weight_ii_log_prob_expanded", + // // "test_sce_mean_weight_ii_log_prob", + // // "test_sce_mean_weight_ii", + // // "test_sce_mean_weight_log_prob_expanded", + // // "test_sce_mean_weight_log_prob", + // // "test_sce_mean_weight", + // // "test_sce_mean", + // // "test_sce_NCd1_mean_weight_negative_ii_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob", + // // "test_sce_NCd1_mean_weight_negative_ii", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob", + // // "test_sce_NCd1d2d3_sum_weight_high_ii", + // // "test_sce_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_mean_weight", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_none_no_weight", + // // "test_sce_none_expanded", + // // "test_sce_none_log_prob_expanded", + // // "test_sce_none_log_prob", + // // "test_sce_none_weights_expanded", + // // "test_sce_none_weights_log_prob_expanded", + // // "test_sce_none_weights_log_prob", + // // "test_sce_none_weights", + // // "test_sce_none", + // // "test_sce_sum_expanded", + // // "test_sce_sum_log_prob_expanded", + // // "test_sce_sum_log_prob", + // // "test_sce_sum", + // "test_selu_default", + // "test_selu_example", + // "test_selu", + // // "test_sequence_insert_at_back", + // // "test_sequence_insert_at_front", + // // "test_sequence_map_add_1_sequence_1_tensor_expanded", + // // "test_sequence_map_add_1_sequence_1_tensor", + // // "test_sequence_map_add_2_sequences_expanded", + // // "test_sequence_map_add_2_sequences", + // // "test_sequence_map_extract_shapes_expanded", + // // "test_sequence_map_extract_shapes", + // // "test_sequence_map_identity_1_sequence_1_tensor_expanded", + // // "test_sequence_map_identity_1_sequence_1_tensor", + // // "test_sequence_map_identity_1_sequence_expanded", + // // "test_sequence_map_identity_1_sequence", + // // "test_sequence_map_identity_2_sequences_expanded", + // // "test_sequence_map_identity_2_sequences", + // "test_shrink_hard", + // "test_shrink_soft", + "test_sigmoid_example", + "test_sigmoid", + // "test_sign", + // "test_simple_rnn_batchwise", + // "test_simple_rnn_defaults", + // "test_simple_rnn_with_initial_bias", + "test_sin_example", + "test_sin", + // "test_sinh_example", + // "test_sinh", + // // "test_size_example", + // // "test_size", + // "test_slice_default_axes", + // "test_slice_default_steps", + // "test_slice_end_out_of_bounds", + // "test_slice_neg_steps", + // "test_slice_neg", + // "test_slice_negative_axes", + // "test_slice_start_out_of_bounds", + // "test_slice", + // "test_softmax_axis_0_expanded", + "test_softmax_axis_0", + // "test_softmax_axis_1_expanded", + "test_softmax_axis_1", + // "test_softmax_axis_2_expanded", + "test_softmax_axis_2", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // "test_softmax_cross_entropy_mean_3d_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob", + // "test_softmax_cross_entropy_mean_3d", + // "test_softmax_cross_entropy_mean_expanded", + // "test_softmax_cross_entropy_mean_log_prob_expanded", + // "test_softmax_cross_entropy_mean_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_log_prob", + // "test_softmax_cross_entropy_mean_weight", + // "test_softmax_cross_entropy_mean", + // "test_softmax_cross_entropy_none_expanded", + // "test_softmax_cross_entropy_none_log_prob_expanded", + // "test_softmax_cross_entropy_none_log_prob", + // "test_softmax_cross_entropy_none_weights_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob", + // "test_softmax_cross_entropy_none_weights", + // "test_softmax_cross_entropy_none", + // "test_softmax_cross_entropy_sum_expanded", + // "test_softmax_cross_entropy_sum_log_prob_expanded", + // "test_softmax_cross_entropy_sum_log_prob", + // "test_softmax_cross_entropy_sum", + // "opset13/test_softmax_default_axis_expanded", + "opset13/test_softmax_default_axis", + // "test_softmax_example_expanded", + "test_softmax_example", + // "test_softmax_large_number_expanded", + "test_softmax_large_number", + // "test_softmax_negative_axis_expanded", + "test_softmax_negative_axis", + // // "test_softplus_example", + // // "test_softplus", + // // "test_softsign_example", + // // "test_softsign", + // "test_spacetodepth_example", + // "test_spacetodepth", + "test_split_equal_parts_1d", + "test_split_equal_parts_2d", + "test_split_equal_parts_default_axis", + "test_split_variable_parts_1d", + "test_split_variable_parts_2d", + "test_split_variable_parts_default_axis", + "test_split_zero_size_splits", + "test_sqrt_example", + "test_sqrt", + "test_squeeze_negative_axes", + "test_squeeze", + // // "test_stft_with_window", + // // "test_stft", + // // "test_strnormalizer_export_monday_casesensintive_lower", + // // "test_strnormalizer_export_monday_casesensintive_nochangecase", + // // "test_strnormalizer_export_monday_casesensintive_upper", + // // "test_strnormalizer_export_monday_empty_output", + // // "test_strnormalizer_export_monday_insensintive_upper_twodim", + // // "test_strnormalizer_nostopwords_nochangecase", + "test_sub_bcast", + "test_sub_example", + // "test_sub_uint8", + "test_sub", + // "test_sum_example", + // "test_sum_one_input", + // "test_sum_two_inputs", + "test_tan_example", + "test_tan", + "test_tanh_example", + "test_tanh", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip0", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", + // // "test_tfidfvectorizer_tf_only_bigrams_skip0", + // // "test_tfidfvectorizer_tf_onlybigrams_levelempty", + // // "test_tfidfvectorizer_tf_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_uniandbigrams_skip5", + // "test_thresholdedrelu_default", + // "test_thresholdedrelu_example", + // "test_thresholdedrelu", + // "test_tile_precomputed", + // "test_tile", + // // "test_top_k_negative_axis", + // // "test_top_k_smallest", + // // "test_top_k", + // // "test_training_dropout_default_mask", + // // "test_training_dropout_default", + // // "test_training_dropout_mask", + // // "test_training_dropout_zero_ratio_mask", + // // "test_training_dropout_zero_ratio", + // // "test_training_dropout", + "test_transpose_all_permutations_0", + "test_transpose_all_permutations_1", + "test_transpose_all_permutations_2", + "test_transpose_all_permutations_3", + "test_transpose_all_permutations_4", + "test_transpose_all_permutations_5", + "test_transpose_default", + // "test_tril_neg", + // "test_tril_one_row_neg", + // "test_tril_out_neg", + // "test_tril_out_pos", + // "test_tril_pos", + // "test_tril_square_neg", + // "test_tril_square", + // "test_tril_zero", + // "test_tril", + // "test_triu_neg", + // "test_triu_one_row", + // "test_triu_out_neg_out", + // "test_triu_out_pos", + // "test_triu_pos", + // "test_triu_square_neg", + // "test_triu_square", + // "test_triu_zero", + // "test_triu", + // // "test_unique_not_sorted_without_axis", + // // "test_unique_sorted_with_axis_3d", + // // "test_unique_sorted_with_axis", + // // "test_unique_sorted_with_negative_axis", + // // "test_unique_sorted_without_axis", + "test_unsqueeze_axis_0", + "test_unsqueeze_axis_1", + "test_unsqueeze_axis_2", + "test_unsqueeze_axis_3", + "test_unsqueeze_negative_axes", + "test_unsqueeze_three_axes", + "test_unsqueeze_two_axes", + "test_unsqueeze_unsorted_axes", + "test_unsqueeze", + // "test_wrap_pad" + // "test_upsample_nearest", + "test_where_example" + // "test_where_long_example", + // "test_xor_bcast3v1d", + // "test_xor_bcast3v2d", + // "test_xor_bcast4v2d", + // "test_xor_bcast4v3d", + // "test_xor_bcast4v4d", + // "test_xor2d", + // "test_xor3d", + // "test_xor4d" ], "ops": [] } diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 9bd0ec1425f9..96e374f87aed 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -19,49 +19,7 @@ if (ORT_WEB_TEST_CONFIG.model.some(testGroup => testGroup.tests.some(test => tes } // set flags -const options = ORT_WEB_TEST_CONFIG.options; -if (options.debug !== undefined) { - ort.env.debug = options.debug; -} -if (options.globalEnvFlags) { - const flags = options.globalEnvFlags; - if (flags.logLevel !== undefined) { - ort.env.logLevel = flags.logLevel; - } - if (flags.webgl?.contextId !== undefined) { - ort.env.webgl.contextId = flags.webgl.contextId; - } - if (flags.webgl?.matmulMaxBatchSize !== undefined) { - ort.env.webgl.matmulMaxBatchSize = flags.webgl.matmulMaxBatchSize; - } - if (flags.webgl?.textureCacheMode !== undefined) { - ort.env.webgl.textureCacheMode = flags.webgl.textureCacheMode; - } - if (flags.webgl?.pack !== undefined) { - ort.env.webgl.pack = flags.webgl.pack; - } - if (flags.webgl?.async !== undefined) { - ort.env.webgl.async = flags.webgl.async; - } - if (flags.wasm?.numThreads !== undefined) { - ort.env.wasm.numThreads = flags.wasm.numThreads; - } - if (flags.wasm?.simd !== undefined) { - ort.env.wasm.simd = flags.wasm.simd; - } - if (flags.wasm?.proxy !== undefined) { - ort.env.wasm.proxy = flags.wasm.proxy; - } - if (flags.wasm?.initTimeout !== undefined) { - ort.env.wasm.initTimeout = flags.wasm.initTimeout; - } - if (flags.webgpu?.profilingMode !== undefined) { - ort.env.webgpu.profiling = {mode: flags.webgpu.profilingMode}; - } - if (flags.webgpu?.validateInputContent !== undefined) { - ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; - } -} +Object.assign(ort.env, ORT_WEB_TEST_CONFIG.options.globalEnvFlags); // Set logging configuration for (const logConfig of ORT_WEB_TEST_CONFIG.log) { @@ -110,8 +68,7 @@ for (const group of ORT_WEB_TEST_CONFIG.model) { let context: ModelTestContext; before('prepare session', async () => { - context = await ModelTestContext.create( - test, ORT_WEB_TEST_CONFIG.profile, ORT_WEB_TEST_CONFIG.options.sessionOptions); + context = await ModelTestContext.create(test, ORT_WEB_TEST_CONFIG.profile, ORT_WEB_TEST_CONFIG.options); }); after('release session', async () => { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 5e9b0910a2c6..d8ee5ef95320 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001; */ const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; -function toInternalTensor(tensor: ort.Tensor): Tensor { - return new Tensor( - tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType); -} function fromInternalTensor(tensor: Tensor): ort.Tensor { return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims); } @@ -96,7 +92,7 @@ async function loadTensors( const outputs: Test.NamedTensor[] = []; let dataFileType: 'none'|'pb'|'npy' = 'none'; - const allowInt64 = ['wasm', 'xnnpack', 'webgpu'].includes(backendName); + const allowInt64 = ['wasm', 'webgpu', 'webnn'].includes(backendName); for (const dataFile of testCase.dataFiles) { const ext = extname(dataFile); @@ -137,7 +133,8 @@ async function loadTensors( } async function initializeSession( - modelFilePath: string, backendHint: string, ioBindingMode: Test.IOBindingMode, profile: boolean, + modelFilePath: string, backendHint: ort.InferenceSession.ExecutionProviderConfig, ioBindingMode: Test.IOBindingMode, + profile: boolean, externalData: ort.InferenceSession.SessionOptions['externalData'], sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; @@ -152,7 +149,8 @@ async function initializeSession( executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile, - preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + externalData }; let session: ort.InferenceSession; @@ -161,7 +159,8 @@ async function initializeSession( if (preloadModelData) { session = await ort.InferenceSession.create(preloadModelData, sessionConfig); } else { - session = await ort.InferenceSession.create(modelFilePath, sessionConfig); + const modelData = await readFile(modelFilePath); + session = await ort.InferenceSession.create(modelData, sessionConfig); } } catch (e) { Logger.error( @@ -232,9 +231,8 @@ export class ModelTestContext { /** * create a ModelTestContext object that used in every test cases in the given ModelTest. */ - static async create( - modelTest: Test.ModelTest, profile: boolean, - sessionOptions?: ort.InferenceSession.SessionOptions): Promise { + static async create(modelTest: Test.ModelTest, profile: boolean, testOptions?: Test.Options): + Promise { if (this.initializing) { throw new Error('cannot create a ModelTestContext object when the previous creation is not done'); } @@ -243,8 +241,12 @@ export class ModelTestContext { this.initializing = true; const initStart = now(); + const executionProviderConfig = + modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || 'webnn') : modelTest.backend!; const session = await initializeSession( - modelTest.modelUrl, modelTest.backend!, modelTest.ioBinding, profile, sessionOptions || {}, this.cache); + modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, modelTest.externalData, + testOptions?.sessionOptions || {}, this.cache); + const initEnd = now(); for (const testCase of modelTest.cases) { @@ -313,7 +315,7 @@ export class TensorResultValidator { } else if (backend === 'webgpu') { this.absoluteThreshold = WEBGPU_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WEBGPU_THRESHOLD_RELATIVE_ERROR; - } else if (backend === 'wasm' || backend === 'xnnpack' || backend === 'webnn') { + } else if (backend === 'wasm' || backend === 'webnn') { this.absoluteThreshold = WASM_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WASM_THRESHOLD_RELATIVE_ERROR; } else if (backend === 'onnxruntime') { @@ -325,6 +327,10 @@ export class TensorResultValidator { } checkTensorResult(actual: Tensor[], expected: Tensor[]): void { + this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor)); + } + + checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { // check output size expect(actual.length, 'size of output tensors').to.equal(expected.length); @@ -342,10 +348,6 @@ export class TensorResultValidator { } } - checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { - this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor)); - } - checkNamedTensorResult(actual: Record, expected: Test.NamedTensor[]): void { // check output size expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length); @@ -359,7 +361,7 @@ export class TensorResultValidator { } // This function check whether 2 tensors should be considered as 'match' or not - areEqual(actual: Tensor, expected: Tensor): boolean { + areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean { if (!actual || !expected) { return false; } @@ -387,13 +389,13 @@ export class TensorResultValidator { switch (actualType) { case 'string': - return this.strictEqual(actual.stringData, expected.stringData); + return this.strictEqual(actual.data, expected.data); case 'float32': case 'float64': return this.floatEqual( - actual.numberData as number[] | Float32Array | Float64Array, - expected.numberData as number[] | Float32Array | Float64Array); + actual.data as number[] | Float32Array | Float64Array, + expected.data as number[] | Float32Array | Float64Array); case 'uint8': case 'int8': @@ -404,10 +406,8 @@ export class TensorResultValidator { case 'int64': case 'bool': return TensorResultValidator.integerEqual( - actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array, - expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array); + actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array); default: throw new Error('type not implemented or not supported'); @@ -574,7 +574,9 @@ export async function sessionRun(options: { // replace the CPU tensors in feeds into GPU tensors for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { - feeds[name] = createGpuTensorForInput(feeds[name]); + if (feeds[name].size > 0) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } } } } @@ -583,7 +585,11 @@ export async function sessionRun(options: { for (const name in options.outputsMetaInfo) { if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { const {type, dims} = options.outputsMetaInfo[name]; - fetches[name] = createGpuTensorForOutput(type, dims); + if (dims.some(d => d === 0)) { + fetches[name] = new ort.Tensor(type, [], dims); + } else { + fetches[name] = createGpuTensorForOutput(type, dims); + } } } } @@ -628,8 +634,8 @@ export async function runModelTestSet( try { const feeds: Record = {}; const outputsMetaInfo: Record = {}; - testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor); + testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor); const [start, end, outputs] = await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { diff --git a/js/web/test/test-shared.ts b/js/web/test/test-shared.ts index 7c327e7c97ac..55beb66e37e6 100644 --- a/js/web/test/test-shared.ts +++ b/js/web/test/test-shared.ts @@ -15,14 +15,33 @@ export function bufferToBase64(buffer: Uint8Array): string { return base64.fromByteArray(buffer); } +async function retry(fn: () => Promise, maxRetries = 3, delay = 100): Promise { + let retries = maxRetries; + do { + try { + return await fn(); + } catch (err) { + if (retries-- === 0) { + throw err; + } + await new Promise(resolve => setTimeout(resolve, delay)); + } + // eslint-disable-next-line no-constant-condition + } while (true); +} + export async function readFile(file: string) { if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node return fs.readFile(file); } else { // browser - const response = await fetch(file); - return new Uint8Array(await response.arrayBuffer()); + // + // use "retry" to workaround the error "TypeError: Failed to fetch" in some test environments + return retry(async () => { + const response = await fetch(file); + return new Uint8Array(await response.arrayBuffer()); + }); } } diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index 5bdc8d84cc7a..14b9fd7c005a 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -65,6 +65,7 @@ export declare namespace Test { export interface ModelTest { name: string; modelUrl: string; + externalData?: InferenceSession.SessionOptions['externalData']; backend?: string; // value should be populated at build time ioBinding: IOBindingMode; platformCondition?: PlatformCondition; @@ -143,6 +144,7 @@ export declare namespace Test { cudaFlags?: Record; wasmOptions?: InferenceSession.WebAssemblyExecutionProviderOption; webglOptions?: InferenceSession.WebGLExecutionProviderOption; + webnnOptions?: InferenceSession.WebNNExecutionProviderOption; globalEnvFlags?: EnvOptions; } diff --git a/js/web/test/training/e2e/browser-test-wasm.js b/js/web/test/training/e2e/browser-test-wasm.js new file mode 100644 index 000000000000..fa87389f7ac4 --- /dev/null +++ b/js/web/test/training/e2e/browser-test-wasm.js @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +describe('Browser E2E testing for training package', function() { + it('Check that training package encompasses inference', async function() { + ort.env.wasm.numThreads = 1; + await testInferenceFunction(ort, {executionProviders: ['wasm']}); + }); + + it('Check training functionality, all options', async function() { + ort.env.wasm.numThreads = 1; + await testTrainingFunctionAll(ort, {executionProviders: ['wasm']}); + }); + + it('Check training functionality, minimum options', async function() { + ort.env.wasm.numThreads = 1; + await testTrainingFunctionMin(ort, {executionProviders: ['wasm']}); + }); +}); diff --git a/js/web/test/training/e2e/common.js b/js/web/test/training/e2e/common.js new file mode 100644 index 000000000000..b6040b63d56b --- /dev/null +++ b/js/web/test/training/e2e/common.js @@ -0,0 +1,246 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const DATA_FOLDER = 'data/'; +const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; +const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; +const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; +const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; + +const trainingSessionAllOptions = { + checkpointState: TRAININGDATA_CKPT, + trainModel: TRAININGDATA_TRAIN_MODEL, + evalModel: TRAININGDATA_EVAL_MODEL, + optimizerModel: TRAININGDATA_OPTIMIZER_MODEL +} + +const trainingSessionMinOptions = { + checkpointState: TRAININGDATA_CKPT, + trainModel: TRAININGDATA_TRAIN_MODEL, +} + +// ASSERT METHODS + +function assert(cond) { + if (!cond) throw new Error(); +} + +function assertStrictEquals(actual, expected) { + if (actual !== expected) { + let strRep = actual; + if (typeof actual === 'object') { + strRep = JSON.stringify(actual); + } + throw new Error(`expected: ${expected}; got: ${strRep}`); + } +} + +function assertTwoListsUnequal(list1, list2) { + if (list1.length !== list2.length) { + return; + } + for (let i = 0; i < list1.length; i++) { + if (list1[i] !== list2[i]) { + return; + } + } + throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); +} + +// HELPER METHODS FOR TESTS + +function generateGaussianRandom(mean=0, scale=1) { + const u = 1 - Math.random(); + const v = Math.random(); + const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); + return z * scale + mean; +} + +function generateGaussianFloatArray(length) { + const array = new Float32Array(length); + + for (let i = 0; i < length; i++) { + array[i] = generateGaussianRandom(); + } + + return array; +} + +/** + * creates the TrainingSession and verifies that the input and output names of the training model loaded into the + * training session are correct. + * @param {} ort + * @param {*} createOptions + * @param {*} options + * @returns + */ +async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { + const trainingSession = await ort.TrainingSession.create(createOptions, options); + + assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); + assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); + assertStrictEquals(trainingSession.trainingInputNames.length, 2); + assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); + assertStrictEquals(trainingSession.trainingOutputNames.length, 1); + return trainingSession; +} + +/** + * verifies that the eval input and output names associated with the eval model loaded into the given training session + * are correct. + */ +function checkEvalModel(trainingSession) { + assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); + assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); + assertStrictEquals(trainingSession.evalInputNames.length, 2); + assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); + assertStrictEquals(trainingSession.evalOutputNames.length, 1); +} + +/** + * Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if + * accessed + * @param {} trainingSession + */ +function checkNoEvalModel(trainingSession) { + try { + assertStrictEquals(trainingSession.evalInputNames, "should have thrown an error upon accessing"); + } catch (error) { + assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); + } + try { + assertStrictEquals(trainingSession.evalOutputNames, "should have thrown an error upon accessing"); + } catch (error) { + assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); + } +} + +/** + * runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length + * of 1 for the loss. + * @param {} trainingSession + * @param {*} feeds + * @returns + */ +var runTrainStepAndCheck = async function(trainingSession, feeds) { + const results = await trainingSession.runTrainStep(feeds); + assertStrictEquals(Object.keys(results).length, 1); + assertStrictEquals(results['onnx::loss::21273'].data.length, 1); + assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); + return results; +}; + +var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, constant, paramsBefore) { + // make a float32 array that is filled with the constant + const newParams = new Float32Array(paramsLength); + for (let i = 0; i < paramsLength; i++) { + newParams[i] = constant; + } + + const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); + + await trainingSession.loadParametersBuffer(newParamsUint8); + const paramsAfterLoad = await trainingSession.getContiguousParameters(); + + // check that the parameters have changed + assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); + assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); + + // check that the parameters have changed to what they should be + for (let i = 0; i < paramsLength; i++) { + // round to the same number of digits (4 decimal places) + assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); + } + + return paramsAfterLoad; +} + +// TESTS + +var testInferenceFunction = async function(ort, options) { + const session = await ort.InferenceSession.create('data/model.onnx', options || {}); + + const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); + + const fetches = + await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])}); + + const c = fetches.c; + + assert(c instanceof ort.Tensor); + assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); + assert(c.data[0] === 700); + assert(c.data[1] === 800); + assert(c.data[2] === 900); + assert(c.data[3] === 1580); + assert(c.data[4] === 1840); + assert(c.data[5] === 2100); + assert(c.data[6] === 2460); + assert(c.data[7] === 2880); + assert(c.data[8] === 3300); +}; + +var testTrainingFunctionMin = async function(ort, options) { + const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); + checkNoEvalModel(trainingSession); + const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); + const labels = new ort.Tensor('int32', [2, 1], [2]); + const feeds = {"input-0": input0, "labels": labels}; + + // check getParametersSize + const paramsSize = await trainingSession.getParametersSize(); + assertStrictEquals(paramsSize, 397510); + + // check getContiguousParameters + const originalParams = await trainingSession.getContiguousParameters(); + assertStrictEquals(originalParams.dims.length, 1); + assertStrictEquals(originalParams.dims[0], 397510); + assertStrictEquals(originalParams.data[0], -0.025190064683556557); + assertStrictEquals(originalParams.data[2000], -0.034044936299324036); + + await runTrainStepAndCheck(trainingSession, feeds); + + await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); +} + +var testTrainingFunctionAll = async function(ort, options) { + const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); + checkEvalModel(trainingSession); + + const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); + const labels = new ort.Tensor('int32', [2, 1], [2]); + let feeds = {"input-0": input0, "labels": labels}; + + // check getParametersSize + const paramsSize = await trainingSession.getParametersSize(); + assertStrictEquals(paramsSize, 397510); + + // check getContiguousParameters + const originalParams = await trainingSession.getContiguousParameters(); + assertStrictEquals(originalParams.dims.length, 1); + assertStrictEquals(originalParams.dims[0], 397510); + assertStrictEquals(originalParams.data[0], -0.025190064683556557); + assertStrictEquals(originalParams.data[2000], -0.034044936299324036); + + const results = await runTrainStepAndCheck(trainingSession, feeds); + + await trainingSession.runOptimizerStep(feeds); + feeds = {"input-0": input0, "labels": labels}; + // check getContiguousParameters after optimizerStep -- that the parameters have been updated + const optimizedParams = await trainingSession.getContiguousParameters(); + assertTwoListsUnequal(originalParams.data, optimizedParams.data); + + const results2 = await runTrainStepAndCheck(trainingSession, feeds); + + // check that loss decreased after optimizer step and training again + assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); + + await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); +} + +if (typeof module === 'object') { + module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; +} diff --git a/js/web/test/training/e2e/data/model.onnx b/js/web/test/training/e2e/data/model.onnx new file mode 100644 index 000000000000..088124bd4862 --- /dev/null +++ b/js/web/test/training/e2e/data/model.onnx @@ -0,0 +1,16 @@ + backend-test:b + +a +bc"MatMultest_matmul_2dZ +a +  + +Z +b +  + +b +c +  + +B \ No newline at end of file diff --git a/js/web/test/training/e2e/karma.conf.js b/js/web/test/training/e2e/karma.conf.js new file mode 100644 index 000000000000..e441cb65b412 --- /dev/null +++ b/js/web/test/training/e2e/karma.conf.js @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const args = require('minimist')(process.argv.slice(2)); +const SELF_HOST = !!args['self-host']; +const ORT_MAIN = args['ort-main']; +const TEST_MAIN = args['test-main']; +if (typeof TEST_MAIN !== 'string') { + throw new Error('flag --test-main= is required'); +} +const USER_DATA = args['user-data']; +if (typeof USER_DATA !== 'string') { + throw new Error('flag --user-data= is required'); +} + +module.exports = function(config) { + const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; + config.set({ + frameworks: ['mocha'], + files: [ + {pattern: distPrefix + ORT_MAIN}, + {pattern: './common.js'}, + {pattern: TEST_MAIN}, + {pattern: './node_modules/onnxruntime-web/dist/*.wasm', included: false, nocache: true}, + {pattern: './data/*', included: false}, + ], + plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], + proxies: { + '/model.onnx': '/base/model.onnx', + '/data/': '/base/data/', + }, + client: {captureConsole: true, mocha: {expose: ['body'], timeout: 60000}}, + reporters: ['mocha'], + captureTimeout: 120000, + reportSlowerThan: 100, + browserDisconnectTimeout: 600000, + browserNoActivityTimeout: 300000, + browserDisconnectTolerance: 0, + browserSocketTimeout: 60000, + hostname: 'localhost', + browsers: [], + customLaunchers: { + Chrome_default: {base: 'ChromeHeadless', chromeDataDir: USER_DATA}, + Chrome_no_threads: { + base: 'ChromeHeadless', + chromeDataDir: USER_DATA, + // TODO: no-thread flags + }, + Edge_default: {base: 'Edge', edgeDataDir: USER_DATA} + } + }); +}; diff --git a/js/web/test/training/e2e/package.json b/js/web/test/training/e2e/package.json new file mode 100644 index 000000000000..5f11a27de6df --- /dev/null +++ b/js/web/test/training/e2e/package.json @@ -0,0 +1,14 @@ +{ + "devDependencies": { + "@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", + "fs-extra": "^11.1.0", + "globby": "^13.1.3", + "karma": "^6.4.1", + "karma-chrome-launcher": "^3.1.1", + "karma-mocha": "^2.0.1", + "karma-mocha-reporter": "^2.2.5", + "light-server": "^2.9.1", + "minimist": "^1.2.7", + "mocha": "^10.2.0" + } +} diff --git a/js/web/test/training/e2e/run.js b/js/web/test/training/e2e/run.js new file mode 100644 index 000000000000..379a8136f3ff --- /dev/null +++ b/js/web/test/training/e2e/run.js @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const path = require('path'); +const fs = require('fs-extra'); +const {spawn} = require('child_process'); +const startServer = require('./simple-http-server'); +const minimist = require('minimist'); + +// copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file +// exists in its parent folder. +// here we use /build/js/e2e-training/ for the test + +const TEST_E2E_SRC_FOLDER = __dirname; +const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..'); +const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training'); +const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache'); +const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data'); +fs.emptyDirSync(TEST_E2E_RUN_FOLDER); +fs.emptyDirSync(NPM_CACHE_FOLDER); +fs.emptyDirSync(CHROME_USER_DATA_FOLDER); +fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); + +// training data to copy +const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..'); +const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api'); +const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); + +// always use a new folder as user-data-dir +let nextUserDataDirId = 0; +function getNextUserDataDir() { + const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()) + nextUserDataDirId++; + fs.emptyDirSync(dir); + return dir; +} + +// commandline arguments +const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; + +async function main() { + // find packed package + const {globbySync} = await import('globby'); + + const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); + const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', {cwd: ORT_COMMON_FOLDER}); + + const PACKAGES_TO_INSTALL = []; + + if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) { + PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0])); + } else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) { + throw new Error('multiple packages found for onnxruntime-common.'); + } + + const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); + const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', {cwd: ORT_WEB_FOLDER}); + if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { + throw new Error('cannot find exactly single package for onnxruntime-web.'); + } + PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0])); + + // we start here: + + // install dev dependencies + await runInShell(`npm install`); + + // npm install with "--cache" to install packed packages with an empty cache folder + await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map(i => `"${i}"`).join(' ')}`); + + // prepare training data + prepareTrainingDataByCopying(); + + console.log('==============================================================='); + console.log("Running self-hosted tests"); + console.log('==============================================================='); + // test cases with self-host (ort hosted in same origin) + await testAllBrowserCases({hostInKarma: true}); + + console.log('==============================================================='); + console.log("Running not self-hosted tests"); + console.log('==============================================================='); + // test cases without self-host (ort hosted in same origin) + startServer(path.resolve(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web')); + await testAllBrowserCases({hostInKarma: false}); + + // no error occurs, exit with code 0 + process.exit(0); +} + +async function testAllBrowserCases({hostInKarma}) { + await runKarma({hostInKarma, main: './browser-test-wasm.js'}); +} + +async function runKarma({hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js'}) { + console.log('==============================================================='); + console.log(`Running karma with the following binary: ${ortMain}`); + console.log('==============================================================='); + const selfHostFlag = hostInKarma ? '--self-host' : ''; + await runInShell(`npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ + ortMain} --test-main=${main} --user-data=${getNextUserDataDir()}`); +} + +async function runInShell(cmd) { + console.log('==============================================================='); + console.log(' Running command in shell:'); + console.log(' > ' + cmd); + console.log('==============================================================='); + let complete = false; + const childProcess = spawn(cmd, {shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER}); + childProcess.on('close', function(code) { + if (code !== 0) { + process.exit(code); + } else { + complete = true; + } + }); + while (!complete) { + await delay(100); + } +} + +async function delay(ms) { + return new Promise(function(resolve) { + setTimeout(function() { + resolve(); + }, ms); + }); +} + +function prepareTrainingDataByCopying() { + fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST); + console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`); +} + +main(); diff --git a/js/web/test/training/e2e/simple-http-server.js b/js/web/test/training/e2e/simple-http-server.js new file mode 100644 index 000000000000..a157c7dd93ad --- /dev/null +++ b/js/web/test/training/e2e/simple-http-server.js @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +// this is a simple HTTP server that enables CORS. +// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework + +const http = require('http'); +const fs = require('fs'); +const path = require('path'); + +const validRequests = { + // .wasm files + '/dist/ort-wasm.wasm': ['dist/ort-wasm.wasm', 'application/wasm'], + '/dist/ort-wasm-simd.wasm': ['dist/ort-wasm-simd.wasm', 'application/wasm'], + '/dist/ort-training-wasm-simd.wasm': ['dist/ort-training-wasm-simd.wasm', 'application/wasm'], + '/dist/ort-wasm-threaded.wasm': ['dist/ort-wasm-threaded.wasm', 'application/wasm'], + '/dist/ort-wasm-simd-threaded.wasm': ['dist/ort-wasm-simd-threaded.wasm', 'application/wasm'], + + // proxied .wasm files: + '/test-wasm-path-override/ort-wasm.wasm': ['dist/ort-training-wasm.wasm', 'application/wasm'], + //'/test-wasm-path-override/renamed.wasm': ['dist/ort-wasm.wasm', 'application/wasm'], + + // .js files + '/dist/ort.min.js': ['dist/ort.min.js', 'text/javascript'], + '/dist/ort.training.simd.wasm.min.js': ['dist/ort.training.simd.wasm.min.js', 'text/javascript'], + '/dist/ort.training.wasm.min.js': ['dist/ort.training.wasm.min.js', 'text/javascript'], + '/dist/ort.js': ['dist/ort.js', 'text/javascript'], + '/dist/ort.webgl.min.js': ['dist/ort.webgl.min.js', 'text/javascript'], + '/dist/ort.wasm.min.js': ['dist/ort.wasm.min.js', 'text/javascript'], + '/dist/ort.wasm-core.min.js': ['dist/ort.wasm-core.min.js', 'text/javascript'], +}; + +module.exports = function(dir) { + http.createServer(function(request, response) { + console.log(`request ${request.url.replace(/\n|\r/g, '')}`); + + const requestData = validRequests[request.url]; + if (!request) { + response.writeHead(404); + response.end('404'); + } else { + const [filePath, contentType] = requestData; + fs.readFile(path.resolve(dir, filePath), function(error, content) { + if (error) { + if (error.code == 'ENOENT') { + response.writeHead(404); + response.end('404'); + } else { + response.writeHead(500); + response.end('500'); + } + } else { + response.setHeader('access-control-allow-origin', '*'); + response.writeHead(200, {'Content-Type': contentType}); + response.end(content, 'utf-8'); + } + }); + } + }) + .listen(8081); + console.log('Server running at http://127.0.0.1:8081/'); + }; diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 8c186b9b3645..014fc57f2155 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -893,7 +893,9 @@ describe('New Conv tests', () => { const expected = cpuConv( inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, testData.strides); - if (!validator.areEqual(actual, expected)) { + try { + validator.checkTensorResult([actual], [expected]); + } catch { console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`); console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`); throw new Error('Expected and Actual did not match'); diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h index a015b6fd60c8..6ff18176ebeb 100644 --- a/objectivec/include/ort_coreml_execution_provider.h +++ b/objectivec/include/ort_coreml_execution_provider.h @@ -41,6 +41,17 @@ NS_ASSUME_NONNULL_BEGIN */ @property BOOL onlyEnableForDevicesWithANE; +/** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with + * dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. + */ +@property BOOL onlyAllowStaticInputShapes; + +/** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + */ +@property BOOL createMLProgram; + @end @interface ORTSessionOptions (ORTSessionOptionsCoreMLEP) diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm index 6340fdea1c3a..58b47d68eea6 100644 --- a/objectivec/ort_coreml_execution_provider.mm +++ b/objectivec/ort_coreml_execution_provider.mm @@ -26,7 +26,10 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti const uint32_t flags = (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) | (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) | - (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0); + (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) | + (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) | + (options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML( [self CXXAPIOrtSessionOptions], flags)); return YES; diff --git a/objectivec/ort_value.mm b/objectivec/ort_value.mm index b9dc1a9885c6..c61a7ea80923 100644 --- a/objectivec/ort_value.mm +++ b/objectivec/ort_value.mm @@ -148,6 +148,9 @@ - (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error { - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) @@ -156,6 +159,9 @@ - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError* - (nullable NSMutableData*)tensorDataWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { ORT_CXX_API_THROW( "This ORTValue holds string data. Please call tensorStringDataWithError: " @@ -182,6 +188,9 @@ - (nullable NSMutableData*)tensorDataWithError:(NSError**)error { - (nullable NSArray*)tensorStringDataWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount(); const size_t tensorStringDataLength = _value->GetStringTensorDataLength(); std::vector tensorStringData(tensorStringDataLength, '\0'); diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 57219c50f39a..c3699f0fb33a 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.17.0" +__version__ = "1.18.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc index 556699192d2e..3e0533dd8b9e 100644 --- a/onnxruntime/contrib_ops/cpu/activations.cc +++ b/onnxruntime/contrib_ops/cpu/activations.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/activation/activations.h" -#include "activations.h" +#include "contrib_ops/cpu/activations.h" namespace onnxruntime { namespace contrib { @@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ThresholdedRelu); -ONNX_OPERATOR_KERNEL_EX( - Gelu, - kMSDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gelu); - ONNX_OPERATOR_KERNEL_EX( QuickGelu, kMSDomain, diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index aed4c2229215..7e64235d3fc3 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -54,47 +54,6 @@ namespace contrib { DEFINE_ELE_KERNEL(ScaledTanh); DEFINE_ELE_KERNEL(ParametricSoftplus); -template -class Gelu : public OpKernel { - public: - Gelu(const OpKernelInfo& info) : OpKernel(info) { - } - - Status Compute(OpKernelContext* context) const override { - const Tensor* input = context->Input(0); - const T* input_data = input->Data(); - - Tensor* output = context->Output(0, input->Shape()); - T* output_data = output->MutableData(); - - concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - int64_t elem_count = input->Shape().Size(); - constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. - int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; - concurrency::ThreadPool::TryBatchParallelFor( - tp, static_cast(task_count), - [&](ptrdiff_t task_idx) { - const auto start = task_idx * length_per_task; - const T* p_input = input_data + start; - T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); - - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } - }, - 0); - return Status::OK(); - } -}; - // Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call // MlasComputeLogistic instead of using Eigen for better perf. template diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index d72868cd8fa9..56c8e2911e28 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); +typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); - p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); + void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); + p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { - ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); + bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; + IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 4711ccf487cc..768676259aa1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const { relative_position_bias, ¶meters)); + if (parameters.do_rotary) { + ORT_NOT_IMPLEMENTED( + "Rotary embedding is not supported in Attention CPU kernel. \ + Please fuse the model with MHA + RotaryEmbedding."); + } + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int input_hidden_size = parameters.input_hidden_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 5d224bdc2235..515a967aa238 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -253,6 +253,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->is_unidirectional = is_unidirectional_; output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr); output_parameters->do_rotary = do_rotary_; + output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_; output_parameters->mask_filter_value = mask_filter_value_; output_parameters->scale = scale_; output_parameters->mask_type = mask_type; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 5ee40c4b9866..a6782daa58f1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -38,6 +38,7 @@ class AttentionBase { is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); @@ -72,6 +73,7 @@ class AttentionBase { bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V. bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer bool do_rotary_; // whether or not to use rotary embeddings + int rotary_embedding_; // rotary embedding dimension float mask_filter_value_; // the value to be used for filtered out positions float scale_; // the scale to be used for softmax }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index a7f83469a768..5a0c3af05c9d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -56,6 +56,7 @@ struct AttentionParameters { int v_head_size; // hidden size per head of V int num_heads; int num_splits; + int rotary_embedding; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -63,6 +64,7 @@ struct AttentionParameters { bool pass_past_in_kv; float mask_filter_value; float scale; + bool use_tf32; AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; @@ -81,6 +83,7 @@ struct PackedAttentionParameters { int token_count; bool has_relative_position_bias; bool broadcast_res_pos_bias; + bool use_tf32; }; // Parameters deduced from node attributes and inputs/outputs. @@ -95,13 +98,19 @@ struct GroupQueryAttentionParameters { int kv_hidden_size; int kv_num_heads; int num_splits; // number of splits for splitkv + int rotary_dim; // rotary embedding dimension bool is_unidirectional; // causal int local_window_size; bool kv_share_buffer; + bool is_packed_qkv; bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; }; namespace attention { @@ -132,6 +141,10 @@ constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FL // Default value for the above setting. constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; +// Environment variable to enable loading more KV data in flight in +// DecoderMaskedMultiHeadAttention/DecoderMaskedSelfAttention kernels +constexpr const char* kDecoderMaskedAttentionLoadKVDataInFlight = "ORT_DECODER_MASKED_ATTENTION_LOAD_KV_DATA_IN_FLIGHT"; + } // namespace attention } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index b761b1afd852..34f57c1655cc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -140,26 +140,35 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { PrepareMask(mask_index, mask_index_dims, mask_data, causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); - } else { // no any mask - const int memset_loop_len = batch_size * num_heads_; - const double memset_cost = static_cast(sequence_length) * total_sequence_length; - - ThreadPool::TryParallelFor(tp, memset_loop_len, memset_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int output_offset = static_cast(i) * sequence_length * total_sequence_length; - T* output = attention_probs + output_offset; - memset(output, 0, static_cast(sequence_length) * total_sequence_length * sizeof(T)); - } - }); } const int loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; - // The cost of Gemm - const double cost = static_cast(head_size) * sequence_length * total_sequence_length; + TensorOpCost unit_cost; + const size_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); + unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); + + if (mask_data != nullptr) { + unit_cost.bytes_loaded += static_cast(probs_matrix_bytes); + unit_cost.bytes_stored += static_cast(probs_matrix_bytes); + } + + if (present || present_key) { + double bytes_to_copy_key = static_cast(sizeof(T) * present_chunk_length); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + } - ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + if (relative_position_bias_data != nullptr) { + unit_cost.compute_cycles += static_cast(sequence_length * total_sequence_length); + unit_cost.bytes_loaded += probs_matrix_bytes * 2; + unit_cost.bytes_stored += probs_matrix_bytes; + } + + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; @@ -171,7 +180,7 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { memcpy(output, mask_data + mask_offset, - static_cast(sequence_length) * total_sequence_length * sizeof(T)); + probs_matrix_bytes); } const T* k = K + kv_input_chunk_length * i; @@ -188,7 +197,7 @@ class AttentionCPUBase : public AttentionBase { // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, 1.0, + Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr); if (relative_position_bias_data != nullptr) { @@ -238,10 +247,24 @@ class AttentionCPUBase : public AttentionBase { present += SafeInt(batch_size) * num_heads_ * total_sequence_length * v_head_size; } - const double cost = - static_cast(sequence_length) * static_cast(v_head_size) * static_cast(sequence_length); + // The cost of Gemm + TensorOpCost unit_cost; + unit_cost.compute_cycles = static_cast(2 * sequence_length * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast((sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); + + if (present || present_value) { + double bytes_to_copy_value = static_cast(present_chunk_length * sizeof(T)); + unit_cost.bytes_loaded += bytes_to_copy_value; + unit_cost.bytes_stored += bytes_to_copy_value; + } + + const size_t bytes_to_copy_trans = SafeInt(v_head_size) * sizeof(T); + double bytes_to_copy_trans_all = static_cast(sequence_length * bytes_to_copy_trans); + unit_cost.bytes_loaded += bytes_to_copy_trans_all; + unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const T* v = V + kv_input_chunk_length * i; if (nullptr != present) { @@ -263,9 +286,8 @@ class AttentionCPUBase : public AttentionBase { T* src = current_tmp_data; ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; T* dest = output + dest_offset; - const auto bytes_to_copy = SafeInt(v_head_size) * sizeof(T); for (int j = 0; j < sequence_length; j++) { - memcpy(dest, src, bytes_to_copy); + memcpy(dest, src, bytes_to_copy_trans); src += v_head_size; dest += v_hidden_size; } diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 694c40bf3eda..c4e4b4ec707f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i num_heads_ = static_cast(num_heads); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } // Reshape Q/K/V from BxSxD to BxSxNxH @@ -57,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv, // Transpose Q/K/V from BxSxNxH to BxNxSxH Status Transpose_BSNH_to_BNSH(const Tensor* qkv, - OrtValue& qkv_transposed) { + OrtValue& qkv_transposed, + concurrency::ThreadPool* tp = nullptr) { std::vector permutations({0, 2, 1, 3}); gsl::span permutations_span{permutations}; size_t from = 2, to = 1; - SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to); + SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); return Status::OK(); } @@ -142,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); // Transpose Q from BxSxNxH to BxNxSxH - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed)); + auto tp = context->GetOperatorThreadPool(); + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp)); return Status::OK(); } @@ -283,8 +286,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { nullptr, ¶meters, num_heads_, - scale, mask_filter_value_, + scale, + is_unidirectional_, past_present_share_buffer, false)); diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index 4c86b777e984..fb7da78a5c0a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { protected: int num_heads_; // number of attention heads float mask_filter_value_; + bool is_unidirectional_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 00e82c9844b3..c91f5b601b4e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -25,6 +25,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing) { // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None @@ -315,7 +316,7 @@ Status CheckInputs(const T* query, output_parameters->head_size = hidden_size / num_heads; output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; - output_parameters->is_unidirectional = false; + output_parameters->is_unidirectional = is_unidirectional; output_parameters->past_present_share_buffer = past_present_share_buffer; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; @@ -342,6 +343,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing, int max_threads_per_block) { @@ -350,8 +352,8 @@ Status CheckInputs(const T* query, } return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, - past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer, - dmmha_packing); + past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, + past_present_share_buffer, dmmha_packing); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 47f462d75fcc..aa8b5b5f608f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + + if (rotary_embedding_dim > 0) { + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } } template @@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; + const int n_heads = parameters.num_heads; const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; - const int half_head_size = head_size / 2; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] int head_stride = head_size; - int seq_stride = num_heads * head_stride; + int seq_stride = n_heads * head_stride; int batch_stride = sequence_length * seq_stride; if (parameters.transposed) { - // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] seq_stride = head_size; head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; + batch_stride = n_heads * head_stride; } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - const int loop_len = batch_size * sequence_length * num_heads; - const double cost = static_cast(head_size); + const int loop_len = batch_size * sequence_length * n_heads; + const double cost = static_cast(rotary_emb_dim); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / num_heads) / sequence_length); - const int s = static_cast((ptr / num_heads) % sequence_length); - const int n = static_cast(ptr % num_heads); + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input_src + block_offset; T* output_data = output_dest + block_offset; - // Cache is (M, H/2) + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) const int position_id = (position_ids_format == 0) ? static_cast(pos_ids_data[0]) + s : static_cast(pos_ids_data[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_emb_dim; const T* cos_data = cos_cache_data + cache_offset; const T* sin_data = sin_cache_data + cache_offset; int cache_idx = 0; T sign = 0; int j = 0; - for (int i = 0; i < head_size; i++) { + for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_emb_dim; sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_emb_dim; + sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } + for (int i = rotary_emb_dim; i < head_size; i++) { + output_data[i] = input_data[i]; + } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index be834a66cdc6..4e32424a22b6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index 7b2e8289f7b0..dcbb36d1c4a3 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -11,14 +11,15 @@ namespace rotary_embedding_helper { // Parameters deduced from node attributes and inputs/outputs. struct RotaryParameters { - int batch_size; // Batch size used by input - int sequence_length; // Sequence length used by input - int hidden_size; // Hidden size used by input - int head_size; // Head size used by cos/sin cache * 2 - int num_heads; // num_heads = hidden_size / head_size - int max_sequence_length; // Sequence length used by cos/sin cache - int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) - bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size + int rotary_embedding_dim; // Rotary embedding dimension. + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -26,11 +27,13 @@ Status CheckInputs(const T* input, const T* position_ids, const T* cos_cache, const T* sin_cache, + int num_heads, + int rotary_embedding_dim, void* parameters) { // input : (batch_size, sequence_length, hidden_size) // position ids : (1) or (batch_size, sequence_length) - // cos cache : (max_sequence_length, head_size / 2) - // sin cache : (max_sequence_length, head_size / 2) + // cos cache : (max_sequence_length, rotary_embedding_dim / 2) + // sin cache : (max_sequence_length, rotary_embedding_dim / 2) // Check input const auto& input_dims = input->Shape().GetDims(); @@ -60,6 +63,12 @@ Status CheckInputs(const T* input, "the same shape"); } + // Check num_heads and rotary_embedding_dim + if (rotary_embedding_dim > 0 && num_heads == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", + "specified"); + } + // Get attributes from inputs int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); @@ -73,8 +82,13 @@ Status CheckInputs(const T* input, transposed = true; } int max_sequence_length = static_cast(cos_cache_dims[0]); - int head_size = static_cast(cos_cache_dims[1]) * 2; - int num_heads = hidden_size / head_size; + int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + int position_ids_format = -1; // Check position_ids input shapes @@ -91,23 +105,15 @@ Status CheckInputs(const T* input, } else { position_ids_format = 0; } + // Check cos_cache input shapes if (max_sequence_length != static_cast(cos_cache_dims[0])) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", "max_sequence_length, got ", cos_cache_dims[0]); } - if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", - "head_size / 2, got ", cos_cache_dims[1]); - } - // Check sin_cache input shapes - if (max_sequence_length != static_cast(sin_cache_dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", - "max_sequence_length, got ", sin_cache_dims[0]); - } - if ((head_size / 2) != static_cast(sin_cache_dims[1])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", - "head_size / 2, got ", sin_cache_dims[1]); + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); } // Set rotary parameters @@ -117,10 +123,11 @@ Status CheckInputs(const T* input, output_parameters->sequence_length = sequence_length; output_parameters->hidden_size = hidden_size; output_parameters->head_size = head_size; - output_parameters->num_heads = num_heads; + output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; output_parameters->transposed = transposed; + output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index a9703dc68dd2..602dd98d8c0d 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include + +#include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" @@ -10,9 +16,57 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#ifdef ORT_NEURAL_SPEED +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#endif + namespace onnxruntime { namespace contrib { +namespace { +int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + const auto accuracy_level = std::clamp(accuracy_level_attr, + static_cast(CompMostAccurate), + static_cast(CompLeastAccurate)); + +#if defined(ORT_NEURAL_SPEED) + + ORT_UNUSED_PARAMETER(nbits); + ORT_UNUSED_PARAMETER(block_size); + + // Neural Speed APIs already expect a minimum accuracy level so just use the given value. + return accuracy_level; + +#else // defined(ORT_NEURAL_SPEED) + + // Find a supported accuracy level that is not less accurate than the one given. + // CompMostAccurate is always supported with the fallback implementation. + // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. + int64_t effective_accuracy_level = accuracy_level; + for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { + const auto compute_type = static_cast(effective_accuracy_level); + if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { + break; + } + } + + return effective_accuracy_level; + +#endif // defined(ORT_NEURAL_SPEED) +} +} // namespace + +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -21,18 +75,31 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{info.GetAttr("accuracy_level")} { + accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { + const auto& node = info.node(); + auto input_defs = node.InputDefs(); + // g_idx + if (input_defs.size() > 4) { + act_order_ = true; + } + int32_t type; + if (input_defs.size() > 3 && GetType(*input_defs[3], type)) { + zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); - is_asym_ = info.GetInputCount() >= 4; +#ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; const Tensor* tensor_zero_point = nullptr; bool B_constant = info.TryGetConstantInput(1, &tensor_B); bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + is_asym_ = info.GetInputCount() >= 4; all_constant_ = B_constant && scale_constant; all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; +#endif } Status Compute(OpKernelContext* context) const override; @@ -49,31 +116,47 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; + bool act_order_{false}; + bool zero_point_is_not_quant_{false}; const int64_t accuracy_level_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; + +#if defined(ORT_NEURAL_SPEED) + bool is_asym_{false}; bool all_constant_{false}; + +#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; + if (act_order_ || zero_point_is_not_quant_) { + return Status::OK(); + } +#if defined(ORT_NEURAL_SPEED) + if (!all_constant_) { return Status::OK(); } - auto compt_type = static_cast(accuracy_level_); MLAS_THREADPOOL* pool = NULL; + if (nbits_ != 4) { + return Status::OK(); + } + auto comp_type = static_cast(accuracy_level_); + auto nbits = static_cast(nbits_); if (input_idx == 1) { - packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); if (packed_b_size_ == 0) return Status::OK(); auto qptr = tensor.Data(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); std::memset(packed_b_.get(), 0, packed_b_size_); - MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, false, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -82,8 +165,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 2 && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, !is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -92,8 +175,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 3 && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -101,12 +184,38 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } +#else // defined(ORT_NEURAL_SPEED) + + if (input_idx == 1) { + const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + +#endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; + +#if defined(ORT_NEURAL_SPEED) + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -120,16 +229,27 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } + +#else // defined(ORT_NEURAL_SPEED) + + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + +#endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); - if (packed_b_.get()) { +#if defined(ORT_NEURAL_SPEED) + + if (packed_b_) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -147,7 +267,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t N = static_cast(helper.N()); const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); + std::vector gemm_params(max_len); AllocatorPtr allocator; auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); @@ -158,22 +278,24 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); // workspace for activation process(dynamic quantization and others) auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); - MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), - thread_pool); + NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); return Status::OK(); } - const Tensor* b = ctx->Input(1); +#endif // defined(ORT_NEURAL_SPEED) + const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); - const uint8_t* b_data = b->Data(); + const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr; + const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr; + const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -181,8 +303,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { Tensor* y = ctx->Output(0, helper.OutputShape()); // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) + if (y->Shape().Size() == 0) { return Status::OK(); + } auto* y_data = y->MutableData(); @@ -192,53 +315,98 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) { - // number of bytes or elements between adjacent matrices - size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; - MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, - static_cast(K), static_cast(N), - b_data_matrix_stride_in_bytes, b_scale_matrix_stride, - &b_zero_point_matrix_stride_in_bytes); - - const size_t b_matrix_size = K * N; - - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; - - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; - data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; - data[i].QuantBZeroPoint = zero_points_data != nullptr - ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes - : nullptr; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + const bool has_single_b_matrix = + (!act_order_) && (!zero_point_is_not_quant_) && + std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); + + if (has_single_b_matrix) { + const auto compute_type = static_cast(accuracy_level_); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } + + const void* b_data = [&]() -> const void* { + if (packed_b_) { + return packed_b_.get(); + } + + const Tensor* b = ctx->Input(1); + return b->DataRaw(); + }(); + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data; + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + + return Status::OK(); } - - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool); - - return Status::OK(); } - const size_t ldb = helper.Ldb(true); + const Tensor* b = ctx->Input(1); + const uint8_t* b_data = b->Data(); + const size_t ldb = helper.Ldb(true); AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - zero_points_data, // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); - + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } + } #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); @@ -269,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX( kCpuExecutionProvider, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc new file mode 100644 index 000000000000..7e343d85f404 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/float16.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void Dequantize4BitsKernelReOrder( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, const int32_t* reorder_idx, int block_size, + int groups_per_threadblock, int total_groups, int out_rows, int out_cols, + int blockIdx_x, int threadIdx_x) { + const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size); + if (group_id >= total_groups) { + return; + } + const int scales_shape_x = (out_cols + block_size - 1) / block_size; + const int zero_point_shape_x = (scales_shape_x + 1) / 2; + + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1)); + + const int out_x = element_offset % (scales_shape_x * block_size); + const int out_y = element_offset / (scales_shape_x * block_size); + if (out_y >= out_rows || out_x >= out_cols) { + return; + } + T* output_i = output + out_y * out_cols + out_x; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int remain_x = std::min(8, out_cols - out_x); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1)); + for (int i = 0; i < remain_x; i++) { + int32_t rid = reorder_idx ? reorder_idx_with_off[i] : kb_idx; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + float zp_f = 8; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = *(zero_points + n_idx * scales_shape_x + rid); + } else { + uint8_t zp = 8; + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * MLFloat16(zp_f); + output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * zp_f; + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // reorder_idx for groupwise quantization + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* pool) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + constexpr int element_per_thread = 8; + int groups_per_threadblock = 256 * element_per_thread / block_size; + int groups_per_K = ceildiv(K, block_size); + int total_groups = N * groups_per_K; // total elemenets in quant_data + int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock)); + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(blocks_per_grid), + [&](std::ptrdiff_t block_id) { + for (int j = 0; j < 256; j++) { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } + }); +} + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const float* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h new file mode 100644 index 000000000000..5061ac5c800a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h new file mode 100644 index 000000000000..864abffd131f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h @@ -0,0 +1,45 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" + +namespace bestla { + +using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; +using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; +using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; +using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; +using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; +using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; + +template +using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; +template +using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; + +class ORTThreading : public parallel::IThreading { + public: + explicit ORTThreading(void* tp); + void parallel_for(const parallel::thread_func& func) const override; + void set_threads(int nthreads) override { + (void)(nthreads); + assert(0); + } + void sync() const override { assert(0); } + void* mTp; +}; + +} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc new file mode 100644 index 000000000000..73aaa4ae61a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -0,0 +1,438 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.cpp + +Abstract: + + GEMM template combinations of neural_speed. +--*/ + +#include "contrib_ops/cpu/quantization/neural_speed_defs.h" +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#include "core/platform/threadpool.h" + +using ThreadPool = onnxruntime::concurrency::ThreadPool; + +namespace bestla { + +ORTThreading::ORTThreading(void* tp) + : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} + +void ORTThreading::parallel_for(const parallel::thread_func& func) const { + ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, + [&](ptrdiff_t tid) { func(static_cast(tid)); }); +} + +template +static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + if (M <= 16) { + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->IsAsym()) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); + } + typename Launcher::Param args{gp, + {A, lda_, &reduceA}, + {B}, + {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), + reduceA.template RPtr(), reduceA.lda}, + {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } else { + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + static Launcher kernel; + typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } +} + +template +static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); +} + +template +static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(A); + (void)(N); + (void)(C); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using ProA = prologue_a::gemm::ActivationKBlockBaseF32; + static ProA proA; + if (B->IsAsym()) { + auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + // using ProA = prologue_a::gemm::ActivationBase; + return 0; + } +} + +template +static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + (void)(N); + (void)(lda); + (void)(ldc); + (void)(A); + (void)(C); + using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; + static ProA proA; + auto quanA = + proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); + return quanA.mSize; +} + +} // namespace bestla + +using namespace bestla; + +static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, + void* ThreadPool) { + GetCPUDevice(); + bestla::ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, + DataParams[i].ldc, WorkSpace, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && + BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && + BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + } + } + } + return size; +} + +template +static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { + static T proB; + auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), + BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto wptr = reinterpret_cast(ptr); + auto BlkSize = wptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} + +template +static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, + const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, + void* ThreadPool) { + static T proB; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + BTLA_DTYPE::BF16, IsAsym); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + proB.reduceWeight(&stor, &orth); + } +} + +static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + [[fallthrough]]; + default: + return 0; + } +} + +static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, + size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, + K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, + lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, + ldb, ThreadPool); + return true; + } + [[fallthrough]]; + default: + return false; + } +} + +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, + NS_SQNBIT_COMPUTE_TYPE CompType) { + if (nbits == 4) { + auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } + return 0; +} + +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + if (nbits == 4) { + if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +} + +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +} + +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +} + +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by bestla + return; + } +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h new file mode 100644 index 000000000000..ebcb3027a209 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.h + +Abstract: + + Prepack-weight GEMM APIs of neural_speed. +--*/ + +#pragma once + +#include +#include + +/** + * @brief Define compute types of block quantization + */ +enum NS_SQNBIT_COMPUTE_TYPE { + NSCompUndef = 0, /*!< undef */ + NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ + NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ + NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ + NSCompInt8 = 4 /*!< input int8, accumulator int32 */ +}; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, + NS_SQNBIT_COMPUTE_TYPE comp_type); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, + NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h new file mode 100644 index 000000000000..e7df50408ef0 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h @@ -0,0 +1,40 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-value" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" + +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4457) +#pragma warning(disable : 4189) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4702) +#pragma warning(disable : 4127) +#endif + +#include "bestla/bestla_prologue_a.h" +#include "bestla/bestla_wrapper.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc index ee9ae7167945..af163b6be702 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc @@ -1,4 +1,4 @@ -// Copyright (c Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "qlinear_util.h" diff --git a/onnxruntime/contrib_ops/cpu/tokenizer.cc b/onnxruntime/contrib_ops/cpu/tokenizer.cc index 1787fb9b3c4a..89371106b379 100644 --- a/onnxruntime/contrib_ops/cpu/tokenizer.cc +++ b/onnxruntime/contrib_ops/cpu/tokenizer.cc @@ -2,12 +2,29 @@ // Licensed under the MIT License. #include "core/common/common.h" +#include "core/common/inlined_containers.h" #include "core/common/narrow.h" +#include "core/common/safeint.h" #include "core/common/utf8_util.h" -#include "core/framework/tensor.h" #include "core/framework/op_kernel.h" +#include "core/framework/tensor.h" #include "re2/re2.h" +#ifdef _MSC_VER +#include +#define ORT_PMR_ALLOCATOR_SUPPORTED +#endif + +#include +#include +#include + +#ifdef ORT_PMR_ALLOCATOR_SUPPORTED +using SlicesVector = std::pmr::vector; +#else +using SlicesVector = std::vector; +#endif + namespace onnxruntime { namespace contrib { @@ -21,6 +38,10 @@ class Tokenizer final : public OpKernel { Status Compute(OpKernelContext* context) const override; private: + Status EstimateNumberOfTokens(gsl::span input_span, + size_t& max_tokens_per_row, + size_t& total_tokens_estimate) const; + Status CharTokenize(OpKernelContext* context, size_t N, size_t C, gsl::span input_dims) const; @@ -31,11 +52,14 @@ class Tokenizer final : public OpKernel { size_t N, size_t C, gsl::span input_dims) const; + void OutputData(gsl::span rows, + size_t max_tokens, size_t max_output_index, std::string* output_data) const; + bool mark_{false}; std::string pad_value_; - int64_t mincharnum_{0}; + size_t mincharnum_{0}; bool char_tokenezation_{false}; - std::vector> separators_; + InlinedVector> separators_; std::unique_ptr regex_; }; @@ -50,8 +74,8 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( contrib::Tokenizer); namespace tokenizer_details { -constexpr char start_text = 0x2; -constexpr char end_text = 0x3; +constexpr char kStartMarker = 0x2; +constexpr char kEndMarker = 0x3; } // namespace tokenizer_details using namespace tokenizer_details; @@ -65,9 +89,11 @@ Tokenizer::Tokenizer(const OpKernelInfo& info) : OpKernel(info) { status = info.GetAttr("pad_value", &pad_value_); ORT_ENFORCE(status.IsOK(), "attribute pad_value is not set"); - status = info.GetAttr("mincharnum", &mincharnum_); + int64_t mincharnum = 0; + status = info.GetAttr("mincharnum", &mincharnum); ORT_ENFORCE(status.IsOK(), "attribute mincharnum is not set"); - ORT_ENFORCE(mincharnum_ > 0, "attribute mincharnum must have a positive value"); + ORT_ENFORCE(mincharnum > 0, "attribute mincharnum must have a positive value"); + mincharnum_ = narrow(mincharnum); // Optional attributes either or std::vector separators; @@ -114,6 +140,25 @@ Tokenizer::Tokenizer(const OpKernelInfo& info) : OpKernel(info) { } } +Status Tokenizer::EstimateNumberOfTokens(gsl::span input_span, + size_t& max_tokens_per_row, size_t& total_tokens_estimate) const { + total_tokens_estimate = 0; + max_tokens_per_row = 0; + for (const auto& s : input_span) { + size_t utf8_chars = 0; // length in utf8 chars + if (!utf8_validate(reinterpret_cast(s.data()), s.size(), + utf8_chars)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input string contains invalid utf8 chars: " + s); + } + auto tokens = std::max(1, utf8_chars / mincharnum_); + total_tokens_estimate += tokens; + max_tokens_per_row = std::max(max_tokens_per_row, tokens); + } + + return Status::OK(); +} + Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C, gsl::span input_dims) const { // With char tokenzation we get as many tokens as the number of @@ -131,14 +176,13 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C, tokens)) { // Please do not include the input text in the error message as it could // be deemed as a compliance violation by teams using this operator - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Input string contains invalid utf8 chars"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input string contains invalid utf8 chars:", s); } max_tokens = std::max(max_tokens, tokens); ++curr_input; } - std::vector output_dims(input_dims.begin(), input_dims.end()); + TensorShapeVector output_dims(input_dims.begin(), input_dims.end()); // Check if we have no output due to apparently empty strings input. if (max_tokens == 0) { output_dims.push_back(0); @@ -160,31 +204,30 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C, while (curr_input != last) { const auto& s = *curr_input; if (mark_) { - (output_data + output_index)->assign(&start_text, 1); + output_data[output_index].assign(&kStartMarker, 1); ++output_index; } size_t tokens = 0; const size_t str_len = s.size(); for (size_t token_idx = 0; token_idx < str_len;) { size_t tlen = 0; - bool result = utf8_bytes(static_cast(s[token_idx]), tlen); + [[maybe_unused]] bool result = utf8_bytes(static_cast(s[token_idx]), tlen); assert(result); - (void)result; assert(token_idx + tlen <= str_len); - *(output_data + output_index) = s.substr(token_idx, tlen); + output_data[output_index] = s.substr(token_idx, tlen); ++output_index; token_idx += tlen; ++tokens; } if (mark_) { - (output_data + output_index)->assign(&end_text, 1); + output_data[output_index].assign(&kEndMarker, 1); ++output_index; } // Padding strings assert(tokens + (static_cast(mark_) * 2) <= max_tokens); const size_t pads = max_tokens - (static_cast(mark_) * 2) - tokens; for (size_t p = 0; p < pads; ++p) { - *(output_data + output_index) = pad_value_; + output_data[output_index] = pad_value_; ++output_index; } ++curr_input; @@ -192,37 +235,162 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C, return Status::OK(); } +namespace { + +// We use std::vector in this case, because InlinedVector::clear() is incompatible +// with std::vector. It also deallocates memory, which is not what we want. + +// The compiler we are using GCC on Linux and Clang on MacOS does not +// have the library that support C++17 PMR. So we are only using it on Windows +// since the problem is acute on the platform. + +#ifdef ORT_PMR_ALLOCATOR_SUPPORTED +///

+/// This class provides a thin abstraction over the std::pmr::monotonic_buffer_resource +/// If the allocated buffer is not enough, additional allocations are done using +/// new/delete. +/// +class MonotonicAllocatorWithDefault : public std::pmr::monotonic_buffer_resource { + public: + MonotonicAllocatorWithDefault(void* ptr, size_t size_in_bytes) + : monotonic_buffer_resource(ptr, size_in_bytes, std::pmr::get_default_resource()) {} + MonotonicAllocatorWithDefault(void* ptr, size_t size_in_bytes, std::pmr::memory_resource* upstream) + : monotonic_buffer_resource(ptr, size_in_bytes, upstream) {} +}; + +class MemoryAllocator { + public: + explicit MemoryAllocator(size_t num_of_slices) { + size_t allocated_size = 0; + void* ptr = AlignedAllocate(num_of_slices, allocated_size); + resource_.emplace(ptr, allocated_size); + } + + SlicesVector CreateVectorWithAllocator() { + return SlicesVector(&resource_.value()); + } + + SlicesVector& EmplaceBack(std::vector& rows) { + return rows.emplace_back(&resource_.value()); + } + + private: + /// + /// Pre-allocate memory for the tokens to reduce a number of individual + /// allocations and thus memory contention. + /// Used in conjunction with PMR memory allocatior + /// + /// number of objects of T + /// buffer holder + /// aligned allocated size + /// pointer to the buffer + void* AlignedAllocate(size_t num, size_t& allocated_size) { + constexpr size_t alignment = alignof(re2::StringPiece); + const size_t size_bytes = SafeInt(num) * sizeof(re2::StringPiece) + alignment; + buf_holder_ = std::make_unique(size_bytes); + void* ptr = buf_holder_.get(); + allocated_size = size_bytes; + return std::align(alignment, size_bytes, ptr, allocated_size); + } + + std::unique_ptr buf_holder_; + std::optional resource_; +}; + +#else + +class MemoryAllocator { + public: + explicit MemoryAllocator(size_t /* num_of_slices */) { + } + + SlicesVector CreateVectorWithAllocator() const { + return SlicesVector{}; + } + + SlicesVector& EmplaceBack(std::vector& rows) const { + return rows.emplace_back(); + } +}; + +#endif +} // namespace + +void Tokenizer::OutputData(gsl::span rows, + size_t max_tokens, [[maybe_unused]] size_t max_output_index, std::string* output_data) const { + size_t output_index = 0; + for (const auto& row : rows) { + [[maybe_unused]] size_t c_idx = output_index; + if (mark_) { + output_data[output_index++].assign(&kStartMarker, 1); + } + // Output tokens for this row + for (const auto& token : row) { + output_data[output_index++].assign(token.data(), token.length()); + } + if (mark_) { + output_data[output_index++].assign(&kEndMarker, 1); + } + const size_t pads = max_tokens - (static_cast(mark_) * 2) - row.size(); + for (size_t p = 0; p < pads; ++p) { + output_data[output_index++] = pad_value_; + } + assert(output_index <= max_output_index); + assert((output_index - c_idx) <= max_tokens); + } +} + Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx, size_t N, size_t C, gsl::span input_dims) const { using namespace re2; - std::vector> rows; - rows.reserve(N * C); + + auto X = ctx->Input(0); + const auto input_span = X->DataAsSpan(); + + // Let's estimate maximum number of tokens + // It is hard to estimate the number of separate characters that would not appear in the + // output. + size_t total_tokens_estimate = 0; + size_t max_tokens_per_row = 0; + ORT_RETURN_IF_ERROR(EstimateNumberOfTokens(input_span, max_tokens_per_row, total_tokens_estimate)); + // Add a scratch token vector allocation + total_tokens_estimate += max_tokens_per_row; + + // Pre-allocate memory for all tokens (StringPieces) + MemoryAllocator allocator(total_tokens_estimate); + + // Make sure the vectors below are destroyed before the allocator + const size_t vector_num = SafeInt(N) * C; + + std::vector rows; + rows.reserve(vector_num); + + // Re-use the same vector for each tokenization round + SlicesVector tokens = allocator.CreateVectorWithAllocator(); + tokens.reserve(max_tokens_per_row); // We do not constraint the search to match // on the beginning or end of the string - const RE2::Anchor anchor = RE2::UNANCHORED; + constexpr RE2::Anchor anchor = RE2::UNANCHORED; // Scan all strings and attempt to find separators in them // collect all the output tokens here size_t max_tokens = 0; - auto X = ctx->Input(0); - auto const input_data = X->Data(); - auto curr_input = input_data; - auto const last = input_data + N * C; - while (curr_input != last) { - const auto& s = *curr_input; + for (const auto& s : input_span) { size_t utf8_chars = 0; // length in utf8 chars - if (!utf8_validate(reinterpret_cast(s.data()), s.size(), - utf8_chars)) { + if (!utf8_len(reinterpret_cast(s.data()), s.size(), + utf8_chars)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Input string contains invalid utf8 chars: " + s); } - std::vector row{s}; + const auto expected_tokens = std::max(1, utf8_chars / mincharnum_); + auto& row = allocator.EmplaceBack(rows); + row.reserve(expected_tokens); + row.emplace_back(s); for (const auto& sep : separators_) { - std::vector tokens; for (const auto& text : row) { const auto end_pos = text.length(); size_t start_pos = 0; @@ -244,7 +412,7 @@ Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx, return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Match contains invalid utf8 chars: " + std::string{submatch}); } - if (utf8_chars >= size_t(mincharnum_)) { + if (utf8_chars >= mincharnum_) { tokens.emplace_back(text.data() + start_pos, token_len); } // Update starting position @@ -263,23 +431,32 @@ Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx, utf8_chars = 0; utf8_len(reinterpret_cast(text.data() + start_pos), trailing_len, utf8_chars); - if (utf8_chars >= size_t(mincharnum_)) { + if (utf8_chars >= mincharnum_) { tokens.emplace_back(text.data() + start_pos, trailing_len); } } } while (match); } // row - // Replace the row with the results of this tokenezation - row.swap(tokens); + + // We want to preserve the buffer for the next separator + // copying slices is cheaper than allocating new memory + if (!tokens.empty()) { + row = tokens; + tokens.clear(); + continue; + } + + // Nothing more to match for any remaining separators + row.clear(); + tokens.clear(); + break; } // separators_ max_tokens = std::max(max_tokens, row.size()); - rows.push_back(std::move(row)); - ++curr_input; } - std::vector output_dims(input_dims.begin(), input_dims.end()); + TensorShapeVector output_dims(input_dims.begin(), input_dims.end()); // Check if we have no output due to either empty input - // everything is a separator + // or everything is a separator if (max_tokens == 0) { output_dims.push_back(0); TensorShape output_shape(output_dims); @@ -297,39 +474,8 @@ Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx, auto output_tensor = ctx->Output(0, output_shape); auto const output_data = output_tensor->MutableData(); -#ifdef _DEBUG - const size_t max_output_index = N * C * max_tokens; -#endif - size_t output_index = 0; - curr_input = input_data; - for (auto& row : rows) { -#ifdef _DEBUG - size_t c_idx = output_index; -#endif - if (mark_) { - (output_data + output_index)->assign(&start_text, 1); - ++output_index; - } - // Output tokens for this row - for (const auto& token : row) { - (output_data + output_index)->assign(token.data(), token.size()); - ++output_index; - } - if (mark_) { - (output_data + output_index)->assign(&end_text, 1); - ++output_index; - } - const size_t pads = max_tokens - (static_cast(mark_) * 2) - row.size(); - for (size_t p = 0; p < pads; ++p) { - *(output_data + output_index) = pad_value_; - ++output_index; - } -#ifdef _DEBUG - assert(output_index <= max_output_index); - assert((output_index - c_idx) <= max_tokens); -#endif - ++curr_input; - } + OutputData(rows, max_tokens, narrow(output_shape.Size()), output_data); + return Status::OK(); } @@ -337,71 +483,78 @@ Status Tokenizer::TokenExpression(OpKernelContext* ctx, size_t N, size_t C, gsl::span input_dims) const { using namespace re2; - // Represents a token that will be output after - // first is the index, second is the size; - std::vector> tokens; - tokens.reserve(N * C); size_t max_tokens = 0; auto X = ctx->Input(0); - auto const input_data = X->Data(); - auto curr_input = input_data; - auto const last = input_data + N * C; + const auto input_span = X->DataAsSpan(); + + // Let's estimate maximum number of tokens + size_t total_tokens_estimate = 0; + size_t max_tokens_per_row = 0; + ORT_RETURN_IF_ERROR(EstimateNumberOfTokens(input_span, max_tokens_per_row, total_tokens_estimate)); + + // Pre-allocate memory for all tokens (StringPieces) + MemoryAllocator allocator(total_tokens_estimate); + + // Make sure the vectors below are destroyed before the allocator + const size_t vector_num = SafeInt(N) * C; + + // We use std::vector in this case, because InlinedVector::clear() is incompatible + // with std::vector. It also deallocates memory, which is not what we want. + std::vector rows; + rows.reserve(vector_num); // We do not constraint the search to match // on the beginning or end of the string - const RE2::Anchor anchor = RE2::UNANCHORED; - - while (curr_input != last) { - const auto& s = *curr_input; + constexpr RE2::Anchor anchor = RE2::UNANCHORED; + for (const auto& s : input_span) { size_t utf8_chars = 0; - if (!utf8_validate(reinterpret_cast(s.data()), s.size(), - utf8_chars)) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Input string contains invalid utf8 chars: " + s); - } - - tokens.emplace_back(); - auto& row = tokens.back(); - - StringPiece text(s); - const auto end_pos = s.length(); - size_t start_pos = 0; - StringPiece submatch; - - bool match = true; - do { - match = regex_->Match(text, start_pos, end_pos, anchor, &submatch, 1); - if (match) { - // Record pos/len - assert(submatch.data() != nullptr); - size_t match_pos = submatch.data() - s.data(); - assert(match_pos >= start_pos); - // Guard against empty match and make - // sure we make progress either way - auto token_len = submatch.length(); - utf8_chars = 0; - if (!utf8_len(reinterpret_cast(submatch.data()), token_len, utf8_chars)) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Match contains invalid utf8 chars: " + std::string{submatch}); - } - if (utf8_chars >= size_t(mincharnum_)) { - row.push_back(submatch); - start_pos = match_pos + token_len; - } else { - size_t bytes = 0; - utf8_bytes(*submatch.data(), bytes); - start_pos = match_pos + bytes; + utf8_len(reinterpret_cast(s.data()), s.size(), utf8_chars); + + auto& row = allocator.EmplaceBack(rows); + + if (utf8_chars >= mincharnum_) { + auto estimated_tokens = std::max(1, utf8_chars / mincharnum_); + row.reserve(estimated_tokens); + + StringPiece text(s); + const auto end_pos = s.length(); + size_t start_pos = 0; + StringPiece submatch; + + bool match = true; + do { + match = regex_->Match(text, start_pos, end_pos, anchor, &submatch, 1); + if (match) { + // Record pos/len + assert(submatch.data() != nullptr); + size_t match_pos = submatch.data() - s.data(); + assert(match_pos >= start_pos); + // Guard against empty match and make + // sure we make progress either way + auto token_len = submatch.length(); + utf8_chars = 0; + if (!utf8_len(reinterpret_cast(submatch.data()), token_len, utf8_chars)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Match contains invalid utf8 chars: " + std::string{submatch}); + } + if (utf8_chars >= mincharnum_) { + row.push_back(submatch); + start_pos = match_pos + token_len; + } else { + size_t bytes = 0; + utf8_bytes(*submatch.data(), bytes); + start_pos = match_pos + bytes; + } } - } - } while (match); + } while (match); + } max_tokens = std::max(max_tokens, row.size()); - ++curr_input; } // Check for empty output - std::vector output_dims(input_dims.begin(), input_dims.end()); + TensorShapeVector output_dims(input_dims.begin(), input_dims.end()); // Check if we have no output due to either empty input // everything is a separator if (max_tokens == 0) { @@ -421,40 +574,7 @@ Status Tokenizer::TokenExpression(OpKernelContext* ctx, auto output_tensor = ctx->Output(0, output_shape); auto const output_data = output_tensor->MutableData(); -#ifdef _DEBUG - const size_t max_output_index = N * C * max_tokens; -#endif - curr_input = input_data; - size_t output_index = 0; - for (const auto& row : tokens) { - assert(curr_input != last); -#ifdef _DEBUG - size_t c_idx = output_index; -#endif - if (mark_) { - (output_data + output_index)->assign(&start_text, 1); - ++output_index; - } - // Output tokens for this row - for (const auto& token : row) { - (output_data + output_index)->assign(token.data(), token.length()); - ++output_index; - } - if (mark_) { - (output_data + output_index)->assign(&end_text, 1); - ++output_index; - } - const size_t pads = max_tokens - (static_cast(mark_) * 2) - row.size(); - for (size_t p = 0; p < pads; ++p) { - *(output_data + output_index) = pad_value_; - ++output_index; - } -#ifdef _DEBUG - assert(output_index <= max_output_index); - assert((output_index - c_idx) <= max_tokens); -#endif - ++curr_input; - } + OutputData(rows, max_tokens, narrow(output_shape.Size()), output_data); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 56d950ca2f41..b18e122980ed 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -258,7 +258,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } @@ -397,12 +397,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 94547887d3a9..8f5cdc97f27e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -214,7 +214,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } @@ -404,12 +404,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 91b93a125ad7..af0904b7d6e4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } @@ -226,7 +226,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } @@ -500,12 +500,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 3962486d5b5e..93837e785b4a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; ORT_ENFORCE(logits_processor >= 0, "logits_processor shall be a non-negative integer, got ", logits_processor); -} + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + auto* temperature_tensor = context->Input(14); + if (temperature_tensor) { + if (temperature_tensor->IsDataType()) { + temperature = *temperature_tensor->Data(); + } else { + temperature = static_cast(*temperature_tensor->Data()); + } + } else { + temperature = 1.0f; + } + } +} void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) @@ -141,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 7e2e5b212922..0eccbe26605f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -50,11 +50,12 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con return beams_.back().score < current_score; } +template void BeamHypotheses::Output( int top_k, int max_length, - gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) - gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty + gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) + gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences ORT_ENFORCE(top_k <= beams_used_); @@ -67,7 +68,7 @@ void BeamHypotheses::Output( gsl::copy(item.hypothesis, target); if (!sequences_scores.empty()) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -181,21 +182,21 @@ void BeamSearchScorer::Process(ISequences& sequences, } } -void BeamSearchScorer::Finalize(ISequences& sequences, - gsl::span& final_beam_scores, - Tensor* output_sequences, - Tensor* output_sequence_scores) { - ORT_ENFORCE(output_sequences != nullptr); - +template +void OutputSequenceScores(BeamSearchScorer* scorer, + ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { // Finalize all open beam hypotheses and add to generated hypotheses. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; if (beam_hyp.done_) { continue; } - for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) { - size_t batch_beam_index = batch_index * num_beams_ + beam_index; + for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) { + size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index; float final_score = final_beam_scores[batch_beam_index]; auto final_tokens = sequences.GetSequence(narrow(batch_beam_index)); beam_hyp.Add(final_tokens, final_score); @@ -206,26 +207,59 @@ void BeamSearchScorer::Finalize(ISequences& sequences, gsl::span output = output_sequences->MutableDataAsSpan(); // Fill output sequences with pad token ID so that we do not need append it later. - std::fill_n(output.data(), output.size(), pad_token_id_); + std::fill_n(output.data(), output.size(), scorer->pad_token_id_); // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; + gsl::span sequence_scores; if (output_sequence_scores) { - sequence_scores = output_sequence_scores->MutableDataAsSpan(); + sequence_scores = output_sequence_scores->MutableDataAsSpan(); } // Select the best hypotheses according to number of sequences to return. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; - auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_, - num_return_sequences_ * max_length_); - gsl::span sequence_scores_buffer; + auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_, + scorer->num_return_sequences_ * scorer->max_length_); + gsl::span sequence_scores_buffer; if (!sequence_scores.empty()) - sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_); + sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_); + + beam_hyp.template Output(narrow(scorer->num_return_sequences_), narrow(scorer->max_length_), batch_output, + sequence_scores_buffer); + } +} + +void BeamSearchScorer::Finalize(ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + ORT_ENFORCE(output_sequences != nullptr); - beam_hyp.Output(narrow(num_return_sequences_), narrow(max_length_), batch_output, - sequence_scores_buffer); + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } +} + +void BeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + std::copy_n(final_scores.data(), final_scores.size(), target.data()); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + const float* src = final_scores.data(); + MLFloat16* dst = target.data(); + for (size_t i = 0; i < target.size(); i++) { + dst[i] = MLFloat16(src[i]); + } + } } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 94b6d340d9f4..dc92e8038a68 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -35,10 +35,11 @@ struct BeamHypotheses { bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) - gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) + gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring int beams_used_; // Number of elements used in beams_ @@ -60,13 +61,14 @@ struct BeamSearchScorer : IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return not_done_count_ == 0; } gsl::span GetNextScores() override { return next_beam_scores_; } gsl::span GetNextTokens() override { return next_beam_tokens_; } gsl::span GetNextIndicesCPU() override { return next_beam_indices_; } - private: size_t batch_size_; size_t num_beams_; size_t max_length_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index f6faf2e325f8..b1dd55eb20f3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -120,6 +120,9 @@ struct IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) = 0; + virtual void OutputScores(gsl::span& final_scores, + Tensor* output_scores) = 0; + virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here @@ -180,7 +183,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t no_speech_token = -1; + + // Token ids are defined below in the order that they appear in the tokenizer + int32_t translate_token_id = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index f39f090c78b0..c74e9160cc43 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,14 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -#ifdef DEBUG_GENERATION -template -void DumpScores(const char* name, const NextTokenScores& next_token_scores) { - std::cout << name << std::endl; - ORT_UNUSED_PARAMETER(next_token_scores); -} -#endif - // Interface for all scorers for beam search or beam sample. template MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_token_id) @@ -36,10 +28,6 @@ void MinLengthLogitsProcessor::Process(const ISequences* sequences, if (sequences->GetSequenceLength() < min_length_) { next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); } - -#ifdef DEBUG_GENERATION - DumpScores("MinLengthLogitsProcessor", next_token_scores); -#endif } template @@ -68,10 +56,6 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_); } } - -#ifdef DEBUG_GENERATION - DumpScores("RepetitionPenaltyLogitsProcessor", next_token_scores); -#endif } template @@ -109,10 +93,6 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = std::numeric_limits::lowest(); } } - -#ifdef DEBUG_GENERATION - DumpScores("NoRepeatNGramLogitsProcessor", next_token_scores); -#endif } template @@ -136,10 +116,6 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } } } - -#ifdef DEBUG_GENERATION - DumpScores("VocabMaskLogitsProcessor", next_token_scores); -#endif } template @@ -171,10 +147,6 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } } } - -#ifdef DEBUG_GENERATION - DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores); -#endif } template @@ -193,10 +165,6 @@ void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, *p /= temperature_; ++p; } - -#ifdef DEBUG_GENERATION - DumpScores("TemperatureLogitsProcessor", next_token_scores); -#endif } template @@ -218,10 +186,6 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, for (size_t i = 0; i < next_token_scores.scores.size(); i++) { *p -= presence_mask_[i] * presence_penalty_; } - -#ifdef DEBUG_GENERATION - DumpScores("PresencePenaltyLogitsProcessor", next_token_scores); -#endif } void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 4688ff272cee..231eb17d1a94 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> + int max_initial_timestamp_index) + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), + translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -258,21 +273,23 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } - -#ifdef DEBUG_GENERATION - DumpScores("TimestampLogitsProcessor", next_token_scores); -#endif } private: - int eos_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; + int translate_token_id_; + int transcribe_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; @@ -334,7 +351,15 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.decoder_start_token_id, + parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc index 1a86c5dbece5..6303858b9bd4 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.cc +++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc @@ -49,7 +49,6 @@ namespace cuda { UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain); -UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain); UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain); REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16) diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h index ab339f276c2b..fc9a71b0b7fa 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations.h @@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise { float beta_; }; -template -class Gelu final : public UnaryElementwise { - public: - Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - MAKE_FUNC_CTX_NULL() -}; - template class QuickGelu final : public UnaryElementwise { public: diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 0c856815fd43..36f33fbb24c1 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh { } }; -template -struct OP_Gelu : public CtxGelu { - __device__ __inline__ T operator()(const T& a) const { - return _Gelu(a); - } -}; - -template <> -struct OP_Gelu : public CtxGelu { - __device__ __inline__ half operator()(const half& a) const { - return static_cast(_Gelu(static_cast(a))); - } -}; - template struct OP_QuickGelu : public CtxQuickGelu { __device__ __inline__ T operator()(const T& a) const { diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h index 5d18283a395e..782d4bf59a5a 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h @@ -11,14 +11,12 @@ namespace cuda { typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine; typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus; typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh; -typedef onnxruntime::cuda::CtxNull CtxGelu; typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu; #define UNARY_CONTRIB_ACTIVATION_OPS() \ UNARY_ACTIVATION_OP_NAME(ScaledTanh) \ UNARY_ACTIVATION_OP_NAME(Affine) \ UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \ - UNARY_ACTIVATION_OP_NAME(Gelu) \ UNARY_ACTIVATION_OP_NAME(QuickGelu) #define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 626e4c0b87a3..9e6752b45186 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -640,7 +640,7 @@ void InvokeAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count, - bool do_rotary = false, int past_sequence_length = 0) { + bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) { assert(num_heads <= max_threads_per_block); if (do_rotary) { @@ -650,20 +650,20 @@ void InvokeAddBiasTranspose( if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } - if (qk_head_size != 64 && qk_head_size != 128) { - ORT_THROW("qk_head_size must be 64 or 128 for rotary attention"); + if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) { + ORT_THROW("rotary_embedding must be 32, 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention"); } const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length; - size_t smem_size = 2 * qk_head_size * sizeof(T); + size_t smem_size = 2 * rotary_embedding * sizeof(T); const dim3 grid(sequence_length, num_heads, batch_size); const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1); AddBiasTransposeQKV<<>>(total_matrix_count, input, biases, output, - qkv_add_bias, qk_head_size, qk_head_size, + qkv_add_bias, rotary_embedding, qk_head_size, step, format); #else ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__); @@ -727,7 +727,7 @@ void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size, - half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) { + half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) { total_matrix_count = std::max(num_matrices, total_matrix_count); if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) { const int H = qk_head_size / 4; @@ -753,7 +753,7 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose( stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, - qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length); + qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length); } } @@ -763,7 +763,7 @@ void LaunchAddBiasTranspose( const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const float* input, const float* biases, float* output, bool /*enable_half4*/, const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary, - int past_sequence_length) { + int rotary_embedding, int past_sequence_length) { total_matrix_count = std::max(num_matrices, total_matrix_count); if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) { const int H = qk_head_size / 4; @@ -789,7 +789,8 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose( stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, - qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length); + qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, + past_sequence_length); } } @@ -842,11 +843,11 @@ void InvokeAddBiasTransposeTrt( template <> void LaunchAddBiasTransposeTrt( - cudaStream_t stream, const int max_threads_per_block, - const int batch_size, const int sequence_length, - const int num_heads, const int head_size, - const float* biases, const float* query, const float* key, const float* value, float* output, - bool is_cross_attention, int kv_sequence_length) { + cudaStream_t /*stream*/, const int /*max_threads_per_block*/, + const int /*batch_size*/, const int /*sequence_length*/, + const int /*num_heads*/, const int /*head_size*/, + const float* /*biases*/, const float* /*query*/, const float* /*key*/, const float* /*value*/, float* /*output*/, + bool /*is_cross_attention*/, int /*kv_sequence_length*/) { ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index d903267c99a0..efc31db43bcd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -33,7 +33,7 @@ void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr, - int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0); + int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0); // Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format. // For self attention: diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index bf6431cf1afb..7a807342ad68 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -84,6 +84,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + // Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr std::vector bias_dims{weights->Shape().GetDims()[1]}; const TensorShape bias_shape{bias_dims}; @@ -251,7 +253,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 83c426e7e6ed..a93fdf74dc28 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,12 +58,12 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { - if (this->sequence_length != sequence_length) { +void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { + if (this->sequence_length != seq_length) { ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, sequence_length, stream); - this->sequence_length = sequence_length; + this->max_batch_size, seq_length, stream); + this->sequence_length = seq_length; } } @@ -213,9 +213,9 @@ Status FusedTrtCrossAttention( template <> Status FusedTrtCrossAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused cross attention does not support float tensor"); } @@ -276,9 +276,9 @@ Status FusedTrtSelfAttention( // Template Specialization for float type template <> Status FusedTrtSelfAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused attention does not support float tensor"); } @@ -313,10 +313,11 @@ Status FlashAttention( parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.v_head_size); + bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), true)); @@ -460,7 +461,8 @@ Status UnfusedAttention( total_sequence_length, sequence_length, qk_head_size, &alpha, data.k, qk_head_size, present_size_per_batch_k, data.q, qk_head_size, sequence_length * qk_head_size, - &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, + device_prop, parameters.use_tf32)); DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); @@ -513,7 +515,7 @@ Status UnfusedAttention( v_head_size, sequence_length, total_sequence_length, &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 5c65a30918ec..b843966d88e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.past_sequence_length); + 3, parameters.do_rotary, parameters.rotary_embedding, + parameters.past_sequence_length); } return Status::OK(); } @@ -230,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -278,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index db78722cc0e4..c12cb374d9ad 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -242,18 +242,18 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { using AlignedAK = AttentionKernel; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) -#pragma warning(disable : 6287) +#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect #endif // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && params.qk_head_size % AlignedAK::kAlignmentK == 0 && params.v_head_size % AlignedAK::kAlignmentV == 0; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { LaunchCutlassFmha(params); })); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 3f703ae3d05e..ceee17c2a2d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -273,13 +273,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data()), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (h2, h1)*(h1, S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(q_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // gemm_query_buffer in col-base: (h2, S*B) // calcualte k, v @@ -298,13 +298,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, @@ -318,13 +318,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(key->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } } else { @@ -342,13 +342,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { kv_sequence_length = cache_sequence_length; @@ -372,6 +372,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { device_prop, #ifdef USE_ROCM GetTuningContext(), +#else + UseTF32(), #endif context->GetComputeStream(), cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index 1dc22a9c8ea9..c0b199678918 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -17,7 +17,7 @@ Status DecoderQkvToContext( const cudaDeviceProp& device_prop, Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, + const size_t /*element_size*/, const int batch_size, const int sequence_length, const int kv_sequence_length, @@ -37,7 +37,8 @@ Status DecoderQkvToContext( T* workspace_buffer, T* output, T* new_key_cache, - T* new_value_cache) { + T* new_value_cache, + bool use_tf32) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int BN = batch_size * num_heads; const int BHN = BN * head_size; @@ -128,14 +129,14 @@ Status DecoderQkvToContext( kv_sequence_length, sequence_length, head_size, &alpha, key_cache, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, kv_sequence_length, sequence_length, head_size, &alpha, k, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } constexpr bool is_unidirectional = false; @@ -163,14 +164,14 @@ Status DecoderQkvToContext( head_size, sequence_length, kv_sequence_length, &one, value_cache, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, kv_sequence_length, &one, v, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } // scratch3 is BxNxSxH, transpose to output SxBxNxH @@ -180,6 +181,7 @@ Status DecoderQkvToContext( Status LaunchDecoderAttentionKernel( const cudaDeviceProp& device_prop, + bool use_tf32, Stream* stream, cublasHandle_t& cublas, const size_t element_size, @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } else { return DecoderQkvToContext( device_prop, @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h index 9db9ccb45e33..f9667a613e64 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -11,6 +11,7 @@ namespace cuda { Status LaunchDecoderAttentionKernel( const cudaDeviceProp& prop, // Device Properties + bool use_tf32, // Use TF32 Stream* stream, // ORT Stream cublasHandle_t& cublas, // Cublas handle const size_t element_size, // Element size of input tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 54aad9cbaf38..66c0aceaed1e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -70,6 +70,11 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* auto& device_prop = GetDeviceProp(); DecoderMaskedMultiHeadAttentionParams parameters; + + parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault( + attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); + + bool is_unidirectional = false; bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -84,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* num_heads_, mask_filter_value_, scale_, + is_unidirectional, past_present_share_buffer_, is_dmmha_packing, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index 69ed07101e64..07a6fbd60e17 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -52,6 +52,10 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont auto& device_prop = GetDeviceProp(); DecoderMaskedMultiHeadAttentionParams parameters; + + parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault( + attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -139,7 +143,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); // Update the q, k, and v buffers parameters.q = gemm_buffer.get(); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 892f5c181a60..8b8e4e267f89 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -4,9 +4,13 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cudnn_common.h" #include "fast_gelu.h" -#include "fast_gelu_impl.h" +#include "core/providers/cuda/tensor/gelu_impl.h" #include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "transformer_common.h" +#ifdef USE_ROCM +#include "contrib_ops/rocm/bert/elementwise.h" +#else +#include "contrib_ops/cuda/bert/transformer_common.h" +#endif namespace onnxruntime { namespace contrib { @@ -31,8 +35,10 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +#ifndef USE_ROCM const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); +#endif } template @@ -50,6 +56,13 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; +#ifdef USE_ROCM + return LaunchElementwiseKernel( + GetTuningContext(), context->GetComputeStream(), + reinterpret_cast(input->Data()), static_cast(input_length), + (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), + reinterpret_cast(output->MutableData())); +#else return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), @@ -58,6 +71,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(output->MutableData()), use_half2_); +#endif } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 3e642a70afef..26f3bd5a0392 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: +#ifndef USE_ROCM bool use_half2_; +#endif }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 33e7a3349477..9efb6f08e8e9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -344,52 +344,148 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio bool has_beams = params.cache_indir != nullptr && !params.is_cross_attention; const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_max_seq_length] : nullptr; - for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { - bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0); + if (!params.kv_data_in_flight) { + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0); - // The keys loaded from the key cache. - K_vec_k k_vec[K_VECS_PER_THREAD]; - if (ti < tlength) { - if (has_beams) { - const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; + // The keys loaded from the key cache. + K_vec_k k_vec[K_VECS_PER_THREAD]; + if (ti < tlength) { + if (has_beams) { + const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - k_vec[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); - } - } else { + k_vec[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); + } + } else { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - k_vec[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + k_vec[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + } } } - } - // Perform the dot product and normalize qk. - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q_vec, k_vec) * inv_sqrt_dh; + // Perform the dot product and normalize qk. + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k_vec) * inv_sqrt_dh; - // This is a deviation from FasterTransformer kernel implementation - // but this aligns with ORT's other Attention kernels which strives to - // mimic PyTorch when dealing with mask filter values - if (is_masked) { - qk += params.mask_filter_value; + // This is a deviation from FasterTransformer kernel implementation + // but this aligns with ORT's other Attention kernels which strives to + // mimic PyTorch when dealing with mask filter values + if (is_masked) { + qk += params.mask_filter_value; + } + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add_vec(qk, + reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]); + } + qk_max = fmaxf(qk_max, qk); + qk_smem[ti] = qk; + } } + } else { + // TODO(hasesh): Tune this value for different workloads. Currently, it is tuned for Whisper model + // Also tune it for different architectures. This works best for Whisper on 80GB A100. + constexpr int K_CACHE_DATA_LOAD_UNROLL = 4; - // Store the product to shared memory. There's one qk value per timestep. Update the max. - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add_vec(qk, - reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]); + for (int ti = ko; ti < ti_end; ti += (K_CACHE_DATA_LOAD_UNROLL * K_PER_ITER)) { + int is_masked[K_CACHE_DATA_LOAD_UNROLL]; + int beam_offset[K_CACHE_DATA_LOAD_UNROLL]; + int time_step[K_CACHE_DATA_LOAD_UNROLL]; + bool time_bounds_cond[K_CACHE_DATA_LOAD_UNROLL]; + +#pragma unroll + for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { + is_masked[k_unroll] = 1; + beam_offset[k_unroll] = 0; + time_step[k_unroll] = ti + k_unroll * K_PER_ITER; + time_bounds_cond[k_unroll] = (time_step[k_unroll] < tlength); + } + +#pragma unroll + for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { + if (time_bounds_cond[k_unroll] && params.mask != nullptr) { + is_masked[k_unroll] = params.mask[bi_total_seq_length + time_step[k_unroll]]; + } + } + + if (has_beams) { + int head_maxlength_headsize_prod = params.num_heads * params.max_sequence_length * head_size; + +#pragma unroll + for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { + if (time_bounds_cond[k_unroll]) { + beam_offset[k_unroll] = beam_indices[time_step[k_unroll]] * head_maxlength_headsize_prod; + } + } + } + + // The keys loaded from the key cache. + K_vec_k k_vec[K_CACHE_DATA_LOAD_UNROLL][K_VECS_PER_THREAD]; + +#pragma unroll + for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { + if (time_bounds_cond[k_unroll]) { + if (has_beams) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + time_step[k_unroll]; + + k_vec[k_unroll][ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[beam_offset[k_unroll] + jj * QK_ELTS_IN_16B]))); + } + } else { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + time_step[k_unroll]; + + k_vec[k_unroll][ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + } + } + } + } + + // Perform the dot product and normalize qk. + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk[K_CACHE_DATA_LOAD_UNROLL]; +#pragma unroll + for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { + qk[k_unroll] = Qk_dot::dot(q_vec, k_vec[k_unroll]) * inv_sqrt_dh; + } + +// This is a deviation from FasterTransformer kernel implementation +// but this aligns with ORT's other Attention kernels which strives to +// mimic PyTorch when dealing with mask filter values +#pragma unroll + for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { + if (time_bounds_cond[k_unroll] && is_masked[k_unroll] == 0) { + qk[k_unroll] += params.mask_filter_value; + } + } + +// Store the product to shared memory. There's one qk value per timestep. Update the max. +#pragma unroll + for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { + if (time_bounds_cond[k_unroll] && (tidx % THREADS_PER_KEY == 0)) { + if (params.relative_attention_bias != nullptr) { + qk[k_unroll] = add_vec(qk[k_unroll], + reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + time_step[k_unroll]]); + } + qk_max = fmaxf(qk_max, qk[k_unroll]); + qk_smem[time_step[k_unroll]] = qk[k_unroll]; + } } - qk_max = fmaxf(qk_max, qk); - qk_smem[ti] = qk; } } @@ -504,18 +600,80 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio V_vec_acum out; zero(out); - // Loop over the timesteps to compute the partial outputs. - for (int ti = vo; ti < tlength; ti += V_PER_ITER) { - // Fetch offset based on cache_indir when beam sampling - const int beam_src = has_beams ? params.cache_indir[bi_max_seq_length + ti] : 0; - const int beam_offset = has_beams ? beam_src * params.num_heads * params.max_sequence_length * head_size : 0; + if (!params.kv_data_in_flight) { + // Loop over the timesteps to compute the partial outputs. + for (int ti = vo; ti < tlength; ti += V_PER_ITER) { + // Fetch offset based on cache_indir when beam sampling + const int beam_src = has_beams ? params.cache_indir[bi_max_seq_length + ti] : 0; + const int beam_offset = has_beams ? beam_src * params.num_heads * params.max_sequence_length * head_size : 0; + + // Load the values from the cache. + V_vec_k v = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset + ti * head_size])); + + // Load the logits from shared memory. + T logit = logits_smem[ti]; + out = fma(logit, v, out); + } + } else { + // Loop over the timesteps to compute the partial outputs. + + // TODO(hasesh): Tune this value for different workloads. Currently, it is tuned for Whisper model + // Also tune it for different architectures. This works best for Whisper on 80GB A100. + constexpr int V_CACHE_DATA_LOAD_UNROLL = 8; + + for (int ti = vo; ti < tlength; ti += V_CACHE_DATA_LOAD_UNROLL * V_PER_ITER) { + int beam_src[V_CACHE_DATA_LOAD_UNROLL]; + int beam_offset[V_CACHE_DATA_LOAD_UNROLL]; + int time_step[V_CACHE_DATA_LOAD_UNROLL]; + bool time_bounds_cond[V_CACHE_DATA_LOAD_UNROLL]; + +#pragma unroll + for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) { + beam_src[v_unroll] = 0; + beam_offset[v_unroll] = 0; + time_step[v_unroll] = ti + v_unroll * V_PER_ITER; + time_bounds_cond[v_unroll] = (time_step[v_unroll] < tlength); + } + + int head_maxlength_headsize_prod = params.num_heads * params.max_sequence_length * head_size; + + if (has_beams) { +// Do the global memory read and corresponding compute in separate unrolled loops +#pragma unroll + for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) { + if (time_bounds_cond[v_unroll]) { + beam_src[v_unroll] = params.cache_indir[bi_max_seq_length + time_step[v_unroll]]; + } + } + +#pragma unroll + for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) { + if (time_bounds_cond[v_unroll]) { + beam_offset[v_unroll] = beam_src[v_unroll] * head_maxlength_headsize_prod; + } + } + } - // Load the values from the cache. - V_vec_k v = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset + ti * head_size])); + // Load the values from the V-cache and logits from shared memory. + V_vec_k v[V_CACHE_DATA_LOAD_UNROLL]; + T logits[V_CACHE_DATA_LOAD_UNROLL]; - // Load the logits from shared memory. - T logit = logits_smem[ti]; - out = fma(logit, v, out); +// Do the global memory read and compute in separate unrolled loops +#pragma unroll + for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) { + if (time_bounds_cond[v_unroll]) { + v[v_unroll] = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset[v_unroll] + time_step[v_unroll] * head_size])); + logits[v_unroll] = logits_smem[time_step[v_unroll]]; + } + } + +#pragma unroll + for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) { + if (time_bounds_cond[v_unroll]) { + out = fma(logits[v_unroll], v[v_unroll], out); + } + } + } } // One group of threads computes the product(s) for the current timestep. diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 4b408dafa2d8..1a17757d1ec2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -22,6 +22,12 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { bool is_cross_attention = false; bool is_packed_qkv = false; + // Useful to better use global memory bandwidth on certain CUDA architectures. + // Turned off by default for now until we fully understand performance implications + // for all types of workloads. + // Can be turned on by appropriate environment variable (see attention_common.h). + bool kv_data_in_flight = false; + void* q = nullptr; void* q_bias = nullptr; @@ -62,4 +68,4 @@ void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cud } // namespace cuda } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 76190aad68fd..0f58a74c4d2f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -35,6 +35,7 @@ void set_params_fprop(Flash_fwd_params& params, void* softmax_lse_d, float softmax_scale, bool is_causal, + bool is_bf16, bool kv_bsnh = true, int window_size_left = -1, int window_size_right = -1) { @@ -44,7 +45,7 @@ void set_params_fprop(Flash_fwd_params& params, params.v_ptr = v; params.o_ptr = out; - params.is_bf16 = false; + params.is_bf16 = is_bf16; // All stride are in elements, not bytes. if (kv_bsnh) { @@ -240,6 +241,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_k, float softmax_scale, bool is_causal, + bool is_bf16, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded @@ -264,6 +266,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, + is_bf16, kv_bsnh, local_window_size, is_causal ? 0 : -1); @@ -306,7 +309,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int max_seqlen_q, int max_seqlen_k, float softmax_scale, - bool is_causal) { + bool is_causal, + bool is_bf16) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); @@ -326,6 +330,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, + is_bf16, true, -1, is_causal ? 0 : -1); @@ -350,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -364,22 +371,23 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, int seqlen_k_new, + int rotary_dim, const float softmax_scale, bool is_causal, + bool is_bf16, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size) { - // if (seqlen_q == 1) { - // is_causal = false; - // } // causal=true is the same as causal=false in this case - + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + // In kv-cache case, seqlen_k_max as kv sequence length Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -394,20 +402,30 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, + is_bf16, past_bsnh, local_window_size, is_causal ? 0 : -1); params.dprops = &dprops; - if (k != nullptr && v != nullptr) { + if (k_new != nullptr && v_new != nullptr) { params.seqlen_knew = seqlen_k_new; - params.knew_ptr = k; - params.vnew_ptr = v; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; // All stride are in elements, not bytes. - params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.knew_row_stride = num_heads_k * head_size; - params.vnew_row_stride = num_heads_k * head_size; + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; } else { @@ -427,6 +445,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.cu_seqlens_k = static_cast(seqlens_k_); } + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = rotary_dim; + } + params.num_splits = num_splits; if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { params.softmax_lseaccum_ptr = softmax_lse_accum; @@ -437,7 +462,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index efc1f565c4fa..24891bcc4d49 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -51,6 +51,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_k, float softmax_scale, bool is_causal, + bool is_bf16, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded @@ -73,7 +74,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int max_seqlen_q, int max_seqlen_k, float softmax_scale, - bool is_causal); + bool is_causal, + bool is_bf16); Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -85,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -92,13 +96,17 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, int seqlen_k_new, + int rotary_dim, const float softmax_scale, bool is_causal, + bool is_bf16, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size = -1); + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000000..431eb2bd69de --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000000..0cb48272dec3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000000..142e922f7103 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000000..2142b1c34311 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000000..751363184e23 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000000..ebf023643597 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000000..166bb2a0072f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000000..c8760b8168db --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 000000000000..3ca416f6580c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu new file mode 100644 index 000000000000..3e37c9af80b3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 000000000000..79606fd05b4d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu new file mode 100644 index 000000000000..0b0d9384709c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 000000000000..8eb5c8f84544 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 000000000000..0141f27aa199 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 000000000000..489d2d47bc70 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 000000000000..bcfd47e76b99 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h index 05ac2476690c..5b70988949bb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h @@ -23,11 +23,15 @@ } \ }() -#define FP16_SWITCH(COND, ...) \ - [&] { \ - assert(COND); \ - using elem_type = cutlass::half_t; \ - return __VA_ARGS__(); \ +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ }() #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc new file mode 100644 index 000000000000..49bf79188efd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/bert/gemma_rotary_emb.h" +#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T, U) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GemmaRotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ + GemmaRotaryEmbedding); + +REGISTER_KERNEL_TYPED(MLFloat16, float) + +template +GemmaRotaryEmbedding::GemmaRotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { +} + +template +Status GemmaRotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* emb = context->Input(0); + const Tensor* q = context->Input(1); + const Tensor* q_rot = context->Input(2); + const Tensor* k = context->Input(3); + const Tensor* k_rot = context->Input(4); + + const auto& emb_dims = emb->Shape().GetDims(); + const auto& q_dims = q->Shape().GetDims(); + int batch_size = static_cast(q_dims[0]); + int num_heads = static_cast(q_dims[1]); + int seq_len = static_cast(q_dims[2]); + int dim = static_cast(q_dims[3]); + + // q_dims should be [batch_size, num_heads, seq_len, dim] + // emb_dims should be [batch_size, seq, dim] + ORT_ENFORCE(emb_dims.size() == 3, "emb_dims should be 3D"); + ORT_ENFORCE(q_dims.size() == 4, "emb_dims should be 4D"); + ORT_ENFORCE(emb_dims[0] == batch_size, "emb_dims[0] should match q_dims[0]"); + ORT_ENFORCE(emb_dims[1] == seq_len, "emb_dims[1] should match q_dims[2]"); + ORT_ENFORCE(emb_dims[2] == dim, "emb_dims[2] should match q_dims[3]"); + + Tensor* output1 = context->Output(0, q_dims); + Tensor* output2 = context->Output(1, q_dims); + + typedef typename ToCudaType::MappedType CudaT; + typedef typename ToCudaType::MappedType CudaU; + return LaunchGemmaRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output1->template MutableData()), + reinterpret_cast(output2->template MutableData()), + reinterpret_cast(emb->template Data()), + reinterpret_cast(q->template Data()), + reinterpret_cast(q_rot->template Data()), + reinterpret_cast(k->template Data()), + reinterpret_cast(k_rot->template Data()), + batch_size, + num_heads, + seq_len, + dim); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h new file mode 100644 index 000000000000..e63236d2ab7c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::cuda::CudaKernel; +using onnxruntime::cuda::ToCudaType; + +template +class GemmaRotaryEmbedding final : public CudaKernel { + public: + GemmaRotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu new file mode 100644 index 000000000000..9e00ca713a44 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.cu @@ -0,0 +1,104 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ +/* +Kernel implementation for Gamma rotary embeddings. +This implementation below subgraph + (emb) + / \ + / \ + Sin Cos + | | + Cast Cast + | | + Unsqueeze Unsqueeze + \/ \/ \/ \/ + Mul Mul Mul Mul + \ / \ / + Add Add + | | + (output1) (output2) +*/ + +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/gemma_rotary_emb_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock; + +template +__global__ void GemmaRotaryEmb( + T* output1, + T* output2, + const U* emb, + const T* q, + const T* q_rot, + const T* k, + const T* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim) { + + const int qk_idx = blockIdx.x * blockDim.x + threadIdx.x; + // index [i, j, k, l] -> [i, k, l] + const int emb_idx = qk_idx / (num_heads * seq_len * dim) * (seq_len * dim) + qk_idx % (seq_len * dim); + if (qk_idx < batch_size * num_heads * seq_len * dim) { + T sin_val = static_cast(sin(emb[emb_idx])); + T cos_val = static_cast(cos(emb[emb_idx])); + output1[qk_idx] = q[qk_idx] * cos_val + q_rot[qk_idx] * sin_val; + output2[qk_idx] = k[qk_idx] * cos_val + k_rot[qk_idx] * sin_val; + } +} + +template +Status LaunchGemmaRotaryEmbeddingKernel( + cudaStream_t stream, + T* output1, + T* output2, + const U* emb, + const T* q, + const T* q_rot, + const T* k, + const T* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim + ) { + int blocksPerGrid = static_cast(ceil(float(batch_size * num_heads * seq_len * dim) / kThreadsPerBlock)); + + GemmaRotaryEmb<<>>( + output1, output2, + emb, q, q_rot, k, k_rot, + batch_size, num_heads, seq_len, dim + ); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchGemmaRotaryEmbeddingKernel( + cudaStream_t stream, + half* output1, + half* output2, + const float* emb, + const half* q, + const half* q_rot, + const half* k, + const half* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h new file mode 100644 index 000000000000..c57fbe0d7e92 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gemma_rotary_emb_impl.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchGemmaRotaryEmbeddingKernel( + cudaStream_t stream, + T* output1, + T* output2, + const U* emb, + const T* q, + const T* q_rot, + const T* k, + const T* k_rot, + const int batch_size, + const int num_heads, + const int seq_len, + const int dim); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 93892169f6c7..112f609d4659 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -34,6 +34,7 @@ namespace cuda { // REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) @@ -46,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_num_heads_ = static_cast(kv_num_heads); is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -61,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_memory_efficient_attention_ = true; #endif + if (!disable_flash_attention_) { + zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); + } } template @@ -72,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); + const Tensor* sin_cache = context->Input(8); auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -83,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { value, past_key, past_value, + cos_cache, + sin_cache, ¶meters, num_heads_, kv_num_heads_, @@ -92,7 +102,18 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); @@ -149,18 +170,31 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (use_memory_efficient_attention && needs_buff) { kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); } + size_t rotary_buffer_bytes = 0; + if (use_memory_efficient_attention && do_rotary_) { + rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size; + rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; + } size_t fmha_buffer_bytes = 0; if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); } + size_t unpacked_qkv_bytes = 0; + if (use_memory_efficient_attention && parameters.is_packed_qkv) { + unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); + } auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, context->GetComputeStream()); #else constexpr bool use_memory_efficient_attention = false; auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif // seqlens_k buffer @@ -181,8 +215,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.output = reinterpret_cast(output->MutableData()); @@ -228,6 +262,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } + if (unpacked_qkv_buffer != nullptr) { + data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); + } + if (rotary_buffer != nullptr) { + data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); + } + // Rotary Embedding + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 54a8127e29e7..15573ece166f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel { int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention int local_window_size_; + bool is_unidirectional_; bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) + IAllocatorUniquePtr zeros_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 2cb9955807f2..1a7c3fcea3fa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query, bool is_past_bsnh, float scale) { // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr // no packing for q/k/v: - // query (Q) : (B, S, D) - // key (K) : (B, S, D_kv) - // value (V) : (B, S, D_kv) + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - + const bool is_packed_qkv = key == nullptr; const auto& query_dims = query->Shape().GetDims(); - const auto& key_dims = key->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query, int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); int q_hidden_size = static_cast(query_dims[2]); - int head_size = static_cast(q_hidden_size) / num_heads; + int head_size = 0; - int kv_hidden_size = static_cast(key_dims[2]); + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + + // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); @@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (static_cast(sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); - } - - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - // Check seqlens_k tensor (holding past seqlen for token gen) const auto& seqlens_dim = seqlens_k->Shape().GetDims(); if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { @@ -180,6 +205,42 @@ Status CheckInputs(const Tensor* query, int total_sequence_length = *((*total_seqlen).template Data()); int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + int rotary_dim = 0; + if (cos_cache != nullptr && sin_cache != nullptr) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] < present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 should be of max_sequence_length."); + } + if (sin_dims[0] < present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 should be of max_sequence_length."); + } + if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (cos_dims[1] != sin_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache dimension 1 must be the same."); + } + rotary_dim = static_cast(cos_dims[1] * 2); + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + bool is_prompt = sequence_length != 1; if (parameters != nullptr) { @@ -190,9 +251,11 @@ Status CheckInputs(const Tensor* query, output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; - output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->head_size = head_size; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; + output_parameters->rotary_dim = rotary_dim; + output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; @@ -208,6 +271,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -220,7 +285,7 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index b22ccb68c1e7..f519be1c9714 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -42,6 +42,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include using namespace onnxruntime::cuda; @@ -150,10 +151,13 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.sequence_length; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; @@ -170,14 +174,14 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -190,7 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -199,7 +203,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -280,6 +284,8 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, template Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, const int max_threads_per_block) { const int batch_size = parameters.batch_size; @@ -299,12 +305,12 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, const dim3 block(H, kv_num_heads, 1); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { @@ -315,14 +321,14 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, H, kv_num_heads, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlaceLarge<<>>(present_sequence_length, H, kv_num_heads, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } @@ -441,7 +447,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } - __global__ void PastToTotalSeqlen(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int add_seqlen) { @@ -451,7 +456,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int /*threads_per_block*/) { if (parameters.is_prompt) { return Status::OK(); } @@ -468,6 +473,83 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } +// Kernel to unpack qkv from packed qkv +template +__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, + const int batch_size) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + int d = (num_heads + 2 * kv_num_heads) * head_size; + const int qkv_size = batch_size * sequence_length * d; + const int q_size = num_heads * head_size; + const int k_size = kv_num_heads * head_size; + if (tid < qkv_size) { + int batch = tid / (d * sequence_length); + int sequence = (tid % (d * sequence_length)) / d; + int offset = tid % d; + if (offset < q_size) { + int unpacked_i = batch * sequence_length * num_heads * head_size + sequence * num_heads * head_size + offset; + unpacked_q[unpacked_i] = packed_qkv[tid]; + } else if (offset < q_size + k_size) { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size); + unpacked_k[unpacked_i] = packed_qkv[tid]; + } else { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size - k_size); + unpacked_v[unpacked_i] = packed_qkv[tid]; + } + } +} + +// Unpack packed qkv +template +Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, + cudaStream_t stream, const int max_threads_per_block) { + const int threads = max_threads_per_block; + const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; + UnpackQKV<<>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size); + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return CUDA_CALL(cudaGetLastError()); +} + ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -482,89 +564,64 @@ Status FlashAttention( const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.sequence_length; - const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; + bool is_causal = true; + bool is_bf16 = std::is_same::value; void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - - bool is_causal = true; + void* key; + void* value; - // Note: seqlens_k is past sequence length for flash - if (parameters.is_prompt) { - // Launch kernel to copy seqlen - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); + } else { + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; } void* seqlens_k = reinterpret_cast(data.seqlens_k); - - if (parameters.kv_share_buffer) { - // Share buffer case - if (data.past_key == nullptr || data.past_key != data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv shall share the same tensor when kv_share_buffer is on."); - } - - if (parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); - key = nullptr; - value = nullptr; - seqlens_k = reinterpret_cast(data.seqlens_k_total); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); - } else { - // Not share buffer case - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (data.past_key != nullptr && data.past_key == data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv share the same tensor but kv_share_buffer is not on."); - } - - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); - - if (!parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; } - - seqlens_k = reinterpret_cast(data.seqlens_k_total); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); - DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); - DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); - DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, 0, - scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, + true)); } + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -589,15 +646,62 @@ Status EfficientAttention( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const void* query = reinterpret_cast(data.query); - const void* key = reinterpret_cast(data.key); - const void* value = reinterpret_cast(data.value); + const void* query; + const void* key; + const void* value; + + if (!parameters.is_packed_qkv) { + query = reinterpret_cast(data.query); + key = reinterpret_cast(data.key); + value = reinterpret_cast(data.value); + } else { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q = reinterpret_cast(data.unpacked_qkv_buffer); + auto k = reinterpret_cast(data.unpacked_qkv_buffer + q_size); + auto v = reinterpret_cast(data.unpacked_qkv_buffer + q_size + k_size); + ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast(data.query), q, k, v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size, stream, max_threads_per_block)); + query = reinterpret_cast(q); + key = reinterpret_cast(k); + value = reinterpret_cast(v); + } + + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q_buffer = reinterpret_cast(data.rotary_buffer); + auto k_buffer = q_buffer + q_size; + auto position_ids_buff = reinterpret_cast(k_buffer + k_size); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, data.seqlens_k, position_ids_buff, stream, + max_threads_per_block)); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, q_buffer, reinterpret_cast(query), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + query = reinterpret_cast(q_buffer); + key = reinterpret_cast(k_buffer); + } if (parameters.is_prompt) { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + batch_size); } else { ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } @@ -609,7 +713,7 @@ Status EfficientAttention( "Past and present kv shall share the same tensor when kv_share_buffer is on."); } // Concatenate new kv in place - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, stream, max_threads_per_block)); } else { // Not share buffer case if (data.past_key != nullptr && data.past_key == data.present_key) { @@ -617,7 +721,7 @@ Status EfficientAttention( "Past and present kv share the same tensor but kv_share_buffer is not on."); } // Copy past and concat new KV to present buffer - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, key, value, stream, max_threads_per_block)); } // Ungroup if grouped, otherwise use present kv directly @@ -670,7 +774,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); @@ -682,7 +785,7 @@ Status EfficientAttention( template Status QkvToContext( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, + cublasHandle_t& /*cublas*/, Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data) { @@ -713,6 +816,15 @@ template Status QkvToContext( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index de32d7ea9316..32341afa0e3f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -21,6 +21,8 @@ struct GroupQueryAttentionData { const T* past_key = nullptr; const T* past_value = nullptr; int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; @@ -28,6 +30,8 @@ struct GroupQueryAttentionData { int* seqlens_k_total = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; + T* unpacked_qkv_buffer = nullptr; + T* rotary_buffer = nullptr; T* k = nullptr; T* v = nullptr; // Output Tensors diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc index e556ae4a490e..9c5d0e9834f6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc @@ -136,7 +136,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, weights_data, n, input_data, k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); } else { // q const CudaT* q_weight = weights_data; @@ -145,7 +145,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, q_weight, n, input_data, k, - &zero, q_data, n, device_prop)); + &zero, q_data, n, device_prop, UseTF32())); // k const CudaT* k_weight = q_weight + static_cast(hidden_size) * hidden_size; CudaT* k_data = q_data + static_cast(batch_size) * sequence_length * hidden_size; @@ -153,7 +153,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, k_weight, n, input_data, k, - &zero, k_data, n, device_prop)); + &zero, k_data, n, device_prop, UseTF32())); // v const CudaT* v_weight = k_weight + static_cast(hidden_size) * hidden_size; @@ -162,7 +162,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, v_weight, n, input_data, k, - &zero, v_data, n, device_prop)); + &zero, v_data, n, device_prop, UseTF32())); } // Wait for async copy of batch_global_num @@ -195,7 +195,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(global_weights->Data()), n, input_data, k, - &zero, global_gemm_buffer, n, device_prop)); + &zero, global_gemm_buffer, n, device_prop, UseTF32())); } else { // global q const CudaT* global_q_weight = global_weights_data; @@ -205,7 +205,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_q_weight, n, input_data, k, - &zero, global_q, n, device_prop)); + &zero, global_q, n, device_prop, UseTF32())); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, @@ -226,7 +226,8 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { hidden_size, // ldc static_cast(max_num_global) * hidden_size, // strideC batch_size, // batch count - device_prop)); + device_prop, + UseTF32())); } // global k const CudaT* global_k_weight = global_weights_data + static_cast(hidden_size) * hidden_size; @@ -235,7 +236,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_k_weight, n, input_data, k, - &zero, global_k, n, device_prop)); + &zero, global_k, n, device_prop, UseTF32())); // global v const CudaT* global_v_weight = global_k_weight + static_cast(hidden_size) * hidden_size; @@ -244,7 +245,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_v_weight, n, input_data, k, - &zero, global_v, n, device_prop)); + &zero, global_v, n, device_prop, UseTF32())); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index f00239460071..c9c66b73b3e9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -1005,7 +1005,6 @@ Status LaunchLongformerAttentionKernel( bool disable_compact_memory, bool use_merged_qkv_weights, bool use_half4) { - CublasMathModeSetter helper(device_prop, cublas, CUBLAS_TENSOR_OP_MATH); size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ebd66d8c6528..2ef011cdd9a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); disable_fused_self_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); @@ -92,6 +94,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, @@ -105,6 +109,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { num_heads_, mask_filter_value_, scale_, + is_unidirectional_, false, // past_present_share_buffer false, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index c162f7133cc1..86a32c92ce00 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel { int num_heads_; // number of attention heads float mask_filter_value_; float scale_; + bool is_unidirectional_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index ec8b1d051b3d..e4b90727121c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -268,6 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -303,17 +304,17 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { int m = parameters.token_count; int n = parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size; int k = parameters.input_hidden_size; - gemm_buffer = this->GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + gemm_buffer = this->template GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); cublasHandle_t cublas = this->GetCublasHandle(context); // Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x bias - // The bias part is not included here since we fuse bias, transpose and output 3 matrice into one cuda kernel. + // The bias part is not included here since we fuse bias, transpose and output 3 matrices into one cuda kernel. CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias @@ -327,7 +328,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { false, use_memory_efficient_attention, no_qkv_workspace); - auto work_space = this->GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = this->template GetScratchBuffer(workSpaceSize, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; PackedAttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 3b5232083940..a84a310b46ca 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -440,7 +440,7 @@ Status LaunchTransposeRemovePadding( template Status FusedScaledDotProductAttention( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedAttentionData& data) { @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length); @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 1b026e64778e..00ab32886112 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, @@ -297,7 +298,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co use_flash_attention, use_memory_efficient_attention, no_qkv_workspace); - auto work_space = this->GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = this->template GetScratchBuffer(workSpaceSize, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; PackedMultiHeadAttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 8a508241d80b..982c7eaa2cb2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -381,7 +381,7 @@ void InvokeTranspose( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, - AttentionQkvFormat source_format, AttentionQkvFormat target_format, + [[maybe_unused]] AttentionQkvFormat source_format, AttentionQkvFormat target_format, const int32_t* token_offset, int32_t token_count, cudaStream_t stream) { if (key != nullptr && value != nullptr) { @@ -551,7 +551,7 @@ void LaunchTranspose( template Status FusedAttentionTrt( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedMultiHeadAttentionData& data) { @@ -639,7 +639,8 @@ Status FlashAttention( sequence_length, sequence_length, scale, - false // is causal + false, // is causal + false // is bf16 )); DUMP_TENSOR_INIT(); @@ -774,7 +775,7 @@ Status UnfusedAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); // Q, K and V are ready now DUMP_TENSOR_INIT(); @@ -807,7 +808,7 @@ Status UnfusedAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output TxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 92ba808dd85c..05f55d9106d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -200,7 +200,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c D, BNS, head_size, &one, reinterpret_cast(weight_tensor.template Data()), (int)D, reinterpret_cast(workspace.get()), (int)head_size, - &zero, gemm_output, ld_gemm_output, device_prop)); + &zero, gemm_output, ld_gemm_output, device_prop, UseTF32())); auto status = LaunchGatedRelativePositionBiasKernel( device_prop, stream, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 2d12e975d88d..ab7479f2938f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -29,10 +29,13 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); } @@ -48,6 +51,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -71,13 +76,12 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.num_heads, parameters.head_size, + parameters.rotary_embedding_dim, parameters.max_sequence_length, parameters.position_ids_format, interleaved, device_prop.maxThreadsPerBlock, parameters.transposed); - - return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h index 6dab2ad56749..d52f61d67044 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -19,6 +19,8 @@ class RotaryEmbedding final : public CudaKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index e1b83bd8caf5..bd50e8646c4c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -7,9 +7,9 @@ Licensed under the MIT License. Kernel implementation for rotary embeddings. */ -#include -#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/rotary_embedding_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include using namespace onnxruntime::cuda; @@ -18,141 +18,120 @@ namespace contrib { namespace cuda { template -__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH - const T* input, // BxSxNxH - const T* cos_cache, // Mx(H/2) - const T* sin_cache, // Mx(H/2) - const int64_t* position_ids, // (1) or BxS - const int sequence_length, - const int num_heads, - const int head_size, - const int position_ids_format, - const bool interleaved, - const int batch_stride, - const int seq_stride, +__global__ void RotaryEmbeddingBSNH(T *output, // BxSxNxH + const T *input, // BxSxNxH + const T *cos_cache, // Mx(H/2) + const T *sin_cache, // Mx(H/2) + const int64_t *position_ids, // (1) or BxS + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int position_ids_format, + const bool interleaved, const int batch_stride, const int seq_stride, const int head_stride) { - // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length - // Use .x in innermost loop to access global memory efficiently - - const int b = blockIdx.z; - const int s = blockIdx.y; - const int n = blockIdx.x; - - const int i = threadIdx.x; - - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - - const T* input_data = input + block_offset; - T* output_data = output + block_offset; - - // Cache is (M, H/2) - const int half_head_size = head_size / 2; - const int position_id = (position_ids_format == 0) ? \ - static_cast(position_ids[0]) + s \ - : static_cast(position_ids[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; - const T* cos_data = cos_cache + cache_offset; - const T* sin_data = sin_cache + cache_offset; - - int cache_idx = 0; - T sign = 0; - int j = 0; - if (interleaved) { - cache_idx = (i / 2) % half_head_size; - sign = (i % 2 == 0) ? -1 : 1; - j = (i % 2 == 0) ? i+1 : i-1; // i - sign - } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? -1 : 1; - j = (i + half_head_size) % head_size; - } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.y; + const int s = blockIdx.x; + const int n = blockIdx.z; + + const int i = threadIdx.x; + + if (i >= head_size) { + return; + } + + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + + const T *input_data = input + block_offset; + T *output_data = output + block_offset; + + if (i >= rotary_embedding_dim) { + output_data[i] = input_data[i]; + return; + } + + // Cache is (M, H/2) + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; + const int position_id = (position_ids_format == 0) ? static_cast(position_ids[0]) + s + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_rotary_embedding_dim; + const T *cos_data = cos_cache + cache_offset; + const T *sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_rotary_embedding_dim; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_rotary_embedding_dim; + sign = (i < half_rotary_embedding_dim) ? -1 : 1; + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } - template -Status LaunchRotaryEmbeddingKernel( - cudaStream_t stream, - T* output, - const T* input, - const int64_t* position_ids, - const T* cos_cache, - const T* sin_cache, - const int batch_size, - const int sequence_length, - const int num_heads, - const int head_size, - const int max_sequence_length, - const int position_ids_format, - const bool interleaved, - const int max_threads_per_block, - const bool transposed) { - - constexpr int smem_size = 0; - const dim3 grid(num_heads, sequence_length, batch_size); - const dim3 block(head_size, 1, 1); - - // Note: Current implementation assumes head_size <= max_threads_per_block - // because head_size is currently large for LLaMA-2. For smaller head_size - // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` - // instead. This will require kernel changes to support. - - // Default input tensor shape is [batch, seq, hidden_size] - int head_stride = head_size; - int seq_stride = num_heads * head_stride; - int batch_stride = sequence_length * seq_stride; - if (transposed) { - // When transposed, input tensor shape is [batch, num_heads, seq, head_size] - seq_stride = head_size; - head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; - } - - assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>( - output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved, - batch_stride, seq_stride, head_stride - ); - - return CUDA_CALL(cudaGetLastError()); +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T *output, const T *input, const int64_t *position_ids, + const T *cos_cache, const T *sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool transposed) { + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + + int tpb = (head_size + 31) / 32 * 32; + + const dim3 block(tpb); + const dim3 grid(sequence_length, batch_size, num_heads); + + // Default input tensor shape is [batch, seq, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (transposed) { + // When transposed, input tensor shape is [batch, num_heads, seq, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, + num_heads, head_size, rotary_embedding_dim, position_ids_format, + interleaved, batch_stride, seq_stride, head_stride); + + return CUDA_CALL(cudaGetLastError()); } -template Status LaunchRotaryEmbeddingKernel( - cudaStream_t stream, - float* output, - const float* input, - const int64_t* position_ids, - const float* cos_cache, - const float* sin_cache, - const int batch_size, - const int sequence_length, - const int num_heads, - const int head_size, - const int max_sequence_length, - const int position_ids_format, - const bool interleaved, - const int max_threads_per_block, - const bool transposed); - -template Status LaunchRotaryEmbeddingKernel( - cudaStream_t stream, - half* output, - const half* input, - const int64_t* position_ids, - const half* cos_cache, - const half* sin_cache, - const int batch_size, - const int sequence_length, - const int num_heads, - const int head_size, - const int max_sequence_length, - const int position_ids_format, - const bool interleaved, - const int max_threads_per_block, - const bool transposed); - - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime +template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, float *output, const float *input, + const int64_t *position_ids, const float *cos_cache, + const float *sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool transposed); + +template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, half *output, const half *input, + const int64_t *position_ids, const half *cos_cache, + const half *sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool transposed); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, BFloat16 *output, const BFloat16 *input, const int64_t *position_ids, + const BFloat16 *cos_cache, const BFloat16 *sin_cache, const int batch_size, const int sequence_length, + const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index ee1ccc43dcbf..36300fe7a660 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -21,6 +21,7 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 8fb6575d27cc..4a4e3eeecf64 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -53,9 +53,9 @@ class FusedMHARunnerFP16v2::mhaImpl { ~mhaImpl() {} - void setup(const int S, const int B) { + void setup(const int seq_len, const int B) { // For bert and vit, use flash attention when sequence length is larger than the threshold. - use_flash_attention = is_flash_attention(S); + use_flash_attention = is_flash_attention(seq_len); params.force_unroll = use_flash_attention; @@ -68,26 +68,26 @@ class FusedMHARunnerFP16v2::mhaImpl { warps_n = 1; } else { if (sm == 70) { - if (S == 64 || S == 96) { + if (seq_len == 64 || seq_len == 96) { warps_m = 2; warps_n = 2; - } else if (S == 128) { + } else if (seq_len == 128) { warps_m = 1; warps_n = 4; - } else if (S == 256 || S == 384) { + } else if (seq_len == 256 || seq_len == 384) { warps_m = 1; warps_n = 8; } else { ORT_ENFORCE(false, "Unsupported sequence length"); } } else { - if (S == 32 || S == 64 || S == 96 || S == 128) { + if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) { warps_m = 2; warps_n = 2; - } else if (S == 192 || S == 256) { + } else if (seq_len == 192 || seq_len == 256) { warps_m = 1; warps_n = 4; - } else if (S == 384) { + } else if (seq_len == 384) { warps_m = 1; warps_n = 8; } else { @@ -99,7 +99,7 @@ class FusedMHARunnerFP16v2::mhaImpl { // The number of threads per CTA. threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension. - xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m); + xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m); const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 @@ -111,7 +111,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl { has_causal_mask = false; } - void setup_causal_masked_fmha(const int S, const int B) { + void setup_causal_masked_fmha(const int seq_len, const int B) { const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -132,7 +132,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -182,30 +182,30 @@ class FusedMHARunnerFP16v2::mhaImpl { return max_seq_len; } - int S = max_seq_len; + int seq_len = max_seq_len; if (max_seq_len <= 32) { - S = (sm == 70) ? 64 : 32; + seq_len = (sm == 70) ? 64 : 32; } else if (max_seq_len <= 64) { - S = 64; + seq_len = 64; } else if (max_seq_len <= 96) { - S = 96; + seq_len = 96; } else if (max_seq_len <= 128) { - S = 128; + seq_len = 128; } else if (max_seq_len <= 192) { - S = (sm == 70) ? 256 : 192; + seq_len = (sm == 70) ? 256 : 192; } else if (max_seq_len <= 256) { - S = 256; + seq_len = 256; } else if (max_seq_len <= 384) { - S = 384; + seq_len = 384; } - return S; + return seq_len; } protected: - bool is_flash_attention(const int S) const { + bool is_flash_attention(const int seq_len) const { ORT_ENFORCE(interface->mHasCausalMask == false); - return interface->mEnableFlashAttention && S >= kMinSequenceLengthFlashAttention; + return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention; } private: @@ -232,12 +232,12 @@ FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, pimpl(new mhaImpl(this)) { } -void FusedMHARunnerFP16v2::setup(const int S, const int B) { - MHARunner::setup(S, B); +void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) { + MHARunner::setup(seq_len, B); if (mHasCausalMask) { - pimpl->setup_causal_masked_fmha(S, B); + pimpl->setup_causal_masked_fmha(seq_len, B); } else { - pimpl->setup(S, B); + pimpl->setup(seq_len, B); } } diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 9b989dac9a94..1dbbe8c4e7ea 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS +#include #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" @@ -18,25 +18,18 @@ namespace cuda { #if defined(ORT_USE_NCCL) -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ShardedMoE, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ShardedMoE, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ ShardedMoE); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) -using namespace ONNX_NAMESPACE; - template ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("tensor_shards", &tensor_shards_).IsOK()); ORT_ENFORCE(op_kernel_info.GetAttr("local_experts_start_index", &local_experts_start_index_).IsOK()); rank_to_experts_start_index_.resize(nccl_->Size()); // Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized. @@ -57,27 +50,34 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { // Create a {Rank, ExpertsStartIndex} map on Host. AutoDestoryCudaEvent cuda_event; cudaEvent_t& copy_event = cuda_event.Get(); - ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event)); const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); - const Tensor* fc2_experts_weights = context->Input(3); - const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc1_experts_bias_optional = context->Input(3); + const Tensor* fc2_experts_weights = context->Input(4); const Tensor* fc2_experts_bias_optional = context->Input(5); + const Tensor* fc3_experts_weights_optional = context->Input(6); + const Tensor* fc3_experts_bias_optional = context->Input(7); + + MoEParameters moe_params(tensor_shards_); + MoEQuantType quant_type = MoEQuantType::None; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, + fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, + fc3_experts_weights_optional, fc3_experts_bias_optional)); - MoEParameters moe_params; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, - fc1_experts_bias_optional, fc2_experts_bias_optional)); - ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, - "num_experts should be divisible by world_size"); + ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); + + if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) { + ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event)); + } - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, + normalize_routing_weights_); - size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), - static_cast(k_)); + size_t ws_size = moe_runner.getWorkspaceSize( + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_)); size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); @@ -95,54 +95,71 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr expert_for_source_row = IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); - // fc1_scales and fc2_scales are used in quantized MoE - const CudaT* fc1_scales_ptr = nullptr; - const CudaT* fc2_scales_ptr = nullptr; - - moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), - reinterpret_cast(router_probs->template Data()), - reinterpret_cast(fc1_experts_weights->template Data()), - std::move(fc1_scales_ptr), - fc1_experts_bias_optional == nullptr - ? nullptr - : reinterpret_cast(fc1_experts_bias_optional->template Data()), - activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), - std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), - static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), - static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_), - static_cast(k_), reinterpret_cast(work_space.get()), - reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), - reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), - reinterpret_cast(expert_for_source_row.get()), Stream(context)); + const CudaT* fc_scales_ptr = nullptr; + + moe_runner.run_moe_fc( + reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->template Data()), std::move(fc_scales_ptr), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, + fc3_experts_weights_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->template Data()), + std::move(fc_scales_ptr), + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->template Data()), std::move(fc_scales_ptr), + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_), + static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); Tensor* output = context->Output(0, input->Shape()); - size_t stride_count = moe_params.hidden_size; - size_t stride_bytes = stride_count * sizeof(CudaT); - int64_t total_past_rows = 0; - int64_t total_covered_rows = 0; - if (copy_event != nullptr) { - CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event)); + if (moe_params.parallel_type == MoEParallelType::None) { + fc2_output_bc = std::move(fc2_output); } - NCCL_RETURN_IF_ERROR(ncclGroupStart()); - for (int rank = 0; rank < nccl_->Size(); ++rank) { - int64_t experts_start_index = rank_to_experts_start_index_[rank]; - moe_runner.get_total_rows_info(experts_start_index, - moe_params.local_num_experts, - total_past_rows, - total_covered_rows); - const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes; - char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes; - NCCL_RETURN_IF_ERROR(ncclBroadcast(src, - dst, - total_covered_rows * stride_count, - GetNcclDataType(input->DataType()), - rank, - nccl_->Comm(), - Stream(context))); + + if (moe_params.parallel_type == MoEParallelType::EPAndTP) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expert and Tensor Parallelism is not supported yet"); + } + + if (moe_params.parallel_type == MoEParallelType::TP) { + ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size()); + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast(fc2_output.get()), + reinterpret_cast(fc2_output_bc.get()), fc2_output_size / sizeof(CudaT), + GetNcclDataType(input->DataType()), ncclSum, nccl_->Comm(), Stream(context))); + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + } + + if (moe_params.parallel_type == MoEParallelType::EP) { + size_t stride_count = moe_params.hidden_size; + size_t stride_bytes = stride_count * sizeof(CudaT); + int64_t total_past_rows = 0; + int64_t total_covered_rows = 0; + if (copy_event != nullptr) { + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event)); + } + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + for (int rank = 0; rank < nccl_->Size(); ++rank) { + int64_t experts_start_index = rank_to_experts_start_index_[rank]; + moe_runner.get_total_rows_info(experts_start_index, moe_params.local_num_experts, total_past_rows, + total_covered_rows); + const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes; + char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes; + NCCL_RETURN_IF_ERROR(ncclBroadcast(src, dst, total_covered_rows * stride_count, + GetNcclDataType(input->DataType()), rank, nccl_->Comm(), Stream(context))); + } + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); } - NCCL_RETURN_IF_ERROR(ncclGroupEnd()); ort_fastertransformer::finalize_moe_routing_kernelLauncher( reinterpret_cast(fc2_output_bc.get()), reinterpret_cast(output->template MutableData()), @@ -158,8 +175,7 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { } template -Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, - OpKernelContext* context, +Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, OpKernelContext* context, cudaEvent_t& cuda_event) const { if (rank_to_experts_start_index_[0] != std::numeric_limits::min()) { return Status::OK(); @@ -176,23 +192,16 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, IAllocator::MakeUniquePtr(allocator, nccl_->Size(), false, stream); // Only happens in the first run. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), - &local_experts_start_index_, - IndexTypeSize, - cudaMemcpyHostToDevice, - Stream(context))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), &local_experts_start_index_, IndexTypeSize, + cudaMemcpyHostToDevice, Stream(context))); NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast(experts_start_index_d.get()), - reinterpret_cast(rank_to_experts_start_index_d.get()), - 1, - GetNcclDataType(DataTypeImpl::GetType()), - nccl_->Comm(), + reinterpret_cast(rank_to_experts_start_index_d.get()), 1, + GetNcclDataType(DataTypeImpl::GetType()), nccl_->Comm(), Stream(context))); // The const_cast<> violates the const modifier to make sure the synchronization happens only once per session. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(rank_to_experts_start_index_.data()), - rank_to_experts_start_index_d.get(), - nccl_->Size() * IndexTypeSize, - cudaMemcpyDeviceToHost, - Stream(context))); + rank_to_experts_start_index_d.get(), nccl_->Size() * IndexTypeSize, + cudaMemcpyDeviceToHost, Stream(context))); CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming)); CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context))); @@ -204,5 +213,3 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index cbd483fddab7..827283a794dd 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -28,6 +26,7 @@ class ShardedMoE final : public NcclKernel, public MoEBase { Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const; int64_t local_experts_start_index_; + int64_t tensor_shards_; std::vector rank_to_experts_start_index_; }; @@ -36,5 +35,3 @@ class ShardedMoE final : public NcclKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index be7e9f6a8225..583e67b2e6de 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,13 +70,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); -#endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -97,6 +97,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -120,6 +122,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); @@ -167,10 +170,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -204,6 +205,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze); #endif +#ifdef ENABLE_CUDA_NHWC_OPS +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -270,13 +275,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -297,6 +302,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -320,6 +327,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -373,10 +381,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -410,6 +416,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif +#ifdef ENABLE_CUDA_NHWC_OPS + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 87e88ac31c99..dea5391c7629 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -24,7 +24,8 @@ namespace { template struct DispatchGroupNorm { - Status operator()(cudaStream_t stream, + Status operator()(CudaTuningContext* tuning_ctx, + Stream* ort_stream, Tensor* output, Tensor* add_out, const Tensor* input, @@ -44,7 +45,8 @@ struct DispatchGroupNorm { int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( - stream, + tuning_ctx, + ort_stream, reinterpret_cast(output->MutableData()), add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), @@ -209,7 +211,8 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + return dispatcher.InvokeRet(GetTuningContext(), + context->GetComputeStream(), output, add_out, input, skip, bias, gamma, beta, workspace.get(), epsilon_, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc new file mode 100644 index 000000000000..5dec69052884 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc @@ -0,0 +1,101 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread) { + return NextSize(channels_per_block) / channels_per_thread; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; + } + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; + } + } + } + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h new file mode 100644 index 000000000000..a80584d3293a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -0,0 +1,186 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "core/providers/cuda/cuda_common.h" +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread); + +static inline int32_t DivUp(int32_t m, int32_t n) { + return (m + n - 1) / n; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor); + +int GetChannelsPerBlock(int num_channels, int num_groups); + +template +struct GroupNormNHWCParams { + // The output buffer. Shape is (n, h, w, c). + T* dst; + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). + T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + + // The gamma scaling factor. + float const* gamma; + + // The beta term to add in GN. + float const* beta; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; + + // The number of instances in the batch. + int32_t n; + + // The height and width of each activation map. + int32_t h; + int32_t w; + + // Number of channels. + int32_t c; + + // Number of groups. + int32_t groups; + + // Do we apply the SiLU activation function? + bool use_silu; + + // Precomputed values and parameters to control the execution of the kernels. + + // Number of activations per instance (h * w) + int32_t hw; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; + // The precomputed number of groups per block. + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; + + GroupNormNHWCParams(T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + int32_t channels_per_group_in = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group_in) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); + } + + this->use_silu = use_silu; + this->dst = output; + this->add_out = add_out; + this->src = input; + this->skip = skip; + this->bias = bias; + this->gamma = gamma; + this->beta = beta; + this->group_sum_buffer = workspace; + this->n = batch_size; + this->h = height; + this->w = width; + this->c = num_channels; + this->groups = num_groups; + this->hw = this->h * this->w; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(this->hw, max_blocks_per_hw); + this->hw_per_block = DivUp(this->hw, blocks_per_hw); + + this->channels_per_block = channels_per_block; + this->channels_per_group = channels_per_group_in; + this->hwc = this->hw * this->c; + this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); + this->groups_per_block = channels_per_block / this->channels_per_group; + this->epsilon = epsilon; + this->broadcast_skip = broadcast_skip; + + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + this->skip_workspace = (this->add_out != nullptr) ? this->add_out : this->dst; + + this->threads_per_block = GetThreadsPerBlock(channels_per_block, CHANNELS_PER_THREAD); + } +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 48b161552ce0..4909dc5e3897 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -27,6 +27,8 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh" using namespace onnxruntime::cuda; @@ -34,329 +36,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -namespace { - -// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. -constexpr static int32_t CHANNELS_PER_THREAD = 2; - -constexpr static int kSizes[] = {128, 256, 320, 384, 512}; -constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); -constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; - -int NextSize(int x) { - for (size_t i = 0; i < kNumOfSizes; ++i) { - if (x <= kSizes[i]) { - return kSizes[i]; - } - } - - return x; -} -} // namespace - -static inline int32_t DivUp(int32_t m, int32_t n) { - return (m + n - 1) / n; -} - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sum_sq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -struct GroupNormNHWCParams { - // The output buffer. Shape is (n, h, w, c). - T* dst; - - // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). - T* add_out; - - // The input buffer. Shape is (n, h, w, c). - T const* src; - - // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). - T const* skip; - - // Optional input buffer for bias tensor. Shape is (c). - T const* bias; - - // The gamma scaling factor. - float const* gamma; - - // The beta term to add in GN. - float const* beta; - - // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. - float* group_sum_buffer; - - // The number of instances in the batch. - int32_t n; - - // The height and width of each activation map. - int32_t h; - int32_t w; - - // Number of channels. - int32_t c; - - // Number of groups. - int32_t groups; - - // Do we apply the SiLU activation function? - bool use_silu; - - // Precomputed values and parameters to control the execution of the kernels. - - // Number of activations per instance (h * w) - int32_t hw; - - // Number of activations per block - int32_t hw_per_block; - - // Number of channels per block in the C dimension. - int32_t channels_per_block; - - // Number of channels per group in the C dimension. - int32_t channels_per_group; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hw*channels_per_group to compute mean of a group. - float inv_hw_channels_per_group; - // The precomputed number of groups per block. - int32_t groups_per_block; - - // Number of threads per block - int32_t threads_per_block; - - // Epsilon to get stable variance in normalization. - float epsilon; - - // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. - bool broadcast_skip; - - // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. - T* skip_workspace; -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); - -template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - float2 f2 = __half22float2(h2); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] -template -inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); - h2 = h2 + b; - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - float2 b = *reinterpret_cast(&bias[bias_offset]); - f2.x += s.x + b.x; - f2.y += s.y + b.y; - - *reinterpret_cast(&add_out[offset]) = f2; - - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] -template -inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - f2.x += s.x; - f2.y += s.y; - *reinterpret_cast(&add_out[offset]) = f2; - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template -__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { - // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage temp_storage; - - // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. - __shared__ float2 smem[THREADS_PER_BLOCK]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The first activation loaded by that block. - int32_t hw_begin = blockIdx.y * params.hw_per_block; - // The last activation loaded by that block. - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - // The sums. - float sum = 0.F; - float sum_sq = 0.F; - - // Iterate over the activations to compute the sums. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - if (params.skip != nullptr) { - // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) - const int64_t bias_offset = static_cast(ci); - T* add_out = params.skip_workspace; - if (params.broadcast_skip) { - const int64_t skip_offset = static_cast(ni) * params.c + ci; - - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); - } - } - } else { - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); - } - } - } - } else { // GroupNorm - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - UpdateSum(params.src, offset, sum, sum_sq); - } - } - - // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; - // The channel in the group. - int32_t cj = ci % params.channels_per_group; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - - // Do the segmented scan. InclusiveScan is not deterministic. - GroupSums out; - BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced stores later). - // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { - smem[gi] = make_float2(out.sum, out.sum_sq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groups_per_block) { - return; - } - - // The global group index. - // Use neighboring threads for coalesced write. - int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; - - if (gj < params.groups) { - float2 sums = smem[threadIdx.x]; - const int index = (2 * ni) * params.groups + gj; - atomicAdd(¶ms.group_sum_buffer[index], sums.x); - atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); - } -} - template void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; @@ -370,119 +49,26 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<>>( \ + params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \ + params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \ + params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCSumKernel<<>>(params); - break; - } -} - -template -__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu); - -template <> -__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); -} - -template <> -__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast(&dst[offset]) = f2; -} - -template -__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on. - int32_t gi = ci / params.channels_per_group; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sum_sq = 0.F; - if (gi < params.groups) { - const int index = (2 * ni) * params.groups + gi; - sum = params.group_sum_buffer[index]; - sum_sq = params.group_sum_buffer[index + params.groups]; - } - - // Load gamma/beta. Fetch two per thread. - float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); - - // Compute the mean. - float mean = sum * params.inv_hw_channels_per_group; - // Compute the variance. - float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); - // Compute the inverse of the stddev. - float inv_std_dev = rsqrtf(var + params.epsilon); - - int32_t hw_begin = blockIdx.y * params.hw_per_block; - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) } } @@ -497,83 +83,34 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<>>( \ + params.dst, params.src, params.skip, params.gamma, params.beta, params.skip_workspace, \ + params.group_sum_buffer, params.epsilon, params.c, params.channels_per_block, params.channels_per_group, \ + params.groups, params.hwc, params.inv_hw_channels_per_group, params.hw, params.hw_per_block, \ + params.use_silu); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) } } -int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { - int32_t max_divisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { - max_divisor = divisor1; - } - if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { - max_divisor = divisor2; - } - } - } - return max_divisor; -} - -// Find proper channels per block based on a cost function: The cost is number of channels corresponding to -// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has -// work to do so it is ideal case. -int FindChannelsPerBlock(int num_channels, int channels_per_group) { - int min_cost = -1; - int best_candidate = -1; - for (size_t i = kNumOfSizes; i > 0; --i) { - if (kSizes[i - 1] < channels_per_group) { - break; - } - - int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; - int blocks = (num_channels + channels_per_block - 1) / channels_per_block; - int cost = blocks * kSizes[i - 1] - num_channels; - if (cost == 0) { - return channels_per_block; - } - - if (min_cost == -1 || cost < min_cost) { - min_cost = cost; - best_candidate = channels_per_block; - } - } - - return best_candidate; -} - -int GetChannelsPerBlock(int num_channels, int num_groups) { - int32_t channels_per_group = num_channels / num_groups; - int32_t channels_per_block = channels_per_group; - if (channels_per_group < kMaxSize / 2) { - channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); - } - return channels_per_block; -} - template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, T* add_out, const T* input, @@ -591,19 +128,17 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - GroupNormNHWCParams params; - int32_t channels_per_group = num_channels / num_groups; - // channels_per_block is computed in PrePack. - // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { - channels_per_block = GetChannelsPerBlock(num_channels, num_groups); - } + // tuning_ctx only used for ROCm EP. + ORT_UNUSED_PARAMETER(tuning_ctx); - // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases - if (channels_per_block % channels_per_group != 0 || - channels_per_block > kMaxSize || - (channels_per_group % CHANNELS_PER_THREAD != 0)) { + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast(workspace), epsilon, + batch_size, num_channels, height, width, num_groups, use_silu, + broadcast_skip, channels_per_block); + + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "GroupNorm in CUDA does not support the input: n=", batch_size, " h=", height, @@ -612,42 +147,7 @@ Status LaunchGroupNormKernel( " groups=", num_groups); } - params.use_silu = use_silu; - params.dst = output; - params.add_out = add_out; - params.src = input; - params.skip = skip; - params.bias = bias; - params.gamma = gamma; - params.beta = beta; - params.group_sum_buffer = reinterpret_cast(workspace); - params.n = batch_size; - params.h = height; - params.w = width; - params.c = num_channels; - params.groups = num_groups; - params.hw = params.h * params.w; - - // This will allocate as many blocks as possible to partition HW. - // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. - // TODO: tune this logic to find proper blocks when hw is small. - constexpr int32_t max_blocks_per_hw = 1024; - const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); - params.hw_per_block = DivUp(params.hw, blocks_per_hw); - - params.channels_per_block = channels_per_block; - params.channels_per_group = channels_per_group; - params.hwc = params.hw * params.c; - params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); - params.groups_per_block = channels_per_block / params.channels_per_group; - params.epsilon = epsilon; - params.broadcast_skip = broadcast_skip; - - // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. - params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; - - params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; - + auto stream = static_cast(ort_stream->GetHandle()); CUDA_RETURN_IF_ERROR(cudaMemsetAsync( params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); @@ -663,14 +163,14 @@ Status LaunchGroupNormKernel( return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out, const half* input, const half* skip, const half* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, bool silu, bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out, const float* input, const float* skip, const float* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index 9532aeecb2f5..98f38a1475ee 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -8,6 +8,8 @@ #include #include +#include "core/providers/cuda/tunable/cuda_tunable.h" + namespace onnxruntime { namespace contrib { namespace cuda { @@ -21,7 +23,8 @@ int GetChannelsPerBlock(int num_channels, int num_groups); template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, // normalized output tensor. Shape is (n, h, w, c) T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) const T* input, // input tensor. Shape is (n, h, w, c) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh new file mode 100644 index 000000000000..ecd06315e370 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -0,0 +1,451 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sum_sq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + const float val = static_cast(input_v.val[i]); + sum += val; + sum_sq += val * val; + } +} + +template <> +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + const VecT bias_v = *reinterpret_cast(bias + bias_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i] + bias_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template +__global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffer, const T* src, const T* skip, const T* bias, + int32_t channels_per_block, int32_t hw_per_block, int32_t hw, int32_t hwc, int32_t c, + int32_t channels_per_group, int32_t groups, int32_t groups_per_block, bool broadcast_skip) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; + + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { + return; + } + + // The first activation loaded by that block. + int32_t hw_begin = blockIdx.y * hw_per_block; + // The last activation loaded by that block. + int32_t hw_end = min(hw_begin + hw_per_block, hw); + + // The sums. + float sum = 0.F; + float sum_sq = 0.F; + + // Iterate over the activations to compute the sums. + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + if (skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = skip_workspace; + if (broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * c + ci; + + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + UpdateSum(src, offset, sum, sum_sq); + } + } + + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * ILP / channels_per_group; + // The channel in the group. + int32_t cj = ci % channels_per_group; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; + + // Do the segmented scan. InclusiveScan is not deterministic. + GroupSums out; + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == channels_per_group - ILP) { + smem[gi] = make_float2(out.sum, out.sum_sq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= groups_per_block) { + return; + } + + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * groups_per_block + threadIdx.x; + + if (gj < groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * groups + gj; + atomicAdd(&group_sum_buffer[index], sums.x); + atomicAdd(&group_sum_buffer[index + groups], sums.y); + } +} + +template +__device__ void computeGroupNormVec(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma_v, const float* beta_v, bool silu) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + VecT output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + float val = static_cast(input_v.val[i]); + val = (val - mean) * inv_std_dev; + val = gamma_v[i] * val + beta_v[i]; + + if (silu) { + val = val * sigmoid(val); + } + output_v.val[i] = static_cast(val); + } + *(reinterpret_cast(dst + offset)) = output_v; +} + +template +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); + +template <> +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +} + +template <> +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast(&dst[offset]) = f2; +} + +template +__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + using VecF = onnxruntime::cuda::aligned_vector; + + const VecF gamma_v = *reinterpret_cast(gamma + ci); + const VecF beta_v = *reinterpret_cast(beta + ci); + // Iterate over the activations to compute the sums. + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + // Fetch ILP channels per thread. + computeGroupNormVec(input, dst, offset, mean, inv_std_dev, gamma_v.val, beta_v.val, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const float* input, float* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const half* input, half* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template +__global__ void GroupNormNHWCScaleKernel(T* dst, const T* src, const T* skip, const float* gamma, const float* beta, + const T* skip_workspace, const float* group_sum_buffer, float epsilon, + int32_t c, int32_t channels_per_block, int32_t channels_per_group, + int32_t groups, int32_t hwc, float inv_hw_channels_per_group, + int32_t hw, int32_t hw_per_block, bool use_silu) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { + return; + } + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The group that thread works on. + int32_t gi = ci / channels_per_group; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sum_sq = 0.F; + if (gi < groups) { + const int index = (2 * ni) * groups + gi; + sum = group_sum_buffer[index]; + sum_sq = group_sum_buffer[index + groups]; + } + + // Compute the mean. + float mean = sum * inv_hw_channels_per_group; + // Compute the variance. + float var = sum_sq * inv_hw_channels_per_group - (mean * mean); + // Compute the inverse of the stddev. + float inv_std_dev = rsqrtf(var + epsilon); + + int32_t hw_begin = blockIdx.y * hw_per_block; + int32_t hw_end = min(hw_begin + hw_per_block, hw); + + const T* input = (skip != nullptr) ? skip_workspace : src; + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + ComputeGroupNormKernel(input, dst, offset, mean, inv_std_dev, gamma, beta, use_silu, c, ci, hw_begin, hw_end); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc index 4c2999c279e0..2500de39d353 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.cc +++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc @@ -9,22 +9,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GridSample, \ - kMSDomain, \ - 1, \ + DOMAIN, \ + VERSION, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); + onnxruntime::contrib::cuda::GridSample); -REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain) -template -GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { +template +GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); @@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { } } -template -Status GridSample::ComputeInternal(OpKernelContext* context) const { +template +Status GridSample::ComputeInternal(OpKernelContext* context) const { const Tensor* X = context->Input(0); const auto& dims_input = X->Shape().GetDims(); const Tensor* Grid = context->Input(1); @@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]); ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2"); + using Ch = Channels; + TensorShapeVector dims_output(4); - dims_output[0] = dims_input[0]; - dims_output[1] = dims_input[1]; - dims_output[2] = dims_grid[1]; - dims_output[3] = dims_grid[2]; + dims_output[Ch::N] = dims_input[Ch::N]; + dims_output[Ch::C] = dims_input[Ch::C]; + dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; + dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; Tensor* Y = context->Output(0, dims_output); // Return early if the output tensor is going to be of size 0 if (Y->Shape().Size() == 0) { @@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; CudaT* Y_data = reinterpret_cast(Y->MutableData()); - GridSampleImpl( + GridSampleImpl( Stream(context), reinterpret_cast(X->Data()), reinterpret_cast(Grid->Data()), @@ -89,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { } } // namespace cuda } // namespace contrib + +namespace cuda { +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) +} // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h index 08ca58c7cc45..16581bfe7748 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample.h @@ -12,7 +12,7 @@ namespace cuda { using namespace onnxruntime::cuda; -template +template class GridSample final : public CudaKernel { public: explicit GridSample(const OpKernelInfo& info); diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index 8a391eca7e86..b23da635bc83 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) { return static_cast(fx); } -template +template __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { + int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { T pixel = 0.0f; + + auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t { + return Layout == LAYOUT_NCHW + ? (bIdx * C * H * W + cIdx * H * W + y * W + x) + : (bIdx * H * W * C + y * W * C + x * C + cIdx); + }; + if (padding_mode == 0) { // zeros if (x >= 0 && x < W && y >= 0 && y < H) { - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } - } else if (padding_mode == 1) { //border + } else if (padding_mode == 1) { // border x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x)); y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y)); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } else { // Reflection - x = (int64_t) GsReflect(x, border[0], border[2]); - y = (int64_t) GsReflect(y, border[1], border[3]); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + x = (int64_t)GsReflect(x, border[0], border[2]); + y = (int64_t)GsReflect(y, border[1], border[3]); + pixel = input_data[PixelOffset(x, y)]; } return pixel; } -__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) -{ +__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) { float cubic_alpha = -0.75f; x = abs(x); coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha); @@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) { return pixel; } -template +template __global__ void _GridSampleKernel( const T* input_data, const T* grid_data, @@ -110,16 +116,32 @@ __global__ void _GridSampleKernel( { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); // extract batch index, channel index, y index, x index for current thread - int BIdx = idx / (C * H_out * W_out ); - int tmpBCnt = BIdx * (C * H_out * W_out); + int BIdx, yIdx, xIdx, cIdx; + if constexpr (Layout == LAYOUT_NCHW) { + BIdx = idx / (C * H_out * W_out); + int tmpBCnt = BIdx * (C * H_out * W_out); + + cIdx = (idx - tmpBCnt) / (H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); - int cIdx = (idx - tmpBCnt) / (H_out * W_out); - int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); + yIdx = (idx - tmpCCnt) / W_out; + int tmpHCnt = tmpCCnt + yIdx * W_out; - int yIdx = (idx - tmpCCnt) / W_out; - int tmpHCnt = tmpCCnt + yIdx * W_out; + xIdx = (idx - tmpHCnt); + } else { + static_assert(Layout == LAYOUT_NHWC, "Unsupported layout"); - int xIdx = (idx - tmpHCnt); + BIdx = idx / (H_out * W_out * C); + int tmpBCnt = BIdx * (H_out * W_out * C); + + yIdx = (idx - tmpBCnt) / (W_out * C); + int tmpHCnt = tmpBCnt + yIdx * (W_out * C); + + xIdx = (idx - tmpHCnt) / C; + int tmpWCnt = tmpHCnt + xIdx * C; + + cIdx = (idx - tmpWCnt); + } int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; T grid_X = grid_data[grid_idx * 2 + 0]; @@ -147,8 +169,9 @@ __global__ void _GridSampleKernel( if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound if (padding_mode == 1) { // border - grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); - grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); + // Clamping must not be done here, see #10607 + // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); + // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); } else if (padding_mode == 2) { // reflection grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); @@ -175,10 +198,10 @@ __global__ void _GridSampleKernel( w_lb = w_b * w_l; w_rb = w_b * w_r; - T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); - T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); - T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); - T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); + T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); + T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); + T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); + T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; output_data[outIdx] = interpoV; return; @@ -186,7 +209,8 @@ __global__ void _GridSampleKernel( if (mode == 1) { // nearest int x_n = grid_x_imgSpace; int y_n = grid_y_imgSpace; - output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); + output_data[outIdx] = + PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); return; } if (mode == 2) { // bicubic @@ -195,7 +219,8 @@ __global__ void _GridSampleKernel( T p[4][4] = {}; // [H][W] for (int64_t h = 0; h < 4; h++) { for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); + p[h][w] = + PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); } } T dx = grid_x_imgSpace - x0 - 1; @@ -204,7 +229,7 @@ __global__ void _GridSampleKernel( } } -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, @@ -216,17 +241,23 @@ void GridSampleImpl( const int64_t H_out, const int64_t W_out, T* output_data) { - int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock)); - _GridSampleKernel<<>>( - input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data); + using Ch = Channels; + + int blocksPerGrid = static_cast( + ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); + _GridSampleKernel<<>>( + input_data, grid_data, mode, padding_mode, align_corners, + dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], + H_out, W_out, output_data); } -#define SPECIALIZED_IMPL(T) \ - template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ - const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ - const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); +#define SPECIALIZED_IMPL(T, IsNHWC) \ + template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ + const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); -SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(float, false) // NCHW +SPECIALIZED_IMPL(float, true) // NHWC } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h index 6df86ce16190..62cd66a48fa8 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index 81e161e60642..9075dda26f86 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -78,9 +78,9 @@ struct Inverse::ComputeImpl { cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // Make a copy of the input which will serve as a workspace as well. - if (std::is_same::value || std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(input_count, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { // Convert from MLFloat16(half) to float Impl_Cast(stream, reinterpret_cast(input.Data()), input_workspace.get(), input_count); } else { @@ -96,7 +96,7 @@ struct Inverse::ComputeImpl { // Need to compute ptrs for output buffers // Output for MLFloat IAllocatorUniquePtr output_ptrs = inst->GetScratchBuffer(n_batches, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { IAllocatorUniquePtr ml_float_output = inst->GetScratchBuffer(input_count, ort_stream); ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, ml_float_output.get(), num_batches, rows, output_ptrs)); // Do the inverse @@ -112,7 +112,7 @@ struct Inverse::ComputeImpl { ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches)); // We are done here } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(static_cast(input_count), ort_stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data(), sizeof(double) * input_count, cudaMemcpyDeviceToDevice, stream)); diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu index ca94477114ee..47a64502b348 100644 --- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu @@ -97,8 +97,8 @@ void ComplexMul_Impl( const TArray* rhs_padded_strides, const T* rhs_data, const TArray* fdm_output_strides, - const onnxruntime::cuda::fast_divmod& fdm_H, - const onnxruntime::cuda::fast_divmod& fdm_C, + const onnxruntime::cuda::fast_divmod& /*fdm_H*/, + const onnxruntime::cuda::fast_divmod& /*fdm_C*/, T* output_data, int64_t count, int64_t lhs_size, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 064b6dd39243..28ab27ee33d1 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -174,7 +174,7 @@ Status GemmFloat8::ComputeGemm( int32_t dtype_A, int32_t dtype_B, int32_t dtype_C, int32_t dtype_Y, const TensorShape& shape_A, const TensorShape& shape_B, - const TensorShape& shape_C, const TensorShape& shape_Y, + const TensorShape& shape_C, const TensorShape& /*shape_Y*/, bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, const void* p_input_c, const void* p_scale_a, const void* p_scale_b, const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h new file mode 100644 index 000000000000..07c38c58e446 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator { + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h similarity index 62% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h index 9b97690fe70f..99cbe4a66049 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifdef USE_CUTLASS - #pragma once #include @@ -29,19 +26,22 @@ namespace ort_fastertransformer { template inline int compute_occupancy_for_kernel() { - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); if (smem_size > (48 << 10)) { - cudaError_t status = - cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (status == cudaError::cudaErrorInvalidValue) { - // Clear the error bit since we can ignore this. - // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an - // occupancy of 0. This will cause the heuristic to ignore this configuration. - status = cudaGetLastError(); + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + CUDA_CALL_THROW(cudaGetDevice(&device)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. return 0; } - CUDA_CALL_THROW(status); } int max_active_blocks = -1; @@ -52,5 +52,3 @@ inline int compute_occupancy_for_kernel() { } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h similarity index 58% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h index 617f9992d180..da8cb6d294ef 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h @@ -28,56 +28,68 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - -#ifdef USE_CUTLASS - /*! \file - \brief Scheduler for grouped GEMM + \brief Functor performing linear combination with a maximum operation used by epilogues. */ #pragma once +#include "cutlass/array.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/matrix_coord.h" - -#include "moe_problem_visitor.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { -namespace gemm { -namespace kernel { +namespace epilogue { +namespace thread { -/// Visitor class to abstract away the algorithm for iterating over tiles -template -struct GemmMoeProblemVisitor - : public MoeProblemVisitor, ThreadblockShape, - GroupScheduleMode_, PrefetchTileCount, ThreadCount> { - static bool const kTransposed = Transposed; +///////////////////////////////////////////////////////////////////////////////////////////////// - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; - using Base = - MoeProblemVisitor; - using Params = typename Base::Params; - using SharedStorage = typename Base::SharedStorage; +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} - // - // Methods - // - CUTLASS_DEVICE - GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) - : Base(params_, shared_storage_, block_idx) {} -}; +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} ///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct GELU_taylor { + static bool const kIsHeavy = true; + + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = static_cast(0.7978845608028654); + float k1 = static_cast(0.044715); + + return static_cast( + cutlass::constants::half() * z * + (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } +}; -} // namespace kernel -} // namespace gemm +} // namespace thread +} // namespace epilogue } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 000000000000..affd1d83a35d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "tensorrt_llm/common/quantization.h" + +namespace tk = tensorrt_llm::common; + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + explicit Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_, + int64_t batch_stride_D_) + : elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; + + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + explicit Params(Arguments const& args) + : elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) {} + }; + + /// Shared storage + struct SharedStorage {}; + + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, + AlphaScaleElementType* ptr_alpha_row, AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + per_token_quant_(quant_option.hasPerTokenScaling()), + per_channel_quant_(quant_option.hasPerChannelScaling()), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), + iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) { + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 000000000000..40f126d56616 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,247 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOpMixed; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + + static int const kFragmentsPerIteration = 2; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = + Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + + private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) : stride_((ref.stride(0) / LoadType::kElements)) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * static_cast(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h similarity index 55% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h index f41c42440f19..b784646c31f8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h @@ -1,11 +1,12 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,65 +23,14 @@ * */ -#ifdef USE_CUTLASS - #pragma once -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/functional.h" -#include "cutlass/half.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_generic.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" #include "cutlass/epilogue/thread/linear_combination_silu.h" -namespace cutlass { -namespace epilogue { -namespace thread { - -__forceinline__ __device__ float copysignf_pos(float a, float b) { - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; -} - -__forceinline__ __device__ float tanh_opt(float x) { -#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - const float exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); -#else - return fast_tanh(x); -#endif -} - -template <> -struct GELU_taylor { - static const bool kIsHeavy = true; - CUTLASS_DEVICE - float operator()(float const& z) const { - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); - - return float( - cutlass::constants::half() * z * - (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } -}; - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - namespace ort_fastertransformer { struct EpilogueOpBiasSilu {}; @@ -89,49 +39,71 @@ struct EpilogueOpBiasReLU {}; struct EpilogueOpBiasFtGelu {}; +struct EpilogueOpDefaultSilu {}; + +struct EpilogueOpDefaultReLU {}; + +struct EpilogueOpDefaultFtGelu {}; + struct EpilogueOpBias {}; -struct EpilogueOpNoBias {}; +struct EpilogueOpDefault {}; template struct Epilogue {}; +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationSilu; + ElementAccumulator, BiasScaleMode>; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationRelu; + ElementAccumulator, BiasScaleMode>; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationGeneric< cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, - ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling, - cutlass::FloatRoundStyle::round_to_nearest, true>; + ElementAccumulator, BiasScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombination; + ElementAccumulator, BiasScaleMode>; }; +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + template -struct Epilogue { - using Op = - cutlass::epilogue::thread::LinearCombination; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; -} // namespace ort_fastertransformer +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, DefaultScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; -#endif +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 000000000000..f5064afc23ae --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h new file mode 100644 index 000000000000..b226b73e86fe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -0,0 +1,476 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include +#include +#include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, + int64_t* splitk_buffer_offsets) { + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different hidden_size (=m*n) + // so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] + + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; + + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) { + sum += static_cast(in_tensor_[k_idx * hidden_size + i]); + } + out_tensor_[i] = (T_OUT)(sum); + } +} + +/// GEMM Grouped +template +class BaseSplitkGrouped { + public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + + protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; + + private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, int32_t tile_count, void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, args.problem_count, args.threadblock_count, + reinterpret_cast(host_workspace.data())); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + + public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { return BaseKernel::can_implement(args); } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) { + total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); + } + size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes, args.problem_count, + args.threadblock_count); + } + return workSpaceSize; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { return dim3(args.threadblock_count, 1, 1); } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + BaseKernel::kThreadCount, smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, + int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient(cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, + int available_sm_count = -1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); + return 0; + } + + int multiprocessor_count; + result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); + return 0; + } + + bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) { + available_sm_count = multiprocessor_count; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + gemm_params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_.update(args, workspace, tile_count); + } else { + gemm_params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + if (!gemm_params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + // + // Launch kernel + // + + // Launch splitk grouped gemm + { + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel<<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + // Launch splitkReduction + { + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, + gemm_params_.split_k_slices, gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class SplitkGemmGrouped : public BaseSplitkGrouped { + public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h similarity index 71% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index efb30d07507b..2b3478a38fc2 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -1,11 +1,12 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,53 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. - - Note that for int4, ThreadBlockK MUST be 64. - - */ - -#ifdef USE_CUTLASS - #pragma once -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/arch/mma.h" -#include "cutlass/platform/platform.h" +#include "cutlass/bfloat16.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" namespace cutlass { namespace gemm { namespace kernel { -template -struct LayoutDetailsB {}; - -// Volta specialiations. Volta will dequantize before STS, so we need a different operator -template -struct LayoutDetailsB { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - template struct MixedGemmArchTraits {}; @@ -68,7 +38,7 @@ struct MixedGemmArchTraits { static constexpr int Stages = 2; using OperatorClass = cutlass::arch::OpClassSimt; using AccType = float; - using LayoutB = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; static constexpr int ElementsPerAccessA = 1; static constexpr int ElementsPerAccessB = 1; @@ -82,10 +52,13 @@ struct MixedGemmArchTraits { // ========================= Volta Traits =========================== // Volta will always dequantize after the global memory load. // This will instantiate any HMMA tensorcore kernels for Volta. +// Note that volta does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm70, - typename cutlass::platform::enable_if::value>::type> { + typename cutlass::platform::enable_if::value || + cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; @@ -105,10 +78,13 @@ struct MixedGemmArchTraits< }; // ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm75, - typename cutlass::platform::enable_if::value>::type> { + typename cutlass::platform::enable_if::value || + cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; @@ -131,7 +107,8 @@ struct MixedGemmArchTraits< template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm80, - typename cutlass::platform::enable_if::value>::type> { + typename cutlass::platform::enable_if::value || + cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; @@ -153,5 +130,3 @@ struct MixedGemmArchTraits< } // namespace kernel } // namespace gemm } // namespace cutlass - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 000000000000..fe4bc0940d9e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h new file mode 100644 index 000000000000..9339be92dfb2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "cutlass/layout/permute.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void> +struct DefaultSplitkGemmGrouped; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout> +struct DefaultSplitkGemmGrouped::value>::type> { + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = + kernel::detail::MapArguments; + + // Define the default GEMM kernel + using DefaultGemmKernel = + typename kernel::DefaultGemm::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::SplitkGemmGrouped; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 000000000000..778d45f39eab --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,513 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), + group_size(group_size), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_zero(ref_zero), + ref_C(ref_C), + ref_D(ref_D), + batch_count(serial_split_k_factor), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), + group_size(args.group_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.ref_A.layout()), + ref_A(args.ref_A), + params_B(args.ref_B.layout()), + ref_B(args.ref_B), + params_scale(args.ref_scale.layout()), + ref_scale(args.ref_scale), + ref_zero(args.ref_zero), + params_C(args.ref_C.layout()), + ref_C(args.ref_C), + params_D(args.ref_D.layout()), + ref_D(args.ref_D), + output_op(args.output_op), + semaphore(static_cast(workspace)), + gemm_k_size(gemm_k_size), + gather_A_indices(args.gather_A_indices), + gather_B_indices(args.gather_B_indices), + scatter_D_indices(args.scatter_D_indices) {} + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement(Arguments const& args) { + static int const kAlignmentA = + (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), {params.problem_size.m(), problem_size_k}, + thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, tb_offset_B, params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#else + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h new file mode 100644 index 000000000000..6cb5cc4e1334 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -0,0 +1,66 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> { + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = + MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 000000000000..fb35b2dbf12c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,516 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = tensorrt_llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, TensorRefB ref_B_, + tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, TensorRefAlphaRow ref_alpha_row_, + TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, int64_t batch_stride_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_), + problem_size(problem_size_), + batch_count(batch_count_), + ref_A(ref_A_), + ref_B(ref_B_), + quant_option(quant_option_), + ref_alpha_col(ref_alpha_col_), + ref_alpha_row(ref_alpha_row_), + ref_C(ref_C_), + ref_D(ref_D_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + batch_stride_D(0), + epilogue_visitor(epilogue_visitor_) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), + params_A(0), + params_B(0), + params_alpha_col(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_alpha_col(nullptr), + ptr_alpha_row(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0) {} + + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_alpha_col(args.ref_alpha_col.layout()), + params_alpha_row(args.ref_alpha_col.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + quant_option(args.quant_option), + ptr_alpha_col(args.ref_alpha_col.data()), + ptr_alpha_row(args.ref_alpha_row.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { return can_implement(args.problem_size); } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), thread_idx, warp_idx, + lane_idx, params.params_alpha_col, params.params_C, params.params_D, params.quant_option, params.ptr_alpha_row, + params.ptr_alpha_col, params.ptr_C, params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 000000000000..35d22b2f55a8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,126 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB {}; + +// Volta specialiations. Volta will dequantize before STS, so we need a different operator +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template + struct LayoutDetailsB < + uint8_t, + Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template + struct LayoutDetailsB < + uint4b_t, + Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 90>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 90>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h new file mode 100644 index 000000000000..9e3e9d20d7f6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -0,0 +1,471 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type {}; + +template +struct use_dq_gemm> : platform::true_type {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + int problem_count; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, + ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, + ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count), + threadblock_count(threadblock_count), + group_size(group_size), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(total_rows_before_expert), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) { + if (platform::is_same::value || platform::is_same::value) { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + explicit Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, + tile_count), + threadblock_count(args.threadblock_count), + group_size(args.group_size), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, + args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { + if (platform::is_same::value || platform::is_same::value) { + if (args.weight_scales == nullptr) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } else if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } else if (args.group_size != args.gemm_k) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); + return Status::kInvalid; + } else if (static_cast(args.gemm_n) < Mma::IteratorB::AccessType::kElements) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + int loop = 0; + while (problem_visitor.next_tile()) { + loop++; + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset(static_cast(cta_idx / grid_shape.n()) * Mma::Shape::kM, + static_cast(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = (reinterpret_cast(params.ptr_B)) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + auto CreateMMA = [&]() { + if constexpr (use_dq_gemm::value) + return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + else + return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + }; + Mma mma = CreateMMA(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + + if constexpr (use_dq_gemm::value) { + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, + scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } else { + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + run_kernel(params, + shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#else + // static_assert(false, + // "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + ; +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h similarity index 80% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h index 157437439cd0..6852d4c811b4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -1,40 +1,24 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * http://www.apache.org/licenses/LICENSE-2.0 * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /*! \file \brief Base scheduler for grouped problems, using MoE */ -#ifdef USE_CUTLASS - #pragma once #include "cutlass/gemm/kernel/grouped_problem_visitor.h" @@ -108,7 +92,7 @@ struct BaseMoeProblemVisitor { /// Get the grid shape CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) { return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); } @@ -147,9 +131,9 @@ struct BaseMoeProblemVisitor { } CUTLASS_HOST_DEVICE - static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { return ProblemSizeHelper::tile_count(grid); } + static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) { return ProblemSizeHelper::tile_count(grid); } - static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) { int32_t total_tiles = 0; for (int32_t i = 0; i < problem_count; ++i) { auto problem = host_problem_sizes_ptr[i]; @@ -278,17 +262,15 @@ struct MoeProblemVisitor +struct SplitkGemmGrouped { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementFinalOutput = typename MapArguments::ElementA; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, + typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, + ElementFinalOutput** ptr_C, ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes), + problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + host_problem_sizes(host_problem_sizes), + split_k_slices(split_k_slices), + splitk_buffer_offsets(splitk_buffer_offsets) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // + // Methods + // + + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_C_split(nullptr), + ptr_D_split(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + swizzle_log_tile(0), + gemm_k_size(0), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} + + CUTLASS_HOST_DEVICE + explicit(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), + host_problem_sizes(args.host_problem_sizes), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D), + ptr_C_split(reinterpret_cast(workspace)), + ptr_D_split(reinterpret_cast(workspace)), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldd(args.ldd), + split_k_slices(args.split_k_slices), + splitk_buffer_offsets(args.splitk_buffer_offsets) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.host_problem_sizes[0], {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + // only support same k + int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + SplitkGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { return Status::kSuccess; } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA* ptr_A = + reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB* ptr_B = + reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + cutlass::gemm::GemmCoord threadblock_offset(static_cast(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + static_cast(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { + problem_size_k = problem_size.k(); + } else { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset_C); + + iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset_C); + iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 000000000000..8bbc1ee4e6c4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,120 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters {}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = + FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = + NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = + NumericArrayConverter; + + using TransformAfterLDS = + FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 000000000000..8b9d6b0b14ad --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,289 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIterators; + +// Fine grained iterators +template +struct DefaultScaleIterators> { + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIterators> { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, Element, Layout, 0, + IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for elementA + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator_, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && + !layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, + layout::RowMajor, OperatorClass, std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + using ScaleIterators = + DefaultScaleIterators; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, + typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, + layout::RowMajor, typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>; +}; + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator_, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && + layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, layout::ColumnMajor, + ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = + DefaultScaleIterators; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, + typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, + layout::RowMajor, typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 000000000000..91c4cd342569 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,245 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, + ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, ElementScale, + LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, + IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, + ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, ElementScale, + LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, + IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 000000000000..1a3e7e39c965 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,283 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 000000000000..4afd482f8562 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,345 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform::conditional::type; + using MmaElementB = typename platform::conditional::type; + + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 000000000000..cf5ba6faa0c8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, + typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) { + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) { + warp_mma(D, A, B, C, warp_tileB_k_offset); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 000000000000..f11e94d9d2b9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 000000000000..dd934b9a0036 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,634 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = + reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr = + reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, IteratorScale& iterator_scale, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, iterator_scale, group_start_iteration_A, + group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, iterator_scale, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h new file mode 100644 index 000000000000..33bcb1910638 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -0,0 +1,586 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it assumes per-column + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 000000000000..2c85ba8a1995 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,379 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = bool> +class DqMmaPipelined : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = + warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation + ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this + ///< argument is not added, it does not affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = + NumericArrayConverter; + + using TransformScale = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, + warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 000000000000..f0b6f4fcaad3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,103 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 8 * sizeof_bits::value / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 000000000000..a368c6d22026 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); + + static_assert(platform::is_same::value || + (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = + MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = + MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 000000000000..51ca8282e42f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,534 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, bfloat16_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 80 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) {} + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) {} + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + using AccessType = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + scale_frag_ptr[tile_iter] = *reinterpret_cast(pointer_ + ColsPerMmaTile * tile_iter); + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + operand_frag = mul_op(operand_frag, scale_frag); + } + + private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8 + lane_idx % 4; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + // For col major B, each thread will jump 4 cols to get its next value inside + // of the super mma. + CUTLASS_PRAGMA_UNROLL + for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { + scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + using MmaOperandB = typename ArchMmaOperator::FragmentB; + static constexpr int total_n_mmas = 2 * TileNIterations; + static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + + MmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h new file mode 100644 index 000000000000..0841218a480b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace ort_fastertransformer { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64 +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, +}; + +enum class MainloopScheduleType { + AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. +}; + +enum class EpilogueScheduleType { + AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. +}; + +enum class ClusterShape { ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1 }; + +struct CutlassGemmConfig { + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + + CutlassGemmConfig() {} + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config(tile_config), split_k_style(split_k_style), split_k_factor(split_k_factor), stages(stages) {} + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90), + mainloop_schedule(mainloop_schedule), + epilogue_schedule(epilogue_schedule), + cluster_shape(cluster_shape) {} +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 000000000000..7fd1745aa2c5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,392 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter {}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = + __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h similarity index 97% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h index 111d5240e40a..e5abefa35bc8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h @@ -31,9 +31,6 @@ /*! \file \brief Defines new layouts needed for MoE */ - -#ifdef USE_CUTLASS - #pragma once #include "cutlass/cutlass.h" @@ -45,7 +42,7 @@ namespace cutlass { namespace layout { template -class ColumnMajorTileInterleave { +struct ColumnMajorTileInterleave { static constexpr int kRowsPerTile = RowsPerTile; static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; @@ -62,5 +59,3 @@ struct IsColumnMajorTileInterleave> { } // namespace layout } // namespace cutlass - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 000000000000..79811ef3e611 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + explicit Params(Layout const& layout) : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), + pointer_scale_(reinterpret_cast(const_cast(pointer_scale))), + pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = + threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator(params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), + group_size) {} + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { is_valid_ &= (!enable); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return is_valid_; } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { return reinterpret_cast(pointer_scale_); } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { return reinterpret_cast(pointer_zero_); } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 000000000000..403221a95601 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,50 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass { + +enum class WeightOnlyQuantOp { UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS }; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) { return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; } + +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index f0abd46572a9..cd59e904ad9e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef USE_CUTLASS #include "cutlass_heuristic.h" @@ -66,9 +65,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, } // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); - const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); - const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + const size_t ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const size_t ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > workspace_bytes) { return false; @@ -128,7 +127,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) { + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); int occupancy = occupancies[ii]; @@ -152,8 +151,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector(ctas_per_wave); + const float current_score = static_cast(num_waves_total) - num_waves_fractional; const float score_slack = 0.1f; if (current_score < config_score || @@ -186,5 +185,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector #include @@ -38,4 +37,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector -using void_t = void; - -template -struct use_dq_gemm : platform::false_type {}; - -template -struct use_dq_gemm> : platform::true_type {}; - -// SFINAE overload for dequantizing gemm -template ::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, - typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, - typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { - typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, - scale_extent, thread_idx, tb_offset_scale); - - mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_scale, src_accum); -} - -// SFINAE overload for normal gemm. This completely ignores the scale parameters -template ::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, - typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, - typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { - mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MoeFCGemm { - public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = false; - - // Optional transpose - using MapArguments = - kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - static_assert(!kTransposed, "Transpose problem not supported"); - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor = - GemmMoeProblemVisitor; - - // - // Structures - // - - /// Argument structure - struct Arguments { - // - // Data members - // - - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - - int64_t* total_rows_before_expert; - int64_t gemm_n; - int64_t gemm_k; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0), - threadblock_count(0), - ptr_A(nullptr), - ptr_B(nullptr), - weight_scales(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - total_rows_before_expert(nullptr), - gemm_n(0), - gemm_k(0), - host_problem_sizes(nullptr) {} - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(int problem_count, int threadblock_count, typename EpilogueOutputOp::Params output_op, - const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, - ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, - GemmCoord* host_problem_sizes = nullptr) - : problem_count(problem_count), - threadblock_count(threadblock_count), - output_op(output_op), - ptr_A(const_cast(ptr_A)), - ptr_B(const_cast(ptr_B)), - weight_scales(const_cast(weight_scales)), - ptr_C(const_cast(ptr_C)), - ptr_D(ptr_D), - total_rows_before_expert(total_rows_before_expert), - gemm_n(gemm_n), - gemm_k(gemm_k), - host_problem_sizes(nullptr) { - if (platform::is_same::value || platform::is_same::value) { - assert(weight_scales); - } - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, - tile_count), - threadblock_count(args.threadblock_count), - output_op(args.output_op), - ptr_A(args.ptr_A), - ptr_B(args.ptr_B), - weight_scales(args.weight_scales), - ptr_C(args.ptr_C), - ptr_D(args.ptr_D) {} - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { - problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, - args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - weight_scales = args.weight_scales; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename ProblemVisitor::SharedStorage problem_visitor; - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - public: - // - // Methods - // - - CUTLASS_DEVICE - MoeFCGemm() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } - - static Status can_implement(Arguments const& args) { - if (args.weight_scales != nullptr) { - CUTLASS_TRACE_HOST( - "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); - return Status::kInvalid; - } - return Status::kSuccess; - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { - return 0; - } - - // The dummy template parameter is not used and exists so that we can compile this code using - // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in - // a namespace - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) { CUTLASS_NOT_IMPLEMENTED(); } - }; - - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) { - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert(platform::is_same::value && kInterleave == 1 || - platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - cutlass::gemm::GemmCoord threadblock_offset(int(cta_idx / grid_shape.n()) * Mma::Shape::kM, - int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Load element pointers. Exchange pointers and strides if working on the transpose - const int64_t rows_to_jump = - problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; - - char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value ? gemm_n : gemm_k * kInterleave; - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - 0, - }; - - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - - cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, - {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, - tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); - run_mma(mma, gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, weight_scale_ptr, - {1, problem_size.n()}, thread_idx, tb_offset_scale); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; - ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - - LayoutC layout_C(0); - LayoutC layout_D(gemm_n); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, - threadblock_offset.mn()); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, - threadblock_offset.mn()); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // Next tile - problem_visitor.advance(gridDim.x); - } - } - }; - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index a30bd1c1e9df..7e29dde8f897 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -14,12 +14,10 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #pragma once +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h" #include -#include "ft_gemm_configs.h" namespace ort_fastertransformer { @@ -44,8 +42,9 @@ class MoeGemmRunner { int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream); - void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); + void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream); private: template @@ -64,5 +63,3 @@ class MoeGemmRunner { }; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d0dfe7c5a64..15cab9dd4a9b 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -13,13 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif -#ifdef USE_CUTLASS - -#include "moe_gemm_kernels_template.h" +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu new file mode 100644 index 000000000000..1309a7c32a37 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7a5d97902ee8..0277fab9df95 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif -#ifdef USE_CUTLASS +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" -#include "moe_gemm_kernels_template.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 3fd0fc47055a..d81808e217fb 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -14,29 +14,38 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif +// Ignore CUTLASS warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + +#include "cutlass/arch/arch.h" #include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/arch.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" -#include "compute_occupancy.h" -#include "epilogue_helpers.h" -#include "layout_traits_helper.h" -#include "moe_cutlass_kernel.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif #ifdef __GNUC__ #pragma GCC diagnostic pop @@ -59,10 +68,6 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, int* kernel_occupancy = nullptr) { - if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { - ORT_THROW("[FT Error][MoeGemm] Grouped gemm does not support split-k"); - } - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, "Specialized for half, float"); @@ -79,10 +84,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w using CutlassWeightType_ = typename cutlass::platform::conditional::value, cutlass::half_t, WeightType>::type; + using CutlassWeightType = CutlassWeightType_; - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. + // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, we do not target TCs. using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; using ElementAccumulator = typename MixedGemmArchTraits::AccType; @@ -111,17 +117,17 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w return; } int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - if (occupancy == 0) { - ORT_THROW("[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); - } - const int threadblock_count = multi_processor_count * occupancy; + ORT_ENFORCE(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + int const threadblock_count = multi_processor_count * occupancy; - typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), + biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + int const group_size = gemm_k; typename GemmGrouped::Arguments args( - num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), reinterpret_cast(weight_scales), - reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, + num_experts, threadblock_count, group_size, epilogue_op, reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(weight_scales), + reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); GemmGrouped gemm; @@ -151,10 +157,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w template struct dispatch_stages { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, - CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch(const T* /*A*/, const WeightType* /*B*/, const T* /*weight_scales*/, const T* /*biases*/, + T* /*C*/, int64_t* /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/, + int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/, + cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) { std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); @@ -223,10 +229,28 @@ template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, + int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } + break; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } + break; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, @@ -246,13 +270,13 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::Undefined: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + ORT_THROW("GEMM config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + ORT_THROW("GEMM config should have already been set by heuristic."); break; default: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); + ORT_THROW("Config is invalid for same type tensorop GEMM."); break; } } @@ -302,8 +326,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig template ::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, + int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: @@ -395,20 +419,20 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp cudaStream_t stream) { switch (activation_type) { case ActivationType::Relu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); break; case ActivationType::Gelu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); break; case ActivationType::Silu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); break; case ActivationType::Identity: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); break; case ActivationType::InvalidType: ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); @@ -420,13 +444,11 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp } template -void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, +void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, + T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) { - run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 9232e8d01293..360c0aacd9c7 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,13 +16,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - +#include #include #include #include #include -#include // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ @@ -32,7 +30,6 @@ #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" #ifdef __GNUC__ #pragma GCC diagnostic pop @@ -51,7 +48,6 @@ #endif namespace ort_fastertransformer { - static constexpr int WARP_SIZE = 32; // ====================== Softmax things =============================== @@ -110,14 +106,15 @@ __launch_bounds__(TPB) __global__ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template -__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, const int) { +__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, int, bool) { // Does not support pre-Kepler architectures ; } #else template -__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, - int* indices, int* source_rows, int num_experts, int k) { +__launch_bounds__(TPB) __global__ + void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, int* indices, int* source_rows, + int num_experts, int k, bool normalize_routing_weights) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -130,6 +127,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool should_process_row = finished ? !finished[block_row] : true; const int thread_read_offset = blockIdx.x * num_experts; + float output_row_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities @@ -157,6 +155,13 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, output[idx] = result_kvp.value; indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; + + if (normalize_routing_weights && k_idx == k - 1) { +#pragma unroll + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + } + } } __syncthreads(); } @@ -180,7 +185,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices, - int* source_rows, int k) { + int* source_rows, int k, bool normalize_routing_weights) { // We begin by enforcing compile time assertions and setting up compile time constants. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); @@ -298,6 +303,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ int start_col = first_elt_read_by_thread; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + float output_row_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax float max_val = row_chunk[0]; @@ -338,8 +344,16 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // single) thread per row of the input/output matrices. const int idx = k * thread_row + k_idx; output[idx] = T(max_val); + output_row_sum = output_row_sum + static_cast(max_val); indices[idx] = should_process_row ? expert : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; + + if (normalize_routing_weights && k_idx == k - 1) { +#pragma unroll + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + } + } } // Finally, we clear the value in the thread with the current max if there is another iteration to run. @@ -372,7 +386,8 @@ struct TopkConstants { template void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, - int num_rows, int num_experts, int k, cudaStream_t stream) { + int num_rows, int /*num_experts*/, int k, bool normalize_routing_weights, + cudaStream_t stream) { static constexpr unsigned long MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); @@ -383,62 +398,62 @@ void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topk_gating_softmax - <<>>(input, finished, output, num_rows, indices, source_row, k); + topk_gating_softmax<<>>( + input, finished, output, num_rows, indices, source_row, k, normalize_routing_weights); } template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, - int* indices, int* source_row, int num_rows, int num_experts, - int k, cudaStream_t stream) { + int* indices, int* source_row, int num_rows, int num_experts, int k, + bool normalize_routing_weights, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; switch (num_experts) { case 2: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 4: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 8: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 16: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 32: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 64: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 128: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 256: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } default: { static constexpr int TPB = 256; moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); - moe_top_k - <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k); + moe_top_k<<>>(softmax_temp_output, finished, output, indices, source_row, + num_experts, k, normalize_routing_weights); } } } @@ -505,8 +520,8 @@ __global__ void compute_total_rows_before_expert_kernel(const int* sorted_expert total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); } -__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, - int local_num_experts, int local_experts_start_index) { +__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, int local_num_experts, + int local_experts_start_index) { const int expert = blockIdx.x * blockDim.x + threadIdx.x; const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; @@ -523,25 +538,30 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { - total_past_rows_ = 0; - total_covered_rows_ = 0; +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, + bool normalize_routing_weights) + : has_fc3_(has_fc3), + total_past_rows_(0), + total_covered_rows_(0), + normalize_routing_weights_(normalize_routing_weights) { moe_gemm_runner_.initialize(sm_version); } template -size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, const int hidden_size, - const int inter_size, int num_experts, - int k) { - const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); - const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); - const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); - const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); - int num_softmax_outs = 0; +size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_rows, const size_t hidden_size, + const size_t inter_size, size_t num_experts, + size_t k) { + total_covered_rows_ = k * num_rows; + + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); + size_t num_softmax_outs = 0; const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - num_softmax_outs = static_cast(pad_to_multiple_of_16(num_rows * num_experts)); + num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); } // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them @@ -550,13 +570,13 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, total_ws_bytes += buf_size * sizeof(T); // permuted_data total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ total_ws_bytes += num_softmax_outs * sizeof(T); - const int bytes_for_fc1_result = interbuf_size * sizeof(T); - const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows))); - sorter_.update_num_experts(num_experts); + const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + sorter_.update_num_experts(static_cast(num_experts)); - int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; if (sorter_ws_size_bytes > bytes_for_fc1_result) { - int remaining_bytes = static_cast(pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result)); + size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result); bytes_for_intermediate_and_sorting += remaining_bytes; } @@ -565,39 +585,140 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, } template -void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, int num_rows, - const int hidden_size, const int inter_size, - int num_experts, int k) { - const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); - const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); - const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); - const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); - - source_rows_ = (int*)ws_ptr; +void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, size_t num_rows, + const size_t hidden_size, const size_t inter_size, + size_t num_experts, size_t k) { + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); + + source_rows_ = reinterpret_cast(ws_ptr); permuted_rows_ = source_rows_ + num_moe_inputs; permuted_experts_ = permuted_rows_ + num_moe_inputs; - permuted_data_ = (T*)(permuted_experts_ + num_moe_inputs); + permuted_data_ = reinterpret_cast(permuted_experts_ + num_moe_inputs); - total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size); + total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); - fc1_result_ = (T*)(total_rows_before_expert_ + padded_experts); + if (has_fc3_) { + fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + } else { + fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + } const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = (T*)(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); } else { softmax_out_ = nullptr; } } +namespace { + +struct __align__(8) Half4 { + half2 x; + half2 y; +}; + +// TODO(wy): move to common header +template +struct T4; +template <> +struct T4 { + using Type = float4; +}; +template <> +struct T4 { + using Type = Half4; +}; + +template +struct T2; +template <> +struct T2 { + using Type = float2; +}; +template <> +struct T2 { + using Type = half2; +}; + +inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } + +inline __device__ float4 operator*(const float4 a, const float4 b) { + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} + +// TODO(wy): use cuda common header and investigate pipeline build issue. +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ + ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) +inline __device__ half operator*(const half a, const half b) { + return __float2half(__half2float(a) * __half2float(b)); +} + +inline __device__ half2 operator*(const half2 a, const half2 b) { + return make_half2(a.x * b.x, a.y * b.y); +} +#endif + +// TODO(wy): use cuda common header and investigate pipeline build issue. +inline __device__ Half4 operator*(const Half4 a, const Half4 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ + ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + Half4 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; +#else + return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; +#endif +} + +} // anonymous namespace + +template +__global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_size) { + int const tid = threadIdx.x; + int const token = blockIdx.x; + + output = output + token * inter_size; + input = input + token * inter_size; + for (int i = tid; i < inter_size; i += blockDim.x) { + T fc1_value = input[i]; + output[i] = fc1_value * output[i]; + } +} + +template +void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, cudaStream_t stream) { + int const blocks = num_tokens; + + if (inter_size & 3 == 0) { + using vec_type = typename T4::Type; + int const threads = std::min(inter_size / 4, 1024); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 4); + } else if (inter_size & 1 == 0) { + using vec_type = typename T2::Type; + int const threads = std::min(inter_size / 2, 1024); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 2); + } else { + int const threads = std::min(inter_size, 1024); + elementWiseMulKernel<<>>(output, input, inter_size); + } +} + template void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, - const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, - const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, - const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, - int* expert_for_source_row, cudaStream_t stream) { + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, + const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, + int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, + int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, + T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { static constexpr bool scales_required = std::is_same::value || std::is_same::value; @@ -615,13 +736,14 @@ void CutlassMoeFCRunner::run_moe_fc( } } - configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k); + configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k)); topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, - source_rows_, num_rows, num_experts, k, stream); + source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream); const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); - sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, - permuted_rows_, k * num_rows, stream); + sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, + source_rows_, permuted_rows_, k * num_rows, stream); initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k, @@ -635,33 +757,63 @@ void CutlassMoeFCRunner::run_moe_fc( dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, stream); } - // expanded_active_expert_rows is not used - moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, - fc1_expert_weights, fc1_scales, fc1_expert_biases, - fc1_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, inter_size, hidden_size, - local_num_experts, fc1_activation_type, stream); + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, + fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + inter_size, hidden_size, local_num_experts, fc1_activation_type, stream); + + if (has_fc3_) { + if (scales_required) { + if (fc3_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for third matmul is a null pointer"); + } + } else { + if (fc3_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3"); + } + } + if (fc3_expert_weights == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] FC3 weights are null"); + } + moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, fc3_expert_weights, fc3_scales, + fc3_expert_biases, fc3_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + inter_size, hidden_size, local_num_experts, stream); - moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, - fc2_expert_weights, fc2_scales, + elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size, + static_cast(inter_size), static_cast(total_covered_rows_), stream); + } + + moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, fc2_expert_weights, fc2_scales, nullptr, fc2_result + total_past_rows_ * hidden_size, - total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream); + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + hidden_size, inter_size, local_num_experts, stream); } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 +template +void CutlassMoeFCRunner::run_moe_fc(const T*, const T*, const WeightType*, const T*, const T*, + ActivationType, const WeightType*, const T*, const T*, + const WeightType*, const T*, int, const int, const int, int, + int, int, int k, char*, T*, T*, int*, int*, cudaStream_t) { + // MoE gemm only supports Volta+ architectures + ORT_THROW("[FT Error][Run MoE FC] MoE gemm only supports Volta+ architectures"); +} +#else template void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, - const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, - const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, + const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, + int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, + int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, - fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, local_num_experts, - local_experts_start_index, k, workspace_ptr, fc2_result, nullptr, num_rows, expert_scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream); + fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size, + inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result, + nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream); } +#endif template void CutlassMoeFCRunner::compute_total_rows_before_expert(const int* sorted_indices, @@ -677,8 +829,8 @@ void CutlassMoeFCRunner::compute_total_rows_before_expert } template -void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, - int num_experts, int local_num_experts, +void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, int num_experts, + int local_num_experts, int local_experts_start_index, cudaStream_t stream) { total_rows_before_expert_host_.resize(num_experts); @@ -692,16 +844,15 @@ void CutlassMoeFCRunner::dispatch_activations(int64_t* to cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); cudaEventRecord(copy_event, stream); - dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, - local_num_experts, local_experts_start_index); + dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, local_num_experts, + local_experts_start_index); get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); } template void CutlassMoeFCRunner::get_total_rows_info(int64_t experts_start_index, - int64_t local_num_experts, - int64_t& total_past_rows, + int64_t local_num_experts, int64_t& total_past_rows, int64_t& total_covered_rows) { int64_t experts_end_index = experts_start_index + local_num_experts - 1; total_past_rows = 0; @@ -758,8 +909,8 @@ __global__ void initialize_moe_routing_kernel(const T* unpermuted_input, T* perm template void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, int num_rows, - int active_rows, int cols, int k, cudaStream_t stream) { + int* expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, + int cols, int k, cudaStream_t stream) { const int blocks = num_rows * k; const int threads = std::min(cols, 1024); initialize_moe_routing_kernel @@ -813,9 +964,10 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; const int expert_idx = expert_for_source_row[k_offset]; - const T* bias_ptr = bias + expert_idx * cols; + const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; - thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + thread_output = + thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + (bias_ptr ? bias_ptr[tid] : T(0))); } reduced_row_ptr[tid] = thread_output; } @@ -825,8 +977,8 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream) { + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream) { const int blocks = num_rows; const int threads = std::min(cols, 1024); finalize_moe_routing_kernel<<>>( @@ -838,8 +990,8 @@ template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream) { + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream) { const int blocks = num_rows; const int threads = std::min(cols, 1024); finalize_moe_routing_kernel @@ -851,8 +1003,8 @@ template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, const T* skip_2, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream) { + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream) { const int blocks = num_rows; const int threads = std::min(cols, 1024); if (skip_2 == nullptr) { @@ -867,20 +1019,21 @@ void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* red } // ========================= TopK Softmax specializations =========================== -template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, - int, int, cudaStream_t); -template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, - int, int, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, int, int, + bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, + bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; // ===================== Specializations for init routing ========================= -template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, - int, int, cudaStream_t); -template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, - int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, + cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, + cudaStream_t); // ==================== Specializations for final routing =================================== template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, @@ -888,17 +1041,12 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, - const int*, const int*, int, int, int, - cudaStream_t); + const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, - const float*, const int*, const int*, int, int, int, - cudaStream_t); + const float*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, - const half*, const int*, const int*, int, int, int, - cudaStream_t); + const half*, const int*, const int*, int, int, int, cudaStream_t); } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index f09471de1cc2..5eef6f95f482 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "moe_gemm_kernels.h" @@ -26,6 +24,8 @@ #include "core/common/common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" +#include "cutlass/numeric_types.h" + using namespace onnxruntime; namespace ort_fastertransformer { @@ -109,12 +109,13 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); - size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k); + size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, @@ -122,6 +123,7 @@ class CutlassMoeFCRunner { void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, @@ -137,7 +139,8 @@ class CutlassMoeFCRunner { int64_t& total_covered_rows); private: - void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k); + void configure_ws_ptrs(char* ws_ptr, size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, + size_t k); private: CubKeyValueSorter sorter_; @@ -154,12 +157,17 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + T* fc3_result_; + + bool has_fc3_; + bool normalize_routing_weights_; // Cuda events contrib::cuda::AutoDestoryCudaEvent cuda_event_; int64_t total_past_rows_; int64_t total_covered_rows_; + // TODO: use pinned memory std::vector total_rows_before_expert_host_; }; @@ -167,13 +175,11 @@ class CutlassMoeFCRunner { template class CutlassMoeFCRunner::value>> { public: - CutlassMoeFCRunner(int sm_version); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); - size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k) { + size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { return 0; } }; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 0da06192e266..dbd783c0cb11 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" @@ -15,39 +13,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MoE, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MoE); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), MoE); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) -using namespace ONNX_NAMESPACE; - template -MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { -} +MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} template Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); - const Tensor* fc2_experts_weights = context->Input(3); - const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc1_experts_bias_optional = context->Input(3); + const Tensor* fc2_experts_weights = context->Input(4); const Tensor* fc2_experts_bias_optional = context->Input(5); + const Tensor* fc3_experts_weights_optional = context->Input(6); + const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, - fc1_experts_bias_optional, fc2_experts_bias_optional)); + MoEQuantType quant_type = MoEQuantType::None; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, + fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, + fc3_experts_weights_optional, fc3_experts_bias_optional)); typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); @@ -55,12 +47,12 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, + normalize_routing_weights_); - size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), - static_cast(k_)); + size_t ws_size = moe_runner.getWorkspaceSize( + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_)); size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); @@ -79,26 +71,29 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr expert_for_source_row = IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); - // fc1_scales and fc2_scales are used in quantized MoE - const CudaT* fc1_scales_ptr = nullptr; - const CudaT* fc2_scales_ptr = nullptr; - - moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), - reinterpret_cast(router_probs->template Data()), - reinterpret_cast(fc1_experts_weights->template Data()), - std::move(fc1_scales_ptr), - fc1_experts_bias_optional == nullptr - ? nullptr - : reinterpret_cast(fc1_experts_bias_optional->template Data()), - activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), - std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), - static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size), - static_cast(moe_params.num_experts), static_cast(moe_params.local_num_experts), - 0 /*local_experts_start_index_ used in sharded MoE*/, static_cast(k_), - reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), - reinterpret_cast(expert_scales.get()), - reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), - reinterpret_cast(expert_for_source_row.get()), Stream(context)); + const CudaT* fc_scales_ptr = nullptr; + moe_runner.run_moe_fc( + reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->DataRaw()), fc_scales_ptr, + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, + fc3_experts_weights_optional == nullptr ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->DataRaw()), + fc_scales_ptr, + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->DataRaw()), fc_scales_ptr, + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), 0 /*local_experts_start_index_ used in sharded MoE*/, + static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); Tensor* output = context->Output(0, input->Shape()); @@ -110,8 +105,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), - static_cast(moe_params.hidden_size), - static_cast(k_), Stream(context)); + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); return Status::OK(); } @@ -119,5 +113,3 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 710b914f0633..c4d8c4dc64c5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -26,5 +24,3 @@ class MoE final : public CudaKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index dc8b9d57f79f..4a407fa1b215 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "core/common/common.h" +#include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" #include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h" @@ -15,27 +14,36 @@ namespace cuda { enum class MoEParallelType { None = 0, - ExpertSlicing = 1, + EP = 1, + TP = 2, + EPAndTP = 3, +}; + +enum class MoEQuantType { + None = 0, + UINT4 = 1, }; struct MoEParameters { + MoEParameters() {} + explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {} int64_t num_rows; int64_t num_experts; int64_t local_num_experts; int64_t hidden_size; int64_t inter_size; + MoEParallelType parallel_type; + int64_t tensor_shards{1}; }; class MoEBase { public: - Status CheckInputs(MoEParameters& parameters, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc2_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_bias_optional) const { + Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input, + const Tensor* router_probs, const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional) const { const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); @@ -45,7 +53,7 @@ class MoEBase { int64_t hidden_size = input_dims[input_dims.size() - 1]; int64_t local_num_experts = fc1_experts_weights_dims[0]; int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = fc1_experts_weights_dims[2]; + int64_t inter_size = fc2_experts_weights_dims[1]; if (fc1_experts_weights_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", @@ -63,20 +71,21 @@ class MoEBase { if (fc2_experts_weights_dims[1] != inter_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims[1] must be equal to inter_size, got ", - fc2_experts_weights_dims[1], - " and ", inter_size); + fc2_experts_weights_dims[1], " and ", inter_size); } - if (fc1_experts_weights_dims[2] != inter_size) { + + const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; + if (fc1_experts_weights_dims[2] != inter_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - fc1_experts_weights_dims[2], - " and ", inter_size); + fc1_experts_weights_dims[2], " and ", inter_size); } - if (fc2_experts_weights_dims[2] != hidden_size) { + if (fc2_experts_weights_dims[2] != hidden_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", fc2_experts_weights_dims[2], " and ", hidden_size); } + if (router_probs_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", router_probs_dims.size()); @@ -85,12 +94,6 @@ class MoEBase { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", router_probs_dims[0], " and ", num_rows); } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); - } - if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); - } if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); @@ -105,42 +108,99 @@ class MoEBase { if (fc1_experts_bias_dims[0] != local_num_experts) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ", - fc1_experts_bias_dims[0], - " and ", local_num_experts); + fc1_experts_bias_dims[0], " and ", local_num_experts); } if (fc2_experts_bias_dims[0] != num_experts) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[0] must be equal to num_experts, got ", - fc2_experts_bias_dims[0], + "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], " and ", num_experts); } if (fc1_experts_bias_dims[1] != inter_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", - fc1_experts_bias_dims[1], + "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], " and ", inter_size); } if (fc2_experts_bias_dims[1] != hidden_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", - fc2_experts_bias_dims[1], + "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], " and ", hidden_size); } } + if (fc3_experts_weights_optional != nullptr && + fc3_experts_weights_optional->Shape().GetDims() != fc1_experts_weights_dims) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc3_experts_weights_dims must be equal to fc1_experts_weights_dims, got ", + fc3_experts_weights_optional->Shape(), " and ", TensorShape(fc1_experts_weights_dims)); + } + + if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr && + fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ", + fc3_experts_bias_optional->Shape(), " and ", fc1_experts_bias_optional->Shape()); + } + parameters.num_rows = num_rows; parameters.num_experts = num_experts; parameters.local_num_experts = local_num_experts; parameters.hidden_size = hidden_size; parameters.inter_size = inter_size; if (num_experts == local_num_experts) { - parameters.parallel_type = MoEParallelType::None; + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::None; + } else { + parameters.parallel_type = MoEParallelType::TP; + } } else if (num_experts > local_num_experts) { - parameters.parallel_type = MoEParallelType::ExpertSlicing; + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::EP; + } else { + parameters.parallel_type = MoEParallelType::EPAndTP; + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_experts must be greater than or equal to local_num_experts, got ", - num_experts, " and ", local_num_experts); + "num_experts must be greater than or equal to local_num_experts, got ", num_experts, + " and ", local_num_experts); + } + + return Status::OK(); + } + + Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, + const Tensor* fc3_experts_scales, int64_t num_experts, int64_t hidden_size, + int64_t inter_size) const { + const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims(); + const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims(); + + if (fc1_experts_scales_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ", + fc1_experts_scales->Shape().GetDims().size()); + } + if (fc1_experts_scales_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", + fc1_experts_scales_dims[0], " and ", num_experts); + } + if (fc1_experts_scales_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ", + fc1_experts_scales_dims[1], " and ", inter_size); + } + if (fc2_experts_scales_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", + fc2_experts_scales->Shape().GetDims().size()); + } + if (fc2_experts_scales_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ", + fc2_experts_scales_dims[0], " and ", num_experts); + } + if (fc2_experts_scales_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ", + fc2_experts_scales_dims[1], " and ", hidden_size); + } + if (fc3_experts_scales != nullptr && fc1_experts_scales_dims != fc3_experts_scales->Shape().GetDims()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc3_experts_scales must be equal to fc1_experts_scales, got ", + fc3_experts_scales->Shape(), " and ", TensorShape(fc1_experts_scales_dims)); } return Status::OK(); @@ -163,8 +223,11 @@ class MoEBase { } else { ORT_THROW("Unsupported MoE activation type: ", activation_type_str); } + + normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; } + bool normalize_routing_weights_; int64_t k_; ort_fastertransformer::ActivationType activation_type_; }; @@ -172,5 +235,3 @@ class MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 705f2d49fe2b..168c69c69f00 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -106,6 +106,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_tensor = context->Input(8); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, @@ -152,7 +154,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { CudaT dequant_scale; CudaT input_scale = *(reinterpret_cast(input_scale_tensor->Data())); CudaT weight_scale = *(reinterpret_cast(weight_scale_tensor->Data())); - if (sizeof(T) == 2) { + if constexpr (sizeof(T) == 2) { dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale)); } else { dequant_scale = input_scale * weight_scale; diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 6b66f1d84e22..265adf22eeb6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -2,10 +2,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" @@ -21,7 +23,7 @@ namespace cuda { __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { half2 scale_half2 = {scale, scale}; - half zp_adjust = -scale * __short2half_rn(zp); + half zp_adjust = -scale * zp; half2 zp_adjust2 = {zp_adjust, zp_adjust}; alignas(16) half2 results[4]; @@ -56,41 +58,95 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f } template -__global__ void Dequantize4BitsKernel( +__global__ void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int block_size, - int blocks_per_K, - int blocks_per_threadblock, - int total_blks, - int shift) { - int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); - if (block_id >= total_blks) { + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (group_id >= total_groups) { return; } - int n_idx = block_id / blocks_per_K; - int kb_idx = block_id % blocks_per_K; - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + // T __shared__ zero_points_after_reorder[];//K + // T __shared__ scales_after_reorder[]; // K + // const int num_r_per_thread = k / 256; + + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + T* output_i = output + element_offset; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1)); + for (int i = 0; i < 8; i++) { + int32_t rid = reorder_idx_with_off[i]; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * __short2half_rn(zp); + output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * T(zp); + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (block_id >= total_groups) { + return; + } + int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - uint8_t zp = 8; - if (zero_points) { - zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; - zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + T zero_point_value; + if constexpr (std::is_same_v) { + const int scales_shape_x = groups_per_K; + const int zero_point_shape_x = (groups_per_K + 1) / 2; + int kb_idx = block_id % scales_shape_x; + int n_idx = block_id / scales_shape_x; + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + zero_point_value = static_cast(zp); + } else { + zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); } output = output + element_offset; - DequantizeEightElements(quant_value, scale, static_cast(zp), output); + DequantizeEightElements(quant_value, scale, zero_point_value, output); } -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2] + const int32_t* reorder_idx, int k, int n, int block_size, @@ -98,47 +154,79 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int blocks_per_K = k / block_size; - int total_blks = n * blocks_per_K; - int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); - int shift = static_cast(log2f(float(block_size))); - - Dequantize4BitsKernel<<>>( - output, - quant_data, - scales_data, - zero_points, - block_size, - blocks_per_K, - blocks_per_threadblock, - total_blks, - shift); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // total elemenets in quant_data + int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + if (!reorder_idx || std::is_same_v) { + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + // static_assert(std::is_same_v, "ZeroT must be uint8_t"); + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } return Status::OK(); } -template Status Dequantize4Bits( +template Status Dequantize4Bits( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); -template Status Dequantize4Bits( +template Status Dequantize4Bits( half* output, const uint8_t* quant_data, const half* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); /////////////////////////////////////////////////////////////////////////////// // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f9c09c55fd89..580b5087f3fa 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,18 +7,18 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const ZeroT* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bbcb7de99781..0534ed6dc7fc 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -117,7 +117,8 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 5b0e61e19701..1cec6f6a12f1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -1,15 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// -// This module define MatMulFp32Q4 operator, it is basically -// matmul float32 with right hand side being a 2-D matrix -// pre-packed and block-compacted into int4 -// - -#include "core/common/safeint.h" -#include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.h" + +#include + +#include "core/common/status.h" +#include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -19,40 +16,19 @@ namespace contrib { namespace cuda { using namespace onnxruntime::cuda; -template -class MatMulNBits final : public CudaKernel { - public: - MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_blk_{true}; -}; - template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); + const Tensor* reorder_idx = ctx->Input(4); const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); typedef typename ToCudaType::MappedType CudaT; @@ -67,76 +43,99 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - if (!is_4bit_done) { - int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); - auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - // column-wise block + bool is_4bit_done = (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (is_4bit_done) { + return Status::OK(); + } + + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, + (const CudaT*)zero_points_data, + reorder_idx_data, SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), SafeInt(N_), + SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 - cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); - T* b_data_cpu = new T[K_ * N_]; - cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); - delete[] b_data_cpu; +cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); +T* b_data_cpu = new T[K_ * N_]; +cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); +delete[] b_data_cpu; #endif - const CudaT alpha = ToCudaType::FromFloat(1.f); - const CudaT zero = ToCudaType::FromFloat(0.f); - - if (helper.OutputOffsets().size() == 1) { - CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( - GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, - SafeInt(helper.N()), - SafeInt(helper.M()), - SafeInt(helper.K()), - &alpha, - reinterpret_cast(b_data), - SafeInt(K_padded), - reinterpret_cast(a_data), - helper.Lda(transa), - &zero, - reinterpret_cast(Y->MutableData()), - helper.Ldc(), - GetDeviceProp())); - } + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 67384957d8dd..d4d583906b7f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -89,7 +89,7 @@ __device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* ha asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); } -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -120,7 +120,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, sums_half2[3] = sums_half2[3] + v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } #else -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -144,7 +144,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, } #endif -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h new file mode 100644 index 000000000000..f5c2c6c4e4fd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulNBits operator, it is basically +// matmul float with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// +#pragma once +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc new file mode 100644 index 000000000000..7bb0945615d3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/quantization/moe_quantization.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", BuildKernelDefConstraints()) \ + .TypeConstraint("T1", BuildKernelDefConstraints()), \ + QMoE); + +REGISTER_KERNEL() + +namespace { +template +struct ToCudaTypeWrapper : public ToCudaType {}; + +template <> +struct ToCudaTypeWrapper { + using MappedType = uint8_t; +}; + +template <> +struct ToCudaTypeWrapper { + using MappedType = cutlass::uint4b_t; +}; +} // anonymous namespace + +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} + +Status QMoE::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc1_scales = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_weights = context->Input(5); + const Tensor* fc2_scales = context->Input(6); + const Tensor* fc2_experts_bias_optional = context->Input(7); + const Tensor* fc3_experts_weights_optional = context->Input(8); + const Tensor* fc3_scales_optional = context->Input(9); + const Tensor* fc3_experts_bias_optional = context->Input(10); + + MoEParameters moe_params; + MoEQuantType quant_type = MoEQuantType::UINT4; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, + fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, + fc3_experts_weights_optional, fc3_experts_bias_optional)); + ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts, + moe_params.hidden_size, moe_params.inter_size)); + + // Support int4 only at the moment. We can add uint8 if needed. + static constexpr bool use_quint4x2 = true; + using T = MLFloat16; + using CudaT = typename ToCudaType::MappedType; + using CudaWeightT = typename ToCudaTypeWrapper::MappedType; + + auto stream = context->GetComputeStream(); + + auto& device_prop = GetDeviceProp(); + const int sm = device_prop.major * 10 + device_prop.minor; + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, + normalize_routing_weights_); + + size_t ws_size = moe_runner.getWorkspaceSize( + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_)); + size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr expert_scales = + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocatorUniquePtr expert_for_source_row = + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + + moe_runner.run_moe_fc( + reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->DataRaw()), + fc1_scales == nullptr ? nullptr : reinterpret_cast(fc1_scales->template Data()), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, + fc3_experts_weights_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->DataRaw()), + fc3_scales_optional == nullptr ? nullptr + : reinterpret_cast(fc3_scales_optional->template Data()), + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->DataRaw()), + fc2_scales == nullptr ? nullptr : reinterpret_cast(fc2_scales->template Data()), + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), 0 /*local_experts_start_index_ used in sharded MoE*/, + static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); + + Tensor* output = context->Output(0, input->Shape()); + + ort_fastertransformer::finalize_moe_routing_kernelLauncher( + reinterpret_cast(fc2_output.get()), reinterpret_cast(output->template MutableData()), + fc2_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc2_experts_bias_optional->template Data()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h new file mode 100644 index 000000000000..7b68d2d082de --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "contrib_ops/cuda/moe/moe_base.h" +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +class QMoE final : public CudaKernel, public MoEBase { + public: + explicit QMoE(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3cecebedae2f..12835978536e 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -142,7 +142,7 @@ inline void debug_print([[maybe_unused]] const T* arr, std::cout << "========" << name << std::endl; for (size_t i = 0; i < sz; i++) { if (i % w == 0) std::cout << std::endl; - if (std::is_same().value) { + if constepxr (std::is_same::value) { std::cout << (int)buf[i] << ", "; } else { std::cout << buf[i] << ", "; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu index f4d5a7b404a6..fd4b51f40fb4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu @@ -151,7 +151,7 @@ QOrderBatchInt8MatrixTransposeKernel(const int8_t* src, const int8_t* dst, const } } -Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int batch_size, const int rows, const int cols, const int8_t* input, int8_t* output) { ORT_ENFORCE(rows % 4 == 0 && cols % 4 == 0, "Matrix rows and cols must be divisible by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu index baff8e76ec73..e6ac0bc8a517 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu @@ -389,7 +389,7 @@ QOrderDequantizeKernel_Strict(const int8_t* __restrict__ src, const __half* __re } } -Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int8_t* src, __half* dst, float scale, size_t N) { ORT_RETURN_IF(N & 0x3LL, "N can not divide by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 2a90e4911f28..08cbb145a6f6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,6 +49,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index b31f5d243e00..4cfa89a4d58c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -203,23 +203,19 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { - Print(name, reinterpret_cast(tensor), dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { @@ -227,9 +223,14 @@ void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +} + +void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { @@ -242,6 +243,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int d DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } +void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); +} + void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); @@ -252,22 +258,31 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { + Print(name, reinterpret_cast(tensor), dim0, dim1); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); } void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { @@ -301,43 +316,52 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { } +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +} + void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int, int) const { diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h index 264ecd7cfe2f..773401f79531 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h @@ -16,20 +16,31 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons public: CudaTensorConsoleDumper() = default; virtual ~CudaTensorConsoleDumper() {} - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const half* tensor, int dim0, int dim1) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index dbd7fb010462..eb1943b59d97 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1,11 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + +// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4706) +#endif +#include +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#include + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "cub/util_type.cuh" -#include -#include + #include "contrib_ops/cuda/bert/utils.cuh" #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" @@ -307,12 +318,13 @@ __device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_ return beams_[beams_count_ - 1].score < current_score; } +template __device__ void BeamHypotheses::Output( int top_k, int max_length, int pad_token_id, int32_t* sequences, // buffer of shape (num_return_sequences, max_length) - float* sequences_scores) // buffer of shape (num_return_sequences) or empty + T* sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences for (int index = 0; index < top_k; index++) { @@ -327,7 +339,7 @@ __device__ void BeamHypotheses::Output( target[i] = pad_token_id; if (sequences_scores) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -501,13 +513,14 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp next_beam_tokens.data()); } +template __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, const int32_t* sequences_buffer, int sequence_length, BeamHypotheses* beam_hyps_, const float* final_beam_scores, int32_t* output, - float* sequence_scores) { + T* sequence_scores) { int batch_index = blockIdx.x * blockDim.x + threadIdx.x; if (batch_index >= state.batch_size_) return; @@ -534,6 +547,7 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, sequence_scores ? sequence_scores + batch_index * state.num_return_sequences_ : nullptr); } +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -541,7 +555,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream) { BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state, sequences.data(), @@ -552,6 +566,58 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, sequence_scores.data()); } +template void LaunchBeamSearchScorer_Finalize( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span sequence_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScorer_Finalize<__half>( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span<__half> sequence_scores, + cudaStream_t stream); + +template +__global__ void FloatConvertAndCopyKernel(const float* src, T* dst, size_t total_elements) { + int64_t index = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (index < total_elements) { + dst[index] = (T)src[index]; + } +} + +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream) { + ORT_ENFORCE(final_scores.size() == output_scores.size()); + constexpr unsigned ThreadPerBlock = 256; + unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1))/ ThreadPerBlock); + + typedef typename ToCudaType::MappedType CudaT; + + FloatConvertAndCopyKernel<<>>( + final_scores.data(), (CudaT*)output_scores.data(), final_scores.size()); +} + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + __global__ void AddProbsKernel(float* log_probs, float* cum_log_probs, const int vocab_size, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 5ed5949196b2..281cb6c72597 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -65,11 +65,12 @@ struct BeamHypotheses { __device__ bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - __device__ void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - int pad_token_id, // pad token - int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) - float* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + __device__ void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + int pad_token_id, // pad token + int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) + T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) }; struct BeamScorerState { @@ -110,6 +111,7 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp gsl::span next_beam_indices, cudaStream_t stream); +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -117,9 +119,14 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps_, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream); +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + void LaunchNextTokenKernel(const int64_t* next_token_indices, int32_t* next_indices, int32_t* next_tokens, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 380d561bbb23..7adc2fe0a67e 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->decoder_start_token_id, + parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); @@ -620,6 +628,8 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return false; } // For CUDA we speculatively run the next step while we wait for the GPU to report status. We use 'IsDoneLater()' for this bool IsDoneLater() const override; @@ -632,7 +642,6 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { } gsl::span GetNextIndicesGPU() override { return next_beam_indices_; } - private: mutable cuda::AutoDestoryCudaEvent event_process_complete_; IAllocatorUniquePtr state_cpu_; IAllocatorUniquePtr state_gpu_; @@ -743,22 +752,58 @@ bool CudaBeamSearchScorer::IsDoneLater() const { return state_cpu_->not_done_count_ == 0; } +template +void CudaOutputSequenceScores(CudaBeamSearchScorer* scorer, + transformers::ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). + gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; + + // Score of each sequence, with shape (batch_size * num_return_sequences). + using CudaT = typename ToCudaType::MappedType; + gsl::span sequence_scores; + if (output_sequence_scores) { + sequence_scores = gsl::span{(CudaT*)output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + } + + cuda::LaunchBeamSearchScorer_Finalize(scorer->state_cpu_->batch_size_, + *scorer->state_gpu_, + sequences.GetCurrentDeviceSequences(), + sequences.GetSequenceLength(), + scorer->beam_hyps_, + final_beam_scores, + output, + sequence_scores, + scorer->stream_); +} + void CudaBeamSearchScorer::Finalize(transformers::ISequences& sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) { ORT_ENFORCE(output_sequences != nullptr); - // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). - gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; - - // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; - if (output_sequence_scores) { - sequence_scores = gsl::span{output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); } +} - cuda::LaunchBeamSearchScorer_Finalize(state_cpu_->batch_size_, *state_gpu_, sequences.GetCurrentDeviceSequences(), sequences.GetSequenceLength(), beam_hyps_, final_beam_scores, output, sequence_scores, stream_); +void CudaBeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } + } } std::unique_ptr CreateBeamScorer(const transformers::IGenerationParameters& parameters, diff --git a/onnxruntime/contrib_ops/js/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/js/bert/rotary_embedding.cc new file mode 100644 index 000000000000..7ee168e27f6f --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/rotary_embedding.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "rotary_embedding.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX(RotaryEmbedding, kMSDomain, 1, kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/rotary_embedding.h b/onnxruntime/contrib_ops/js/bert/rotary_embedding.h new file mode 100644 index 000000000000..376b4e7082fb --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/rotary_embedding.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class RotaryEmbedding final : public JsKernel { + public: + explicit RotaryEmbedding(const OpKernelInfo& info) : JsKernel(info) { + int64_t interleaved = info.GetAttrOrDefault("interleaved", 0); + int64_t num_heads = info.GetAttrOrDefault("num_heads", 0); + int64_t rotary_embedding_dim = info.GetAttrOrDefault("rotary_embedding_dim", 0); + float scale = info.GetAttrOrDefault("scale", 1.0); + + JSEP_INIT_KERNEL_ATTRIBUTE(RotaryEmbedding, ({ + "interleaved" : !!$1, + "numHeads" : $2, + "rotaryEmbeddingDim" : $3, + "scale" : $4, + }), + static_cast(interleaved), static_cast(num_heads), + static_cast(rotary_embedding_dim), scale); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fast_gelu.cc b/onnxruntime/contrib_ops/js/fast_gelu.cc new file mode 100644 index 000000000000..62c538318160 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fast_gelu.h b/onnxruntime/contrib_ops/js/fast_gelu.h new file mode 100644 index 000000000000..68c7892741c6 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(FastGelu, FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 498a9f5679eb..a6f8aebc2d1e 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,12 +8,17 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -24,13 +29,20 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/layer_norm.cc b/onnxruntime/contrib_ops/js/layer_norm.cc new file mode 100644 index 000000000000..814543a9905e --- /dev/null +++ b/onnxruntime/contrib_ops/js/layer_norm.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_data_types.h" +#include "core/providers/js/operators/layer_norm.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", onnxruntime::js::JsepSupportedFloatTypes()) + .TypeConstraint("U", onnxruntime::js::JsepSupportedFloatTypes()), + onnxruntime::js::LayerNorm); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc new file mode 100644 index 000000000000..888db0fd161f --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/quantization/matmul_nbits.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", JsepSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h new file mode 100644 index 000000000000..cca2c4757765 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class MatMulNBits final : public JsKernel { + public: + MatMulNBits(const OpKernelInfo& info) : JsKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, + nbits_{narrow(info.GetAttr("bits"))}, + block_size_{narrow(info.GetAttr("block_size"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), + "Block size must be a power of 2 and greater than or equal to 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ + "k" : $1, + "n" : $2, + "accuracyLevel" : $3, + "bits" : $4, + "blockSize" : $5 + }), + static_cast(K_), + static_cast(N_), + static_cast(accuracy_level_), + static_cast(nbits_), + static_cast(block_size_)); + } + + private: + const size_t K_; + const size_t N_; + const int64_t accuracy_level_; + const size_t nbits_; + const size_t block_size_; +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/skip_layer_norm.cc b/onnxruntime/contrib_ops/js/skip_layer_norm.cc index f949326e1dc9..dc2c4ab75f2f 100644 --- a/onnxruntime/contrib_ops/js/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/js/skip_layer_norm.cc @@ -14,10 +14,16 @@ ONNX_OPERATOR_KERNEL_EX( kMSDomain, 1, kJsExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", JsepSupportedFloatTypes()) - .TypeConstraint("U", JsepSupportedFloatTypes()), - SkipLayerNorm); + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + SkipLayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SkipSimplifiedLayerNormalization, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + SkipLayerNorm); } // namespace js } // namespace contrib diff --git a/onnxruntime/contrib_ops/js/skip_layer_norm.h b/onnxruntime/contrib_ops/js/skip_layer_norm.h index c3011e96ae29..ead5146aa96d 100644 --- a/onnxruntime/contrib_ops/js/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/js/skip_layer_norm.h @@ -11,19 +11,20 @@ namespace js { using onnxruntime::js::JsKernel; +template class SkipLayerNorm final : public JsKernel { public: SkipLayerNorm(const OpKernelInfo& op_kernel_info) : JsKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); - ORT_ENFORCE(epsilon_ >= 0); + float epsilon; + ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon).IsOK()); + ORT_ENFORCE(epsilon >= 0); JSEP_INIT_KERNEL_ATTRIBUTE(SkipLayerNormalization, ({ - "epsilon" : $1 + "epsilon" : $1, + "simplified" : !!$2 }), - epsilon_); + epsilon, + static_cast(simplified)); } - - private: - float epsilon_; }; } // namespace js diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh index 0599318a4022..be8508670e4b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh @@ -31,7 +31,7 @@ using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecializatio using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; @@ -141,6 +141,35 @@ std::vector, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu index 181e47f012c9..2e32a6594d16 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu index 1577bdf397fa..91da8d9e1f9a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu index 14de59234356..b08123be1897 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 78983ac95e67..54dda4bfa6d2 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -732,122 +732,154 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { +template +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); using Nop = ck::tensor_operation::element_wise::PassThrough; using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), + "attention mode is not supported, got ", params->attention->mode); + if constexpr (USE_BIAS) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer == nullptr, "biased version only support input with bias"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer != nullptr, "non-biased version only support input without bias"); + } + if constexpr (USE_MASK) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), + "mask type is not supported, got ", params->attention->mask_type); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer == nullptr, "masked version only support input with mask"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); + } + + auto attn = params->attention; + const int& G0 = attn->batch_size; + const int& G1 = attn->num_heads; + const int& M = attn->sequence_length; + const int& N = attn->total_sequence_length; + const int& K = attn->head_size; + const int& O = attn->v_head_size; + { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); + } + + auto [qs, ks, vs] = GetQkvStrides(attn); + std::vector q_buffer_lengths = {G0, G1, M, K}; + std::vector q_buffer_strides = qs.template ForBNSHCoord>(); + std::vector k_buffer_lengths = {G0, G1, N, K}; + std::vector k_buffer_strides = ks.template ForBNSHCoord>(); + std::vector v_buffer_lengths = {G0, G1, O, N}; + std::vector v_buffer_strides = vs.template ForBNHSCoord>(); + std::vector out_buffer_lengths = {G0, G1, M, O}; + std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 + + std::array bias_buffers{}; + std::array, kNumBiasBuffer> bias_lengths{}; + std::array, kNumBiasBuffer> bias_strides{}; + if constexpr (USE_BIAS) { + bias_buffers[0] = const_cast(params->bias_buffer); + bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + bias_strides[0] = {G1 * M * N, M * N, N, 1}; + } + if constexpr (USE_MASK) { + bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; + bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + if (params->mask_index_dims.size() == 2) { // [B,T] + bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; + } else if (params->mask_index_dims.size() == 3) { // [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else { + ORT_ENFORCE(false, "Unreachable"); + } + } + + auto arg = impl->MakeArgumentPointer( + params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, + bias_buffers, // Gemm1 bias, as attention mask + {}, // Gemm2 bias + q_buffer_lengths, q_buffer_strides, + k_buffer_lengths, k_buffer_strides, + v_buffer_lengths, v_buffer_strides, + out_buffer_lengths, out_buffer_strides, + bias_lengths, bias_strides, + {}, + {}, + Nop{}, + Nop{}, + Acc0ElementOp{params->scale}, + Nop{}, + Nop{}); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + + if constexpr (USE_MASK) { + ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); + } + + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); +} + +template +auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; using D0DataType = typename ck::detail::tuple_concat< std::conditional_t, ck::Tuple<>>, std::conditional_t, ck::Tuple<>>>::type; - constexpr static auto MaskingSpec = + constexpr static auto MaskingSpecMaskDisabled = ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + constexpr static auto MaskingSpecMaskOutUpperTriangle = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + + std::vector>>> + ret; - std::vector>>> ret; for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) { + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { auto type_string = impl->GetTypeString(); auto invoker = impl->MakeInvokerPointer(); auto op = [impl = std::move(impl), invoker = std::move(invoker)]( const GemmSoftmaxGemmPermuteParams* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } + params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { + auto type_string = impl->GetTypeString(); - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->sequence_length != params->attention->total_sequence_length, + "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); + + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } + return ret; } #endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc deleted file mode 100644 index 9cb414e4e898..000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/fast_gelu.h" - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/miopen_common.h" -#include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -using namespace ONNX_NAMESPACE; - -template -Status FastGelu::ComputeInternal(OpKernelContext* context) const { - ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); - - const Tensor* input = context->Input(0); - const Tensor* bias = context->Input(1); - Tensor* output = context->Output(0, input->Shape()); - - int64_t input_length = input->Shape().Size(); - if (input_length == 0) { - return Status::OK(); - } - int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - const HipT* input_buffer = reinterpret_cast(input->Data()); - const HipT* bias_buffer = (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr; - return LaunchElementwiseKernel( - GetTuningContext(), context->GetComputeStream(), - input_buffer, static_cast(input_length), - bias_buffer, static_cast(bias_length), - reinterpret_cast(output->MutableData())); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h deleted file mode 100644 index 42bfe5a0b024..000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class FastGelu final : public RocmKernel { - public: - FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 6f98312e4067..09e7d61b71db 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -68,6 +68,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; @@ -121,8 +122,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, - num_heads_, mask_filter_value_, scale_, + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h index 84d8b76bbfeb..1d676d7a7bca 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public RocmKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc deleted file mode 100644 index e82e15a304f4..000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define GROUP_NORM_TYPES float, MLFloat16 - -ONNX_OPERATOR_KERNEL_EX( - GroupNorm, kMSDomain, 1, kRocmExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); - -using namespace ONNX_NAMESPACE; - -namespace { -template -struct DispatchGroupNorm { - Status operator()(RocmTuningContext* tuning_ctx, - Stream* stream, - Tensor* output, - const Tensor* input, - const Tensor* gamma, - const Tensor* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_swish_activation) { - typedef typename ToHipType::MappedType HipT; - return LaunchGroupNormKernel( - tuning_ctx, - stream, - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - gamma->Data(), - beta->Data(), - workspace, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_swish_activation); - } -}; - -} // namespace - -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { - epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); - ORT_ENFORCE(epsilon_ >= 0); - - int64_t num_groups; - ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); - ORT_ENFORCE(num_groups >= 0); - num_groups_ = static_cast(num_groups); - - int64_t activation; - ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); - ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish - use_swish_activation_ = (activation == 1); - - channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); -} - -Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { - is_packed = false; - return Status::OK(); -} - -Status GroupNorm::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* gamma = context->Input(1); - const Tensor* beta = context->Input(2); - Tensor* output = context->Output(0, input->Shape()); - - if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "only the channels_last layout is supported"); - } - - const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 4 dimensions, got ", input_dims.size()); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in gamma and input does not match"); - } - - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in beta and input does not match"); - } - - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisible by num_groups"); - } - - if (context->GetUseDeterministicCompute()) { - static std::once_flag log_warning; - std::call_once(log_warning, []() { - LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic."; - }); - } - - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); - - utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(GetTuningContext(), context->GetComputeStream(), - output, input, gamma, beta, workspace.get(), - epsilon_, - batch_size, - num_channels, - height, - width, - num_groups_, - use_swish_activation_); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index fb7091592c16..d0a0d09fcbae 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -26,13 +26,18 @@ namespace rocm { using onnxruntime::rocm::CKDataTypeAdaptor; -using Swish = ck::tensor_operation::element_wise::Swish; +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 5; constexpr int NumReduceDim = 3; -template +template auto GetCKGroupNormNHWCTypeStringAndOps() { using XDataType = typename CKDataTypeAdaptor::type; using YDataType = typename CKDataTypeAdaptor::type; @@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { using GammaDataType = float; using BetaDataType = float; - using Activation = std::conditional_t; + using Activation = std::conditional_t; - std::vector>>> ret; + std::vector>>> ret; for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; auto invoker = impl->MakeInvokerPointer(); - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams* params) -> Status { - if constexpr (WithSwish) { + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->withSwish, "Swish version only support groupnorm with swish"); + !params->use_silu, "Silu version only support groupnorm with silu"); } else { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->withSwish, "Pass version only support groupnorm without swish"); + params->use_silu, "Pass version only support groupnorm without silu"); } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->cPerGroup, 1}; + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; std::vector reduce_dims{1, 2, 4}; auto activation = Activation{}; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 19b081881dce..4cb371fdcf96 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -18,7 +18,7 @@ namespace internal { using F16 = ck::half_t; using F32 = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface @@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() { template <> std::vector>> + F16, F32, F32, F16, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Silu, 5, 3>(); template <> std::vector std::vector>> + F32, F32, F32, F32, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Swish, 5, 3>(); + F32, F32, F32, F32, F32, Silu, 5, 3>(); template <> std::vector -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f16_instances{}); + device_normalization_f16_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 9b0ccab17b4c..ceb53ed442ab 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -11,12 +11,12 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f32_instances{}); + device_normalization_f32_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 008ae20b0561..7cff640db2f3 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -8,110 +8,47 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::CeilDiv; - -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; - } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; - } - } - } - return maxDivisor; -} - template -struct GroupNormNHWCParams : OpParams { - GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma, - const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish) - : OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) { - int32_t maxBlocksPerHW = 1024; - switch (c) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; - } - - hw = h * w; - const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW); - hwPerBlock = CeilDiv(hw, blocksPerHW); - cPerGroup = c / groups; - hwc = hw * c; - invHWC = 1.F / (float)(hw * cPerGroup); - groupsPerBlock = cPerBlock / cPerGroup; - } +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} std::string Signature() const override { - std::string swish_suffix = withSwish ? "_Swish" : "_Pass"; - std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix; + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; return sig; } - - // The output buffer. Layout NHWC. - T* dst; - // The input buffer. Layout NHWC. - T const* src; - // The gamma scaling factor. - float const* gamma; - // The beta term to add in GN. - float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; - float epsilon; - - // The number of instances in the batch. - int32_t n; - // The height and width of each activation map. - int32_t h; - int32_t w; - // The number of channels. - int32_t c; - // The number of groups. - int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; - - // Precomputed values and parameters to control the execution of the kernels. - - // The number of activations per instance (h * w) and the number of - // activations per block. - int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; - // The precomputed number of groups per block. - int32_t groupsPerBlock; }; } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index dbd5009e6367..142aaf14e8d2 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -15,9 +15,12 @@ namespace rocm { template Status LaunchGroupNormKernel( RocmTuningContext* tuning_ctx, - Stream* stream, + Stream* ort_stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -27,19 +30,26 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - GroupNormNHWCParams params(tuning_ctx, stream, output, reinterpret_cast(workspace), input, gamma, beta, - batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation); + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); if (tuning_ctx->IsTunableOpEnabled()) { static GroupNormNHWCTunableOp op; @@ -50,14 +60,17 @@ Status LaunchGroupNormKernel( } template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h deleted file mode 100644 index a0f7e0aca5de..000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/common.h" -#include "core/common/status.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::RocmTuningContext; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { - // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; -} - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh deleted file mode 100644 index d6322a12a936..000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCm kernel is modified from TensorRT 8.5. -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sumSq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = static_cast(input_v.val[i]); - sum += val; - sumSq += val * val; - } -} - -template -__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw, - int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) { - // The object in charge of doing the sums for the different blocks. - typedef hipcub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[ThreadsPerBlock]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (ILP channels per thread). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // The sums. - float sum = 0.F; - float sumSq = 0.F; - - // Iterate over the activations to compute the sums. - if (ci < c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * hwc + static_cast(hwi) * c + ci; - UpdateSum(src, offset, sum, sumSq); - } - } - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * ILP / cPerGroup; - int32_t cj = threadIdx.x * ILP - cPerGroup * gi; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; - - // Do the segmented scan. - GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == cPerGroup - ILP) { // ILP channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The global group index. - int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= groupsPerBlock || gj >= groups) { - return; - } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x); - atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y); -} - -template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev, - const U* gamma_v, const U* beta_v, bool swish) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - VecT output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - U val = static_cast(input_v.val[i]); - val = (val - mean) * invStdDev; - val = gamma_v[i] * val + beta_v[i]; - - if (swish) { - val = val * sigmoid(val); - } - output_v.val[i] = static_cast(val); - } - *(reinterpret_cast(dst + offset)) = output_v; -} - -template -__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock, - int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) { - // The channel loaded by that thread (ILP channels per thread for F16x2). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - if (ci >= c) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / cPerGroup; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; - if (gi < groups) { - sum = redBuffer[(2 * ni + 0) * groups + gi]; - sumSq = redBuffer[(2 * ni + 1) * groups + gi]; - } - - using VecF = onnxruntime::rocm::aligned_vector; - - const VecF gamma_v = *reinterpret_cast(gamma + ci); - const VecF beta_v = *reinterpret_cast(beta + ci); - - // Compute the mean. - float mean = sum * invHWC; - // Compute the variance. - float var = sumSq * invHWC - (mean * mean); - // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon); - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * hwc + hwi * c + ci; - - // Fetch ILP channels per thread. - computeGroupNorm(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b7b9441ac997..c6ca16bfdfc8 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -20,21 +20,21 @@ namespace rocm { namespace { -template +template std::string GetGroupNormTritonGroupName() { std::string ret = "GroupNormTriton_"; - std::string swish_suffix = WithSwish ? "Swish_" : "Pass_"; - ret += swish_suffix; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; ret += GetDataTypeName(); return ret; } } // namespace -template +template auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); auto* kernel_list = GetOrtTritonKernelByGroup(group_name); if (kernel_list == nullptr) { return ret; @@ -45,36 +45,50 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto* metadata = GetOrtTritonKernelMetadata(i); auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCParams* params) -> Status { + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ")."); + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSwish) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); } // Construct args for launch kernel struct { - void* X; - void* Y; + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; const void* gamma; const void* beta; int hw; int c; int c_per_group; float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; } args = { - (void*)params->src, + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, (void*)params->dst, + (void*)params->skip_workspace, (const void*)params->gamma, (const void*)params->beta, params->hw, params->c, - params->cPerGroup, - params->epsilon}; + params->channels_per_group, + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; // Grid dim is (batch_count, groups, 1) return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 56b3a030b289..5ba96ebc117f 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -12,16 +12,22 @@ @triton.jit def group_norm_kernel( input_ptr, + skip_ptr, + bias_ptr, output_ptr, + add_out_ptr, gamma_ptr, beta_ptr, img_size, c, c_per_group, eps, + has_skip, + has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, - ACTIVATION_SWISH: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, ): row_x = tl.program_id(0) row_y = tl.program_id(1) @@ -36,14 +42,35 @@ def group_norm_kernel( offsets = hw[:, None] * c + cols[None, :] mask = (cols < c_per_group)[None, :] + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + # Calculate mean and variance _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias _sum += a _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) # Set axis=None (or leave it unspecified) to reduce all axes. # TODO: In older Triton we have to reduce an axis at a time, but in our case @@ -57,12 +84,16 @@ def group_norm_kernel( gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c y_ptr = output_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta - if ACTIVATION_SWISH: + if ACTIVATION_SILU: y *= tl.sigmoid(y) tl.store(y_ptr + offsets, y, mask=mask) @@ -71,27 +102,27 @@ def group_norm_kernel( # blocks = [16, 32, 64, 128, 256, 512] # hw_sizes = [8, 16, 32, 64, 128, 256, 512] # but this will result in too many functions and slow down the compilation. -with_swish = [True, False] +with_silu = [True, False] dtypes = ["fp32", "fp16"] blocks = [16, 32, 64, 128] hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}" def get_function_table(): func_table = [] - for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks): - swish_suffix = "Swish" if swish else "Pass" - name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(swish_suffix, dtype) - sig = sig_pattern.format(dtype, dtype) + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) kwargs = { "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)}, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, } func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} func_table.append(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index 25d820f7ed32..e6831f764b41 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -20,115 +20,117 @@ namespace rocm { using onnxruntime::rocm::GPU_WARP_SIZE; template -void groupNormNHWCSum(const GroupNormNHWCParams* params) { - // Make sure the values are as we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - groupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->src, params->redBuffer, params->cPerBlock, \ - params->hwPerBlock, params->hw, params->hwc, params->c, \ - params->cPerGroup, params->groups, params->groupsPerBlock); \ +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SUM(256, 2) - case 480: - LAUNCH_GROUPNORM_SUM(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SUM(128, 2) + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SUM(64, 2) + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCSumOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCSumKernel + GroupNormNHWCSumKernel <<StreamHandle()>>>( - params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock, - params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock); + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); return HIP_CALL(hipGetLastError()); } template -void groupNormNHWCScale(const GroupNormNHWCParams* params) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - groupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->gamma, params->beta, \ - params->redBuffer, params->epsilon, params->c, params->cPerBlock, \ - params->cPerGroup, params->groups, params->hwc, params->invHWC, \ - params->hw, params->hwPerBlock, params->withSwish); \ +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SCALE(256, 2) - case 480: - LAUNCH_GROUPNORM_SCALE(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SCALE(128, 2) + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SCALE(64, 2) + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCScaleOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCScaleKernel + GroupNormNHWCScaleKernel <<StreamHandle()>>>( - params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock, - params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish); + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); return HIP_CALL(hipGetLastError()); } template class GroupNormNHWCOp { public: - Status operator()(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); auto status = GroupNormNHWCSumOp(params); ORT_RETURN_IF_ERROR(status); HIP_RETURN_IF_ERROR(hipGetLastError()); @@ -138,29 +140,30 @@ class GroupNormNHWCOp { return Status::OK(); } - Status IsSupported(const GroupNormNHWCParams* params) { + Status IsSupported(const GroupNormNHWCTunableParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup, + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 && - params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0), - "The value of attributes don't meet the requirements."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize && - params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock); + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); return Status::OK(); } }; template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); - groupNormNHWCSum(params); +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); HIP_RETURN_IF_ERROR(hipGetLastError()); - groupNormNHWCScale(params); + GroupNormNHWCScale(params); HIP_RETURN_IF_ERROR(hipGetLastError()); return Status::OK(); } @@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 320) template -class GroupNormNHWCTunableOp : public TunableOp> { +class GroupNormNHWCTunableOp : public TunableOp> { public: GroupNormNHWCTunableOp() { this->RegisterOp(GroupNormNHWCStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc index d597e0d57fbc..63804f79a32f 100644 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ b/onnxruntime/contrib_ops/rocm/fused_conv.cc @@ -76,7 +76,12 @@ struct FNVHash { void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { int spatial_dim = 1; #if ROCM_VERSION >= 50500 - miopenGetConvolutionSpatialDim(cdesc, &spatial_dim); + MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); + std::vector pads{spatial_dim}; + std::vector strides{spatial_dim}; + std::vector dilations{spatial_dim}; + miopenConvolutionMode_t mode; + MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); #else // Previous versions of MIOpen doesn't provide API to probe the dimension of a // miopenConvolutionDescriptor_t, so we have to guess. @@ -100,11 +105,12 @@ struct FNVHash { pads.resize(spatial_dim); strides.resize(spatial_dim); dilations.resize(spatial_dim); +#endif (*this) << spatial_dim; (*this) << pads; (*this) << strides; (*this) << dilations; -#endif + (*this) << mode; } private: @@ -313,6 +319,8 @@ class FusedConv : public onnxruntime::rocm::Conv { auto ret = miopenCompileFusionPlan(handle, fusion->plan); if (miopenStatusSuccess == ret) { fusion->compiled_on.insert(handle); + } else { + return ret; } return miopenStatusSuccess; } diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 55cd6a1d112f..e19a976f3141 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -150,7 +151,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); #endif -#if defined(USE_MPI) && defined(ORT_USE_NCCL) +#ifdef ORT_USE_NCCL class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); @@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -309,7 +311,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif -#if defined(USE_MPI) && defined(ORT_USE_NCCL) +#ifdef ORT_USE_NCCL BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index fcf9c2b03dea..be881f6bc4bc 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -30,6 +30,10 @@ #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #endif // ARM #endif // Linux @@ -48,6 +52,13 @@ #if defined(CPUINFO_SUPPORTED) #include +#if defined(CPUIDINFO_ARCH_ARM) +namespace onnxruntime { +// The following function is declared in "core/common/cpuid_uarch.h" but we cannot include the whole header file because +// some of its symbols are conflict with +void decodeMIDR(uint32_t midr, uint32_t uarch[1]); +} // namespace onnxruntime +#endif #else #include "core/common/cpuid_uarch.h" #endif // CPUINFO_SUPPORTED @@ -138,16 +149,12 @@ void CPUIDInfo::ArmLinuxInit() { // Pytorch CPUINFO only works on ARM linux or android // Assuming no hyper-threading, no NUMA groups #ifdef CPUINFO_SUPPORTED - pytorch_cpuinfo_init_ = cpuinfo_initialize(); - if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; - return; - } is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -177,6 +184,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); #endif } @@ -233,51 +241,24 @@ void CPUIDInfo::ArmWindowsInit() { lastUarch = uarch; } } - - switch (lastUarch) { - case cpuinfo_uarch_cortex_a55: - case cpuinfo_uarch_cortex_a55r0: - case cpuinfo_uarch_cortex_a76: - case cpuinfo_uarch_neoverse_n1: - case cpuinfo_uarch_cortex_a77: - case cpuinfo_uarch_exynos_m4: - case cpuinfo_uarch_exynos_m5: - has_fp16_ = true; - break; - default: - break; - } - if (!has_fp16_) { - /* - * Detecting fp16 support. Different cores should have the same instruction set. - * So we just check the first ID_AA64PFR0_EL1 - * Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000), - */ - uint64_t ID_AA64PFR0_EL1; - unsigned long valsize = sizeof(uint64_t); - auto retCode = ::RegGetValueA( - HKEY_LOCAL_MACHINE, - "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", - "CP 4020", RRF_RT_REG_QWORD, nullptr, - &ID_AA64PFR0_EL1, &valsize); - if (retCode == ERROR_SUCCESS) { - // AdvSIMD, bits [23:20] - auto advSimd = ID_AA64PFR0_EL1 >> 20; - if ((advSimd & 0xfULL) == 1) { - has_fp16_ = true; - } - } - } #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); #else has_arm_neon_dot_ = false; #endif - has_fp16_ |= has_arm_neon_dot_; - /* TODO: implement them when hw+sw is available for testing these features */ - has_arm_neon_i8mm_ = false; - has_arm_sve_i8mm_ = false; + + if (pytorch_cpuinfo_init_) { + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + } else { + has_fp16_ = false; + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; + } } #endif /* (arm or arm64) and windows */ @@ -297,5 +278,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { return 0xFFFFFFFF; // don't know how to get core index #endif } - +CPUIDInfo::CPUIDInfo() { +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) +#if CPUINFO_SUPPORTED + pytorch_cpuinfo_init_ = cpuinfo_initialize(); + if (!pytorch_cpuinfo_init_) { + LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; + } +#endif +#ifdef __linux__ + ArmLinuxInit(); +#elif defined(_WIN32) + ArmWindowsInit(); +#endif /* (arm or arm64) and windows */ +#endif +} } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a15c75104b83..a3936b4bd11a 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -30,6 +30,7 @@ class CPUIDInfo { bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } uint32_t GetCurrentCoreIdx() const; @@ -92,17 +93,7 @@ class CPUIDInfo { } private: - CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) -#ifdef __linux__ - ArmLinuxInit(); -#elif defined(_WIN32) - ArmWindowsInit(); -#endif /* (arm or arm64) and windows */ -#endif - } + CPUIDInfo(); bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -125,15 +116,18 @@ class CPUIDInfo { bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; #ifdef CPUIDINFO_ARCH_X86 void X86Init(); - #elif defined(CPUIDINFO_ARCH_ARM) + // Now the following var is only used in ARM build, but later one we may expand the usage. + bool pytorch_cpuinfo_init_{false}; +#endif + #ifdef __linux__ - bool pytorch_cpuinfo_init_{false}; void ArmLinuxInit(); #elif defined(_WIN32) @@ -141,7 +135,6 @@ class CPUIDInfo { void ArmWindowsInit(); #endif /* (arm or arm64) and windows */ -#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/common/flatbuffers.h b/onnxruntime/core/common/flatbuffers.h new file mode 100644 index 000000000000..0d61e1038a82 --- /dev/null +++ b/onnxruntime/core/common/flatbuffers.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#if defined(__GNUC__) +#include "onnxruntime_config.h" +#pragma GCC diagnostic push + +#ifdef HAS_SHORTEN_64_TO_32 +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif +#endif + +#include "flatbuffers/flatbuffers.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif \ No newline at end of file diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc index 6c6e2f48557e..eac9a7fa0808 100644 --- a/onnxruntime/core/common/logging/logging.cc +++ b/onnxruntime/core/common/logging/logging.cc @@ -12,6 +12,8 @@ #ifdef _WIN32 #include +#include "core/platform/windows/logging/etw_sink.h" +#include "core/common/logging/sinks/composite_sink.h" #else #include #if defined(__MACH__) || defined(__wasm__) || defined(_AIX) @@ -243,5 +245,36 @@ unsigned int GetProcessId() { #endif } +std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity, + logging::Severity etwSeverity) { +#ifdef _WIN32 + auto& manager = EtwRegistrationManager::Instance(); + if (manager.IsEnabled()) { + auto compositeSink = std::make_unique(); + compositeSink->AddSink(std::move(existingLogger), originalSeverity); + compositeSink->AddSink(std::make_unique(), etwSeverity); + return compositeSink; + } else { + return existingLogger; + } +#else + // On non-Windows platforms, just return the existing logger + (void)originalSeverity; + (void)etwSeverity; + return existingLogger; +#endif // _WIN32 +} + +Severity OverrideLevelWithEtw(Severity originalSeverity) { +#ifdef _WIN32 + auto& manager = logging::EtwRegistrationManager::Instance(); + if (manager.IsEnabled() && + (manager.Keyword() & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + return manager.MapLevelToSeverity(); + } +#endif // _WIN32 + return originalSeverity; +} + } // namespace logging } // namespace onnxruntime diff --git a/onnxruntime/core/common/logging/sinks/composite_sink.h b/onnxruntime/core/common/logging/sinks/composite_sink.h index f27abb9e6aad..9d18eb527ffd 100644 --- a/onnxruntime/core/common/logging/sinks/composite_sink.h +++ b/onnxruntime/core/common/logging/sinks/composite_sink.h @@ -5,6 +5,8 @@ #include #include +#include +#include #include "core/common/logging/isink.h" #include "core/common/logging/logging.h" @@ -27,20 +29,31 @@ class CompositeSink : public ISink { /// Adds a sink. Takes ownership of the sink (so pass unique_ptr by value). ///
/// The sink. + /// The min severity to send a message to that sink /// This instance to allow chaining. - CompositeSink& AddSink(std::unique_ptr sink) { - sinks_.push_back(std::move(sink)); + CompositeSink& AddSink(std::unique_ptr sink, logging::Severity severity) { + sinks_with_severity_.emplace_back(std::move(sink), severity); return *this; } + /// + /// Gets a const reference to the collection of sinks and min severity for that sink + /// + /// A const reference to the vector pair of unique_ptr to ISink and severity. + const std::vector, logging::Severity>>& GetSinks() const { + return sinks_with_severity_; + } + private: void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) override { - for (auto& sink : sinks_) { - sink->Send(timestamp, logger_id, message); + for (auto& sink_pair : sinks_with_severity_) { + if (message.Severity() >= sink_pair.second) { + sink_pair.first->Send(timestamp, logger_id, message); + } } } - std::vector> sinks_; + std::vector, logging::Severity>> sinks_with_severity_; }; } // namespace logging } // namespace onnxruntime diff --git a/onnxruntime/core/common/logging/sinks/ostream_sink.cc b/onnxruntime/core/common/logging/sinks/ostream_sink.cc index 0db3d8709d48..a120138d1d15 100644 --- a/onnxruntime/core/common/logging/sinks/ostream_sink.cc +++ b/onnxruntime/core/common/logging/sinks/ostream_sink.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/common/logging/sinks/ostream_sink.h" -#include "date/date.h" namespace onnxruntime { namespace logging { @@ -24,7 +23,7 @@ struct Color { void OStreamSink::SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) { // operator for formatting of timestamp in ISO8601 format including microseconds - using date::operator<<; + using timestamp_ns::operator<<; // Two options as there may be multiple calls attempting to write to the same sink at once: // 1) Use mutex to synchronize access to the stream. diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index eca1221e84cb..716eed1afec5 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -65,5 +65,24 @@ inline std::string TrimString(std::string s) { return s; } +/** + * @brief A consistent way to construct the full qualified op name. + */ +inline std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) { + return MakeString(domain, "::", op_type); +} + +/** + * Use this simple hash to generate unique int by given string input. + */ +inline uint32_t GetHashFromString(const std::string& str_value) { + uint32_t hash = 0; + for (char const& c : str_value) { + hash = hash * 101 + c; + } + + return hash; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/common/utf8_util.h b/onnxruntime/core/common/utf8_util.h index 218309f7198d..583aaf0a47cf 100644 --- a/onnxruntime/core/common/utf8_util.h +++ b/onnxruntime/core/common/utf8_util.h @@ -8,8 +8,13 @@ namespace onnxruntime { namespace utf8_util { -// Returns the number of bytes in the utf8 character -// by analyzing its leading byte +/// +/// Checks the extension bytes and returns a number of +/// bytes in the UTF-8 character +/// +/// +/// result +/// false if the char len is greater than 4 otherwise true inline bool utf8_bytes(unsigned char ch, size_t& len) { if ((ch & 0x80) == 0) { len = 1; diff --git a/onnxruntime/core/flatbuffers/checkpoint_version.h b/onnxruntime/core/flatbuffers/checkpoint_version.h index 6cad27c35024..e6ee20bf508c 100644 --- a/onnxruntime/core/flatbuffers/checkpoint_version.h +++ b/onnxruntime/core/flatbuffers/checkpoint_version.h @@ -13,7 +13,9 @@ namespace onnxruntime { // The format includes support for the ModuleState (stores the module parameters), OptimizerGroups // (stores the optimizer states), and PropertyBag // (stores custom user properties with support for int64, float and strings). -constexpr const int kCheckpointVersion = 1; +// Version 2: Introduces the On-Device Training nominal checkpoint state. +// Changes include the addition of the is_nominal_state field in the checkpoint's ModuleState. +constexpr const int kCheckpointVersion = 2; /** * @brief Check if the given checkpoint version is supported in this build diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.h b/onnxruntime/core/flatbuffers/flatbuffers_utils.h index 55bde0b2df80..76860d6ab1db 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.h +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.h @@ -5,7 +5,7 @@ #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/common.h" #include "core/common/path_string.h" diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py index 2be826fee2cc..19c6b1b6f275 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py @@ -74,9 +74,17 @@ def FrozenParamsIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 -def ModuleStateStart(builder): builder.StartObject(2) + # ModuleState + def IsNominalState(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def ModuleStateStart(builder): builder.StartObject(3) def ModuleStateAddRequiresGradParams(builder, requiresGradParams): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0) def ModuleStateStartRequiresGradParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ModuleStateAddFrozenParams(builder, frozenParams): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0) def ModuleStateStartFrozenParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ModuleStateAddIsNominalState(builder, isNominalState): builder.PrependBoolSlot(2, isNominalState, 0) def ModuleStateEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/schema/README.md b/onnxruntime/core/flatbuffers/schema/README.md index 932478111ee6..96a2936c196a 100644 --- a/onnxruntime/core/flatbuffers/schema/README.md +++ b/onnxruntime/core/flatbuffers/schema/README.md @@ -21,7 +21,7 @@ e.g. - /build/Linux/Debug/_deps/flatbuffers-build/flatc It is possible to use another flatc as well, e.g., from a separate installation. Note that ONNX Runtime uses -FlatBuffers 1.12. +FlatBuffers 23.5.26. To update the flatbuffers schemas and generated files: 1. Modify [the ORT file format schema](ort.fbs) or [training checkpoint schema](ort_training_checkpoint.fbs). diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs.h b/onnxruntime/core/flatbuffers/schema/ort.fbs.h index e0f5342c2962..dc8a471f2d81 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs.h @@ -4,7 +4,7 @@ #ifndef FLATBUFFERS_GENERATED_ORT_ONNXRUNTIME_FBS_H_ #define FLATBUFFERS_GENERATED_ORT_ONNXRUNTIME_FBS_H_ -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" namespace onnxruntime { namespace fbs { @@ -562,8 +562,8 @@ struct DimensionValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DIM_TYPE) && - VerifyField(verifier, VT_DIM_VALUE) && + VerifyField(verifier, VT_DIM_TYPE, 1) && + VerifyField(verifier, VT_DIM_VALUE, 8) && VerifyOffset(verifier, VT_DIM_PARAM) && verifier.VerifyString(dim_param()) && verifier.EndTable(); @@ -634,7 +634,7 @@ struct TensorTypeAndShape FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_ELEM_TYPE) && + VerifyField(verifier, VT_ELEM_TYPE, 4) && VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyTable(shape()) && verifier.EndTable(); @@ -687,7 +687,7 @@ struct MapType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_KEY_TYPE) && + VerifyField(verifier, VT_KEY_TYPE, 4) && VerifyOffset(verifier, VT_VALUE_TYPE) && verifier.VerifyTable(value_type()) && verifier.EndTable(); @@ -787,7 +787,7 @@ struct NodeEdge FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_NODE_INDEX) && + VerifyField(verifier, VT_NODE_INDEX, 4) && VerifyOffset(verifier, VT_INPUT_EDGES) && verifier.VerifyVector(input_edges()) && VerifyOffset(verifier, VT_OUTPUT_EDGES) && @@ -911,11 +911,11 @@ struct Node FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(doc_string()) && VerifyOffset(verifier, VT_DOMAIN) && verifier.VerifyString(domain()) && - VerifyField(verifier, VT_SINCE_VERSION) && - VerifyField(verifier, VT_INDEX) && + VerifyField(verifier, VT_SINCE_VERSION, 4) && + VerifyField(verifier, VT_INDEX, 4) && VerifyOffset(verifier, VT_OP_TYPE) && verifier.VerifyString(op_type()) && - VerifyField(verifier, VT_TYPE) && + VerifyField(verifier, VT_TYPE, 4) && VerifyOffset(verifier, VT_EXECUTION_PROVIDER_TYPE) && verifier.VerifyString(execution_provider_type()) && VerifyOffset(verifier, VT_INPUTS) && @@ -1174,7 +1174,7 @@ struct TypeInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DENOTATION) && verifier.VerifyString(denotation()) && - VerifyField(verifier, VT_VALUE_TYPE) && + VerifyField(verifier, VT_VALUE_TYPE, 1) && VerifyOffset(verifier, VT_VALUE) && VerifyTypeInfoValue(verifier, value(), value_type()) && verifier.EndTable(); @@ -1259,7 +1259,7 @@ struct OperatorSetId FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DOMAIN) && verifier.VerifyString(domain()) && - VerifyField(verifier, VT_VERSION) && + VerifyField(verifier, VT_VERSION, 8) && verifier.EndTable(); } }; @@ -1343,7 +1343,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(doc_string()) && VerifyOffset(verifier, VT_DIMS) && verifier.VerifyVector(dims()) && - VerifyField(verifier, VT_DATA_TYPE) && + VerifyField(verifier, VT_DATA_TYPE, 4) && VerifyOffset(verifier, VT_RAW_DATA) && verifier.VerifyVector(raw_data()) && VerifyOffset(verifier, VT_STRING_DATA) && @@ -1568,9 +1568,9 @@ struct Attribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(name()) && VerifyOffset(verifier, VT_DOC_STRING) && verifier.VerifyString(doc_string()) && - VerifyField(verifier, VT_TYPE) && - VerifyField(verifier, VT_F) && - VerifyField(verifier, VT_I) && + VerifyField(verifier, VT_TYPE, 4) && + VerifyField(verifier, VT_F, 4) && + VerifyField(verifier, VT_I, 8) && VerifyOffset(verifier, VT_S) && verifier.VerifyString(s()) && VerifyOffset(verifier, VT_T) && @@ -1759,12 +1759,12 @@ struct NodesToOptimizeIndices FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NODE_INDICES) && verifier.VerifyVector(node_indices()) && - VerifyField(verifier, VT_NUM_INPUTS) && - VerifyField(verifier, VT_NUM_OUTPUTS) && - VerifyField(verifier, VT_HAS_VARIADIC_INPUT) && - VerifyField(verifier, VT_HAS_VARIADIC_OUTPUT) && - VerifyField(verifier, VT_NUM_VARIADIC_INPUTS) && - VerifyField(verifier, VT_NUM_VARIADIC_OUTPUTS) && + VerifyField(verifier, VT_NUM_INPUTS, 4) && + VerifyField(verifier, VT_NUM_OUTPUTS, 4) && + VerifyField(verifier, VT_HAS_VARIADIC_INPUT, 1) && + VerifyField(verifier, VT_HAS_VARIADIC_OUTPUT, 1) && + VerifyField(verifier, VT_NUM_VARIADIC_INPUTS, 4) && + VerifyField(verifier, VT_NUM_VARIADIC_OUTPUTS, 4) && verifier.EndTable(); } }; @@ -1862,8 +1862,8 @@ struct DeprecatedNodeIndexAndKernelDefHash FLATBUFFERS_FINAL_CLASS : private fla } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_NODE_INDEX) && - VerifyField(verifier, VT_KERNEL_DEF_HASH) && + VerifyField(verifier, VT_NODE_INDEX, 4) && + VerifyField(verifier, VT_KERNEL_DEF_HASH, 8) && verifier.EndTable(); } }; @@ -2161,7 +2161,7 @@ struct Graph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_NODES) && verifier.VerifyVector(nodes()) && verifier.VerifyVectorOfTables(nodes()) && - VerifyField(verifier, VT_MAX_NODE_INDEX) && + VerifyField(verifier, VT_MAX_NODE_INDEX, 4) && VerifyOffset(verifier, VT_NODE_EDGES) && verifier.VerifyVector(node_edges()) && verifier.VerifyVectorOfTables(node_edges()) && @@ -2390,7 +2390,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_IR_VERSION) && + VerifyField(verifier, VT_IR_VERSION, 8) && VerifyOffset(verifier, VT_OPSET_IMPORT) && verifier.VerifyVector(opset_import()) && verifier.VerifyVectorOfTables(opset_import()) && @@ -2400,7 +2400,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(producer_version()) && VerifyOffset(verifier, VT_DOMAIN) && verifier.VerifyString(domain()) && - VerifyField(verifier, VT_MODEL_VERSION) && + VerifyField(verifier, VT_MODEL_VERSION, 8) && VerifyOffset(verifier, VT_DOC_STRING) && verifier.VerifyString(doc_string()) && VerifyOffset(verifier, VT_GRAPH) && @@ -2740,8 +2740,8 @@ struct ArgTypeAndIndex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_ARG_TYPE) && - VerifyField(verifier, VT_INDEX) && + VerifyField(verifier, VT_ARG_TYPE, 1) && + VerifyField(verifier, VT_INDEX, 4) && verifier.EndTable(); } }; diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs index c8244b0a426f..94757fa6d5bf 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs @@ -8,6 +8,10 @@ namespace onnxruntime.fbs; table ModuleState { requires_grad_params:[Tensor]; frozen_params:[Tensor]; + // Nominal state just means that the Tensors in the ModuleState + // are empty. i.e. The tensors are treated as named entities + // without any meaningful data. + is_nominal_state:bool; } table ParameterOptimizerState { diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h index 48feebb19769..62e6cf74394e 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h @@ -4,7 +4,7 @@ #ifndef FLATBUFFERS_GENERATED_ORTTRAININGCHECKPOINT_ONNXRUNTIME_FBS_H_ #define FLATBUFFERS_GENERATED_ORTTRAININGCHECKPOINT_ONNXRUNTIME_FBS_H_ -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "ort.fbs.h" @@ -39,7 +39,8 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ModuleStateBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_REQUIRES_GRAD_PARAMS = 4, - VT_FROZEN_PARAMS = 6 + VT_FROZEN_PARAMS = 6, + VT_IS_NOMINAL_STATE = 8 }; const flatbuffers::Vector> *requires_grad_params() const { return GetPointer> *>(VT_REQUIRES_GRAD_PARAMS); @@ -47,6 +48,9 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *frozen_params() const { return GetPointer> *>(VT_FROZEN_PARAMS); } + bool is_nominal_state() const { + return GetField(VT_IS_NOMINAL_STATE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_REQUIRES_GRAD_PARAMS) && @@ -55,6 +59,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_FROZEN_PARAMS) && verifier.VerifyVector(frozen_params()) && verifier.VerifyVectorOfTables(frozen_params()) && + VerifyField(verifier, VT_IS_NOMINAL_STATE, 1) && verifier.EndTable(); } }; @@ -69,6 +74,9 @@ struct ModuleStateBuilder { void add_frozen_params(flatbuffers::Offset>> frozen_params) { fbb_.AddOffset(ModuleState::VT_FROZEN_PARAMS, frozen_params); } + void add_is_nominal_state(bool is_nominal_state) { + fbb_.AddElement(ModuleState::VT_IS_NOMINAL_STATE, static_cast(is_nominal_state), 0); + } explicit ModuleStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -84,23 +92,27 @@ struct ModuleStateBuilder { inline flatbuffers::Offset CreateModuleState( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset>> requires_grad_params = 0, - flatbuffers::Offset>> frozen_params = 0) { + flatbuffers::Offset>> frozen_params = 0, + bool is_nominal_state = false) { ModuleStateBuilder builder_(_fbb); builder_.add_frozen_params(frozen_params); builder_.add_requires_grad_params(requires_grad_params); + builder_.add_is_nominal_state(is_nominal_state); return builder_.Finish(); } inline flatbuffers::Offset CreateModuleStateDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector> *requires_grad_params = nullptr, - const std::vector> *frozen_params = nullptr) { + const std::vector> *frozen_params = nullptr, + bool is_nominal_state = false) { auto requires_grad_params__ = requires_grad_params ? _fbb.CreateVector>(*requires_grad_params) : 0; auto frozen_params__ = frozen_params ? _fbb.CreateVector>(*frozen_params) : 0; return onnxruntime::fbs::CreateModuleState( _fbb, requires_grad_params__, - frozen_params__); + frozen_params__, + is_nominal_state); } struct ParameterOptimizerState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -194,8 +206,8 @@ struct OptimizerGroup FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_GROUP_NAME) && verifier.VerifyString(group_name()) && - VerifyField(verifier, VT_STEP) && - VerifyField(verifier, VT_INITIAL_LEARNING_RATE) && + VerifyField(verifier, VT_STEP, 8) && + VerifyField(verifier, VT_INITIAL_LEARNING_RATE, 4) && VerifyOffset(verifier, VT_OPTIMIZER_STATES) && verifier.VerifyVector(optimizer_states()) && verifier.VerifyVectorOfTables(optimizer_states()) && @@ -277,7 +289,7 @@ struct IntProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && - VerifyField(verifier, VT_VALUE) && + VerifyField(verifier, VT_VALUE, 8) && verifier.EndTable(); } }; @@ -341,7 +353,7 @@ struct FloatProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && - VerifyField(verifier, VT_VALUE) && + VerifyField(verifier, VT_VALUE, 4) && verifier.EndTable(); } }; @@ -560,7 +572,7 @@ struct Checkpoint FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_VERSION) && + VerifyField(verifier, VT_VERSION, 4) && VerifyOffset(verifier, VT_MODULE_STATE) && verifier.VerifyTable(module_state()) && VerifyOffset(verifier, VT_OPTIMIZER_GROUPS) && diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index ea7a6432a750..95e5380675df 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -175,14 +175,12 @@ class PlannerImpl { size_t num_logic_streams_{0}; std::vector> stream_nodes_; - InlinedVector node_stream_map_; // dependence_graph_ keeps the dependencies combining model graph and logic streams // e.g. dependence_graph_[downstream_node] = [upstream_node_0, upstream_node_1, upstream_node_2 ...] // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream InlinedHashMap> dependence_graph_; - InlinedHashMap> value_consumer_map_; InlinedHashMap value_node_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -295,7 +293,7 @@ class PlannerImpl { } #endif - // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. + // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; @@ -530,6 +528,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); + for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i; } bool HasExternalOutputs(const Node& node) const { @@ -1065,7 +1064,8 @@ class PlannerImpl { // build the consumer list for each value int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumer_map_.reserve(num_ml_values); + InlinedHashMap> value_consumer_map; + value_consumer_map.reserve(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1078,10 +1078,10 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer - value_consumer_map_[origin].insert(node_index); + value_consumer_map[origin].insert(node_index); } } return Status::OK(); @@ -1138,8 +1138,8 @@ class PlannerImpl { std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl; allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1168,8 +1168,8 @@ class PlannerImpl { allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1187,11 +1187,11 @@ class PlannerImpl { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1266,7 +1266,7 @@ class PlannerImpl { } bool all_covered = true; - for (auto consumer : value_consumer_map_[output_idx_global]) { + for (auto consumer : value_consumer_map[output_idx_global]) { if (deps->find(consumer) == deps->end()) { all_covered = false; break; @@ -1277,9 +1277,9 @@ class PlannerImpl { allocation_plan[downstream_value].reused_buffer = output_idx_global; get_reused = true; // add new consumer for the value to be reused - value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]); - value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(), - value_consumer_map_[downstream_value].end()); + value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); + value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1342,8 +1342,9 @@ class PlannerImpl { ort_value_usecount.reserve(ort_value_info_.size()); #endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough! ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1693,8 +1694,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } @@ -1721,9 +1722,9 @@ class PlannerImpl { // we actually can do better if all the consumers depends on the last consumer. // will optimize it later bool is_all_consumer_same_stream = true; - auto stream_idx = node_stream_map_[value_consumers[i][0]]; + auto stream_idx = plan_.node_stream_map_[value_consumers[i][0]]; for (size_t j = 1; j < value_consumers[i].size(); ++j) { - if (node_stream_map_[value_consumers[i][j]] != stream_idx) { + if (plan_.node_stream_map_[value_consumers[i][j]] != stream_idx) { is_all_consumer_same_stream = false; break; } @@ -1748,10 +1749,10 @@ class PlannerImpl { const PathString& /*partition_config_file*/) { if (graph_viewer_.NumberOfNodes() > 0) { stream_nodes_.push_back({}); - node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); + plan_.node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) { stream_nodes_[0].push_back(node_index); - node_stream_map_[node_index] = 0; + plan_.node_stream_map_[node_index] = 0; } num_logic_streams_ = 1; } @@ -1773,7 +1774,12 @@ class PlannerImpl { execution_plan.emplace_back(std::make_unique(node_device_mem_location)); // 2. add steps to the execution plan for (auto node_index : stream_nodes_[0]) { +#if defined(ORT_MINIMAL_BUILD) execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); +#else + execution_plan[0]->steps_.emplace_back(std::make_unique(node_index, + graph_viewer_.GetNode(node_index)->Name())); +#endif } } else { // graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer @@ -1790,10 +1796,10 @@ class PlannerImpl { auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file); auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder()); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); + plan_.node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); for (size_t i = 0; i < stream_nodes_.size(); ++i) { for (auto node_index : stream_nodes_[i]) { - node_stream_map_[node_index] = i; + plan_.node_stream_map_[node_index] = i; } } num_logic_streams_ = stream_nodes_.size(); @@ -1856,7 +1862,7 @@ class PlannerImpl { auto* node = graph_viewer_.GetNode(node_index); for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) { // if the output node is not in the same stream, generate a trigger point - if (node_stream_map_[it->Index()] != i + if (plan_.node_stream_map_[it->Index()] != i #ifdef ENABLE_TRAINING // Do not insert Barrier/TriggerDownStream step if the producer and consumer are in different sides of yieldOp // As in this case producer will surely be ready before the consumer is running. @@ -1889,9 +1895,9 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type(); + OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); - if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { + if ((plan_.node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { node_to_notification[node_index] = plan_.notification_owners.size(); plan_.notification_owners.push_back(i); @@ -1903,7 +1909,7 @@ class PlannerImpl { } // output->Exists } // for each output if (output_consumed_in_subgraph) { - const auto downstream = node_stream_map_[it->Index()]; + const auto downstream = plan_.node_stream_map_[it->Index()]; if (downstream != i) { auto downstream_device = execution_plan[downstream]->device_.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, downstream_device); @@ -1929,7 +1935,7 @@ class PlannerImpl { onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault); - ORT_ENFORCE(execution_plan[node_stream_map_[node_index]]->device_.Type() == node_device_mem_location.Type()); + ORT_ENFORCE(execution_plan[plan_.node_stream_map_[node_index]]->device_.Type() == node_device_mem_location.Type()); } } @@ -1978,8 +1984,12 @@ class PlannerImpl { // add dependency for model graph dependence_graph_[it->Index()].insert(node_index); } - // push launch kernel command +// push launch kernel command +#if defined(ORT_MINIMAL_BUILD) execution_plan[i]->steps_.emplace_back(std::make_unique(node_index)); +#else + execution_plan[i]->steps_.emplace_back(std::make_unique(node_index, graph_viewer_.GetNode(node_index)->Name())); +#endif // check if any notification generated by this node, if yes, push a activate auto notification_it = node_to_notification.find(node_index); if (notification_it != node_to_notification.end()) { @@ -2003,7 +2013,7 @@ class PlannerImpl { if (!node_output->Exists()) continue; OrtValueIndex output_idx_global; ORT_THROW_IF_ERROR(ort_value_name_idx_map_.GetIdx(node_output->Name(), output_idx_global)); - plan_.value_to_stream_map[output_idx_global] = node_stream_map_[node_index]; + plan_.value_to_stream_map[output_idx_global] = plan_.node_stream_map_[node_index]; value_node_map_[output_idx_global] = node_index; } } @@ -2079,7 +2089,7 @@ class PlannerImpl { } // trigger downstream for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) { - auto stream_idx = node_stream_map_[it->Index()]; + auto stream_idx = plan_.node_stream_map_[it->Index()]; if (stream_idx != i) { auto node_it = std::find(stream_nodes_[stream_idx].begin(), stream_nodes_[stream_idx].end(), it->Index()); int offset = static_cast(std::distance(stream_nodes_[stream_idx].begin(), node_it)); diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index e16b90ded338..5e4cd9f62f11 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -482,7 +482,7 @@ class BFCArena : public IAllocator { Bin* BinForSize(size_t bytes) { return BinFromIndex(BinNumForSize(bytes)); } - char bins_space_[sizeof(Bin) * kNumBins]; + alignas(Bin) char bins_space_[sizeof(Bin) * kNumBins]; // The size of the current region allocation. SafeInt curr_region_allocation_bytes_; diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index d9c49dc6bea1..32a5f749af08 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } +#ifdef ENABLE_TRAINING +void IExecutionFrame::ReleaseAllMLValues() { + for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) { + all_values_[ort_value_idx] = OrtValue(); + } +} +#endif + Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); @@ -223,7 +231,8 @@ void IExecutionFrame::Init(gsl::span feed_mlvalue_idxs, gsl::span& initializers, const std::function& is_initializer_sparse_func, gsl::span fetches) { - ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size()); + ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size(), "Get feed size: ", feeds.size(), " but expected feed size: ", + feed_mlvalue_idxs.size()); ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size()); // Need this for sparse conversions in host memory @@ -830,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { // This method is not thread safe! // Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { +#ifdef ENABLE_TRAINING + try { + auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); + return status; + } catch (const std::exception& e) { + LOGS(session_state_.Logger(), WARNING) + << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx + << "so clean up all OrtValues"; + ReleaseAllMLValues(); + return Status(ONNXRUNTIME, FAIL, e.what()); + } +#else return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); +#endif } void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684fa..18d210ffd48f 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -67,6 +67,8 @@ class IExecutionFrame { const std::unordered_map& initializers); Status GetOutputs(gsl::span fetch_mlvalue_idxs, std::vector& fetches); + // if OOM happens, then release all values, so session can run next batch. + void ReleaseAllMLValues(); #endif // TO DO: make it thread safe diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 7f8009216ce3..b39924d4c3ff 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -35,77 +35,4 @@ common::Status IExecutionProvider::Compile(const std::vector& } #endif - -int IExecutionProvider::ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer, - HashValue& model_hash) { - model_hash = 0; - - // find the top level graph - const Graph* cur_graph = &graph_viewer.GetGraph(); - while (cur_graph->IsSubgraph()) { - cur_graph = cur_graph->ParentGraph(); - } - - uint32_t instance_hash[4] = {0, 0, 0, 0}; - - const Graph& main_graph = *cur_graph; - - // hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use - // the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique - // fingerprint for the instance that can use used as the key to the hash of the model path/contents. - MurmurHash3::x86_128(&main_graph, gsl::narrow_cast(sizeof(Graph)), instance_hash[0], &instance_hash); - HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32); - - // if we've already hashed this main graph instance use the cached value - auto entry = main_graph_hash_.find(graph_instance_hash); - if (entry != main_graph_hash_.cend()) { - model_hash = entry->second; - } else { - uint32_t hash[4] = {0, 0, 0, 0}; - - // prefer path the model was loaded from - // this may not be available if the model was loaded from a stream or in-memory bytes - const auto& model_path_str = main_graph.ModelPath().ToPathString(); - if (!model_path_str.empty()) { - MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast(model_path_str.size()), hash[0], &hash); - } else { - auto hash_str = [&hash](const std::string& str) { - MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); - }; - - // fingerprint the main graph by hashing graph inputs and the ordered outputs from each node - for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) { - hash_str(node_arg->Name()); - } - - // note: process nodes in order defined in model to be deterministic - for (const auto& node : main_graph.Nodes()) { - for (const auto* node_arg : node.OutputDefs()) { - if (node_arg->Exists()) { - hash_str(node_arg->Name()); - } - } - } - } - - model_hash = hash[0] | (uint64_t(hash[1]) << 32); - - main_graph_hash_[graph_instance_hash] = model_hash; - } - - // return the current unique id, and increment to update - return model_metadef_id_[model_hash]++; -} - -int IExecutionProvider::GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const { - ORT_ENFORCE(metadef_id_generator_, - "IExecutionProvider constructor must be called with true for use_metadef_id_creator"); - - // if the EP is shared across multiple sessions there's a very small potential for concurrency issues. - // use a lock when generating an id to be paranoid - static OrtMutex mutex; - std::lock_guard lock(mutex); - return metadef_id_generator_->GenerateId(graph_viewer, model_hash); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index d97953fd9d5e..dc45cad692b6 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -3,7 +3,6 @@ #pragma once -// #include #include #include #include @@ -13,7 +12,10 @@ #include "core/graph/graph_viewer.h" #include "core/common/logging/logging.h" #ifdef _WIN32 +#include +#include #include "core/platform/tracing.h" +#include "core/platform/windows/telemetry.h" #endif namespace onnxruntime { @@ -43,20 +45,62 @@ class ExecutionProviders { exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); +#endif + + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } + +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { for (const auto& config_pair : providerOptions) { TraceLoggingWrite( telemetry_provider_handle, "ProviderOptions", + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "isCaptureState")); } -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); } +#endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); diff --git a/onnxruntime/core/framework/execution_steps.cc b/onnxruntime/core/framework/execution_steps.cc index df19236d037c..b647833cfd37 100644 --- a/onnxruntime/core/framework/execution_steps.cc +++ b/onnxruntime/core/framework/execution_steps.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #include "core/framework/execution_steps.h" #include "core/framework/sequential_executor.h" + namespace onnxruntime { + BarrierStep::BarrierStep(size_t id, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index), barrier_id_{id} {} @@ -16,8 +19,8 @@ Status BarrierStep::Execute(StreamExecutionContext& ctx, } std::string BarrierStep::ToString() const { - return ::onnxruntime::MakeString("Set a barrier with id: ", - barrier_id_, ", count: ", 2, "."); + // Set a barrier with id: barrier_id_, count: 2. + return MakeString("Barrier - BarrierId: ", barrier_id_, ", Count: ", 2); } WaitOnEPStep::WaitOnEPStep(WaitNotificationFn handle, @@ -42,11 +45,17 @@ Status WaitOnEPStep::Execute(StreamExecutionContext& ctx, } std::string WaitOnEPStep::ToString() const { - return ::onnxruntime::MakeString("WaitOnEPStep: wait on notification with id: ", - notification_idx_, ". "); + // Wait on notification with notification_idx_ + return MakeString("WaitOnEP - NotificationId: ", notification_idx_); } -LaunchKernelStep::LaunchKernelStep(NodeIndex index) : SequentialExecutionPlan::ExecutionStep(index) {} +#if defined(ORT_MINIMAL_BUILD) +LaunchKernelStep::LaunchKernelStep(NodeIndex index) + : SequentialExecutionPlan::ExecutionStep(index) {} +#else +LaunchKernelStep::LaunchKernelStep(NodeIndex index, std::string_view node_name) + : SequentialExecutionPlan::ExecutionStep(index), node_name_(node_name) {} +#endif Status LaunchKernelStep::Execute(StreamExecutionContext& ctx, size_t stream_idx, @@ -61,13 +70,17 @@ Status LaunchKernelStep::Execute(StreamExecutionContext& ctx, return Status::OK(); } #endif - onnxruntime::Status status = ExecuteKernel(ctx, node_index_, stream_idx, terminate_flag, session_scope); + Status status = ExecuteKernel(ctx, node_index_, stream_idx, terminate_flag, session_scope); continue_flag = status.IsOK(); return status; } std::string LaunchKernelStep::ToString() const { - return ::onnxruntime::MakeString("Launch kernel with node id: ", node_index_, ". "); +#if defined(ORT_MINIMAL_BUILD) + return MakeString("LaunchKernel - ", "NodeIndex: ", node_index_); +#else + return MakeString("LaunchKernel - ", "NodeIndex: ", node_index_, ", Name: ", node_name_); +#endif } ActivateNotificationStep::ActivateNotificationStep( @@ -89,12 +102,12 @@ Status ActivateNotificationStep::Execute(StreamExecutionContext& ctx, } std::string ActivateNotificationStep::ToString() const { - return ::onnxruntime::MakeString("ActivateNotificationStep: activate notification with id: ", - notification_idx_, ". "); + // Activate notification with id: notification_idx_ + return MakeString("ActivateNotification - NotificationId: ", notification_idx_); } -TriggerDownstreamStep::TriggerDownstreamStep(size_t trigger_point_index, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index), - trigger_point_index_(trigger_point_index) {} +TriggerDownstreamStep::TriggerDownstreamStep(size_t trigger_point_index, NodeIndex node_index) + : SequentialExecutionPlan::ExecutionStep(node_index), trigger_point_index_(trigger_point_index) {} Status TriggerDownstreamStep::Execute(StreamExecutionContext& ctx, size_t /*stream_idx*/, @@ -107,7 +120,8 @@ Status TriggerDownstreamStep::Execute(StreamExecutionContext& ctx, } std::string TriggerDownstreamStep::ToString() const { - return ::onnxruntime::MakeString("TriggerDownstreamStep: trigger downstream of trigger point: ", - trigger_point_index_, "."); + // Trigger downstream of trigger point: trigger_point_index_. + return MakeString("TriggerDownstream - TriggerPointIndex: ", trigger_point_index_); } + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_steps.h b/onnxruntime/core/framework/execution_steps.h index b67b58390082..545dabc56b27 100644 --- a/onnxruntime/core/framework/execution_steps.h +++ b/onnxruntime/core/framework/execution_steps.h @@ -44,7 +44,11 @@ class WaitOnEPStep : public SequentialExecutionPlan::ExecutionStep { class LaunchKernelStep : public SequentialExecutionPlan::ExecutionStep { public: +#if defined(ORT_MINIMAL_BUILD) LaunchKernelStep(NodeIndex index); +#else + LaunchKernelStep(NodeIndex index, std::string_view node_name); +#endif Status Execute(StreamExecutionContext& ctx, size_t stream_idx, @@ -53,6 +57,11 @@ class LaunchKernelStep : public SequentialExecutionPlan::ExecutionStep { bool& continue_flag) override; std::string ToString() const override; + +#if !defined(ORT_MINIMAL_BUILD) + private: + std::string node_name_; +#endif }; class ActivateNotificationStep : public SequentialExecutionPlan::ExecutionStep { diff --git a/onnxruntime/core/framework/feeds_fetches_manager.h b/onnxruntime/core/framework/feeds_fetches_manager.h index 75cb7485a6e3..c2c1be64f3e1 100644 --- a/onnxruntime/core/framework/feeds_fetches_manager.h +++ b/onnxruntime/core/framework/feeds_fetches_manager.h @@ -25,7 +25,7 @@ enum class DeviceCopyCheck { }; struct DeviceCopyChecks { - DeviceCopyCheck status = DeviceCopyCheck::Unknown; ///< Overall status. If NoCopy no input or output copies are needed + DeviceCopyCheck status = DeviceCopyCheck::Unknown; ///< Overall status. NoCopy means input_copy_needed and output_copy_needed are both NoCopy DeviceCopyCheck input_copy_needed = DeviceCopyCheck::Unknown; DeviceCopyCheck output_copy_needed = DeviceCopyCheck::Unknown; }; @@ -73,6 +73,9 @@ struct FeedsFetchesInfo { struct MLValueCopyInfo { OrtDevice source_device{}; OrtDevice target_device{}; // default is CPU + + // if all the consume ops are from the same stream, this variable is the stream index; otherwise -1 + int unique_stream_index_consumes_it = -1; }; class FeedsFetchesManager { diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e4fe0c756454..90ee8a46f66a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -16,6 +16,7 @@ #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -634,6 +635,98 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +static Status CreateEpContextModel(const ExecutionProviders& execution_providers, + const Graph& graph, + const std::string& ep_context_path, + const logging::Logger& logger) { + InlinedVector all_ep_context_nodes; + for (const auto& ep : execution_providers) { + const InlinedVector ep_context_nodes = ep->GetEpContextNodes(); + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); + } + + if (all_ep_context_nodes.size() < 1) { + return Status::OK(); + } + + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : all_ep_context_nodes) { + if (node_name == node->Name()) { + return std::make_pair(true, node); + } + } + return std::make_pair(false, static_cast(nullptr)); + }; + + onnxruntime::PathString context_cache_path; + PathString model_pathstring = graph.ModelPath().ToPathString(); + + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { +#ifdef _WIN32 + std::wifstream fs(context_cache_path); +#else + std::ifstream fs(context_cache_path); +#endif + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } + + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); + + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); + + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } + + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } + + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); + } + } + + // handle initializers + for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { + if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { + ep_graph.AddInitializedTensor(*initialized_tensor.second); + } + } + + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -840,6 +933,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. @@ -886,7 +981,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_)); + + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (ep_context_enabled) { + ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger)); + } #else + ORT_UNUSED_PARAMETER(config_options); + ORT_UNUSED_PARAMETER(logger); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 4fc85c258826..d1ef193cf152 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -13,6 +13,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; class Model; +struct ConfigOptions; class GraphPartitioner { public: @@ -31,6 +32,8 @@ class GraphPartitioner { // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index b2ef85311958..f8ccdb8fb023 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -24,7 +24,8 @@ Status KernelRegistryManager::CreateKernel(const Node& node, session_state.GetConstantInitializedTensors(), session_state.GetOrtValueNameIdxMap(), session_state.GetDataTransferMgr(), - session_state.GetAllocators()); + session_state.GetAllocators(), + session_state.GetSessionOptions().config_options); return kernel_create_info.kernel_create_func(session_state.GetMutableFuncMgr(), kernel_info, out); } diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h index 31a806dd5229..fea2a6ef3a43 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver.h @@ -7,7 +7,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc index 4f5fa9910b5d..423307b4c8fc 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -5,7 +5,7 @@ #include "core/framework/kernel_type_str_resolver_utils.h" -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/common.h" #include "core/flatbuffers/schema/ort.fbs.h" @@ -53,200 +53,240 @@ Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrRe // clang-format off constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = { 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, - 0x4c, 0x0b, 0x00, 0x00, 0xac, 0x08, 0x00, 0x00, 0xd0, 0x0a, 0x00, 0x00, 0x10, 0x06, 0x00, 0x00, - 0xa8, 0x07, 0x00, 0x00, 0x18, 0x03, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, - 0x44, 0x07, 0x00, 0x00, 0x9c, 0x01, 0x00, 0x00, 0xf8, 0x07, 0x00, 0x00, 0x78, 0x09, 0x00, 0x00, - 0x14, 0x01, 0x00, 0x00, 0x50, 0x06, 0x00, 0x00, 0x60, 0x02, 0x00, 0x00, 0xf4, 0x08, 0x00, 0x00, - 0x8c, 0x03, 0x00, 0x00, 0x9c, 0x02, 0x00, 0x00, 0x84, 0x06, 0x00, 0x00, 0xcc, 0x03, 0x00, 0x00, - 0x60, 0x05, 0x00, 0x00, 0xb8, 0x01, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0x08, 0x04, 0x00, 0x00, - 0xe0, 0x09, 0x00, 0x00, 0x8c, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xf4, 0xff, 0xff, - 0x08, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xda, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf4, 0xff, 0xff, - 0xd8, 0xf4, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, - 0x60, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, - 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf5, 0xff, 0xff, 0xa4, 0x0a, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfc, 0xf4, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x2c, 0xf5, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x4e, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x48, 0xf5, 0xff, 0xff, 0xc8, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0x30, 0xf5, 0xff, 0xff, 0x6c, 0xf5, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, - 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, - 0x31, 0x39, 0x00, 0x00, 0x9c, 0xf5, 0xff, 0xff, 0x3c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc2, 0xf5, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x94, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xc4, 0xf5, 0xff, 0xff, - 0xe8, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xb4, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xac, 0xf5, 0xff, 0xff, - 0xe8, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf6, 0xff, 0xff, 0xac, 0x05, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x88, 0x0d, 0x00, 0x00, + 0xec, 0x06, 0x00, 0x00, 0x68, 0x06, 0x00, 0x00, 0x1c, 0x08, 0x00, 0x00, 0xc8, 0x02, 0x00, 0x00, + 0x2c, 0x03, 0x00, 0x00, 0x80, 0x01, 0x00, 0x00, 0xc0, 0x09, 0x00, 0x00, 0xdc, 0x03, 0x00, 0x00, + 0x6c, 0x09, 0x00, 0x00, 0x64, 0x02, 0x00, 0x00, 0xbc, 0x0c, 0x00, 0x00, 0x04, 0x0d, 0x00, 0x00, + 0xd4, 0x00, 0x00, 0x00, 0x10, 0x04, 0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x68, 0x08, 0x00, 0x00, + 0x70, 0x03, 0x00, 0x00, 0xf0, 0x0d, 0x00, 0x00, 0x8c, 0x04, 0x00, 0x00, 0x6c, 0x05, 0x00, 0x00, + 0x94, 0x0a, 0x00, 0x00, 0x44, 0x0c, 0x00, 0x00, 0x28, 0x07, 0x00, 0x00, 0xc4, 0x05, 0x00, 0x00, + 0xc0, 0x09, 0x00, 0x00, 0x08, 0x0a, 0x00, 0x00, 0xb8, 0x08, 0x00, 0x00, 0x90, 0x01, 0x00, 0x00, + 0x5c, 0x07, 0x00, 0x00, 0xbc, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x24, 0xf2, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, + 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, + 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x60, 0xf2, 0xff, 0xff, 0x64, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x36, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xf8, 0xf5, 0xff, 0xff, 0x34, 0xf6, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, - 0x50, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, + 0x4e, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xf2, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x88, 0xf2, 0xff, 0xff, 0x10, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd8, 0xf2, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x70, 0xf2, 0xff, 0xff, 0xac, 0xf2, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, + 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0xe0, 0xf2, 0xff, 0xff, 0xb8, 0x0a, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbc, 0xf2, 0xff, 0xff, + 0xf8, 0xf2, 0xff, 0xff, 0xcc, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe6, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xf3, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x20, 0xf3, 0xff, 0xff, 0x50, 0x0a, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6c, 0xf3, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x3c, 0xf3, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x64, 0xf3, 0xff, 0xff, + 0xd4, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xb0, 0xf3, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x80, 0xf3, 0xff, 0xff, 0x90, 0x0c, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6e, 0xf3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x68, 0xf3, 0xff, 0xff, 0xa4, 0xf3, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, - 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x74, 0xf6, 0xff, 0xff, - 0x38, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x64, 0xf6, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x5c, 0xf6, 0xff, 0xff, - 0x98, 0xf6, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbe, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x90, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf6, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0xe4, 0xf6, 0xff, 0xff, - 0x2c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x0a, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xcc, 0xf6, 0xff, 0xff, - 0x08, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x30, 0xf7, 0xff, 0xff, 0xe0, 0x08, 0x00, 0x00, + 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0xe4, 0xf3, 0xff, 0xff, + 0xe0, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xd2, 0xf3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x3c, 0xf4, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0xf4, 0xff, 0xff, 0x8c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5c, 0xf4, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0xf4, 0xf3, 0xff, 0xff, 0x30, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, + 0x58, 0xf4, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x46, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x40, 0xf4, 0xff, 0xff, 0x7c, 0xf4, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0xa4, 0xf4, 0xff, 0xff, + 0x94, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xf0, 0xf4, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf4, 0xff, 0xff, 0x50, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x56, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0xf7, 0xff, 0xff, 0x54, 0xf7, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, - 0x78, 0xf7, 0xff, 0xff, 0x98, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9e, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x60, 0xf7, 0xff, 0xff, 0x9c, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, - 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x4d, 0x61, - 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0xd0, 0xf7, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, + 0xae, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa8, 0xf4, 0xff, 0xff, 0xe4, 0xf4, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, + 0x31, 0x31, 0x00, 0x00, 0x0c, 0xf5, 0xff, 0xff, 0x04, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfa, 0xf4, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xf4, 0xf4, 0xff, 0xff, 0x30, 0xf5, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x88, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x58, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, + 0x3a, 0x31, 0x00, 0x00, 0x7c, 0xf5, 0xff, 0xff, 0x94, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6a, 0xf5, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x64, 0xf5, 0xff, 0xff, 0xa0, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, + 0xc8, 0xf5, 0xff, 0xff, 0x48, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb6, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xb0, 0xf5, 0xff, 0xff, 0xec, 0xf5, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, + 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x20, 0xf6, 0xff, 0xff, 0xa4, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xf6, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xf7, 0xff, 0xff, 0xf4, 0xf7, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x1c, 0xf8, 0xff, 0xff, 0xf4, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xf8, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x04, 0xf8, 0xff, 0xff, 0x40, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, - 0x68, 0xf8, 0xff, 0xff, 0xa8, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x50, 0xf8, 0xff, 0xff, 0x8c, 0xf8, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x07, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, 0xc8, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, - 0x0c, 0x01, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, - 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, - 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, - 0xd8, 0xf8, 0xff, 0xff, 0xdc, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xc4, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xf4, 0xf8, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x22, 0xf9, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xf4, 0xf8, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x24, 0xf9, 0xff, 0xff, - 0xe4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x10, 0xf9, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0x40, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf9, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, - 0x68, 0xf9, 0xff, 0xff, 0x70, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xf9, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, - 0x60, 0xf9, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x90, 0xf9, 0xff, 0xff, 0x1c, 0x05, 0x00, 0x00, + 0x0e, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x78, 0xf6, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x48, 0xf6, 0xff, 0xff, 0x50, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x24, 0xf6, 0xff, 0xff, 0x60, 0xf6, 0xff, 0xff, 0x10, 0x07, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xac, 0xf6, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x7c, 0xf6, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0xa4, 0xf6, 0xff, 0xff, + 0xc8, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xf0, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf6, 0xff, 0xff, 0x50, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x80, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x78, 0xf9, 0xff, 0xff, 0xb4, 0xf9, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0xd8, 0xf9, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, - 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x04, 0xfa, 0xff, 0xff, - 0x84, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xf0, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x20, 0xfa, 0xff, 0xff, 0xf0, 0x05, 0x00, 0x00, + 0xae, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa8, 0xf6, 0xff, 0xff, 0xe4, 0xf6, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, + 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x14, 0xf7, 0xff, 0xff, + 0xb0, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x6c, 0xf7, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x3c, 0xf7, 0xff, 0xff, 0x5c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8c, 0xf7, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x24, 0xf7, 0xff, 0xff, 0x60, 0xf7, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, + 0x88, 0xf7, 0xff, 0xff, 0x88, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x70, 0xf7, 0xff, 0xff, 0xac, 0xf7, 0xff, 0xff, 0xc0, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xf7, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xc8, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, + 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xf7, 0xff, 0xff, 0x20, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x46, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xfa, 0xff, 0xff, 0x44, 0xfa, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, - 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, - 0x31, 0x31, 0x00, 0x00, 0x6c, 0xfa, 0xff, 0xff, 0xc4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x88, 0xfa, 0xff, 0xff, 0x88, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xae, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xfa, 0xff, 0xff, 0xac, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0xd0, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, + 0xde, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xf7, 0xff, 0xff, 0x14, 0xf8, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, + 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, + 0x48, 0xf8, 0xff, 0xff, 0x50, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x98, 0xf8, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x30, 0xf8, 0xff, 0xff, 0x6c, 0xf8, 0xff, 0xff, 0x58, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5a, 0xf8, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xc4, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x94, 0xf8, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, + 0x64, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, + 0x00, 0x00, 0x00, 0x00, 0xcc, 0xf8, 0xff, 0xff, 0xc8, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb6, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xe8, 0xf8, 0xff, 0xff, 0x28, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0xd0, 0xf8, 0xff, 0xff, 0x0c, 0xf9, 0xff, 0xff, 0x60, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x28, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x50, 0xf9, 0xff, 0xff, 0xc0, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xf6, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xfa, 0xff, 0xff, 0xf4, 0xfa, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0xf4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xfb, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x04, 0xfb, 0xff, 0xff, 0x40, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, - 0x68, 0xfb, 0xff, 0xff, 0xa8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x50, 0xfb, 0xff, 0xff, 0x8c, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfb, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfb, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xa4, 0xfb, 0xff, 0xff, 0xe0, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, - 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, - 0x08, 0xfc, 0xff, 0xff, 0x08, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xf0, 0xfb, 0xff, 0xff, 0x2c, 0xfc, 0xff, 0xff, 0x04, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x18, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x48, 0xfc, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, - 0x24, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3e, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xf9, 0xff, 0xff, 0x74, 0xf9, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x4d, 0x61, 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, + 0xa8, 0xf9, 0xff, 0xff, 0x68, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x96, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xf9, 0xff, 0xff, 0xcc, 0xf9, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, + 0x72, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0x98, 0x03, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x50, 0xfa, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xe8, 0xf9, 0xff, 0xff, 0x24, 0xfa, 0xff, 0xff, + 0xa0, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x12, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x7c, 0xfa, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x4c, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, - 0x31, 0x30, 0x00, 0x00, 0x7c, 0xfc, 0xff, 0xff, 0x30, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfc, 0xff, 0xff, 0x94, 0xfc, 0xff, 0xff, - 0x44, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xba, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfc, 0xff, 0xff, - 0x02, 0x00, 0x00, 0x00, 0xbc, 0xfc, 0xff, 0xff, 0x4c, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0xd8, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, - 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, - 0x00, 0x00, 0x00, 0x00, 0x0c, 0xfd, 0xff, 0xff, 0xcc, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0xfd, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x04, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x34, 0xfd, 0xff, 0xff, - 0x78, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x24, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x1c, 0xfd, 0xff, 0xff, - 0x58, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x80, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x78, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0xa8, 0xfd, 0xff, 0xff, 0x68, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xce, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x90, 0xfd, 0xff, 0xff, 0xcc, 0xfd, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, - 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, - 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x28, 0xfe, 0xff, 0xff, 0x84, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0xff, 0x40, 0xfe, 0xff, 0xff, 0x98, 0x00, 0x00, 0x00, + 0x32, 0x31, 0x00, 0x00, 0x7c, 0xfa, 0xff, 0xff, 0x48, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6a, 0xfa, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xd4, 0xfa, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xa4, 0xfa, 0xff, 0xff, + 0xf4, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xf4, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x8c, 0xfa, 0xff, 0xff, + 0xc8, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, + 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0xf4, 0xfa, 0xff, 0xff, + 0x1c, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xdc, 0xfa, 0xff, 0xff, + 0x18, 0xfb, 0xff, 0xff, 0x54, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x64, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x34, 0xfb, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, + 0x00, 0x00, 0x00, 0x00, 0x5c, 0xfb, 0xff, 0xff, 0xac, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x4a, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x44, 0xfb, 0xff, 0xff, 0x80, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, 0xa4, 0xfb, 0xff, 0xff, + 0x6c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x92, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfb, 0xff, 0xff, + 0xc8, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, + 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, 0xf0, 0xfb, 0xff, 0xff, 0x20, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x66, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0x68, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x2c, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, - 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, - 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0xa4, 0xfe, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9c, 0xfe, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x94, 0xfe, 0xff, 0xff, 0xd0, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfe, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xd0, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x09, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, - 0x28, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x20, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x50, 0xff, 0xff, 0xff, 0xc0, 0x00, 0x00, 0x00, + 0xde, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xfb, 0xff, 0xff, 0x14, 0xfc, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, + 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x40, 0xfc, 0xff, 0xff, 0xd0, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x76, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xff, 0xff, 0xff, 0x74, 0xff, 0xff, 0xff, - 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, - 0x24, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, - 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, - 0x00, 0x00, 0x00, 0x00, 0xac, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x07, 0x00, 0x00, 0x00, 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa4, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd4, 0xff, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x2e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x28, 0xfc, 0xff, 0xff, 0x64, 0xfc, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbc, 0xfc, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x8c, 0xfc, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, + 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfc, 0xff, 0xff, + 0x5c, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa2, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xfc, 0xff, 0xff, + 0xd8, 0xfc, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0xa8, 0x00, 0x00, 0x00, 0xd0, 0x00, 0x00, 0x00, 0xfc, 0x00, 0x00, 0x00, 0x28, 0x01, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x4c, + 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, 0x24, 0xfd, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, + 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x7c, 0xfd, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x4c, 0xfd, 0xff, 0xff, 0x20, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x98, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x68, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xc0, 0xfd, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0x90, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe8, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x80, 0xfd, 0xff, 0xff, 0xbc, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x14, 0xfe, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0x1c, 0xfe, 0xff, 0xff, + 0x03, 0x00, 0x00, 0x00, 0xec, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x4c, 0xfe, 0xff, 0xff, + 0x07, 0x00, 0x00, 0x00, 0x1c, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x70, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x40, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, + 0x68, 0xfe, 0xff, 0xff, 0xa8, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x56, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xfe, 0xff, 0xff, 0x8c, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfe, 0xff, 0xff, + 0x54, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa2, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xfe, 0xff, 0xff, + 0xd8, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf6, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xf0, 0xfe, 0xff, 0xff, 0x2c, 0xff, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, + 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x64, 0xff, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x78, 0x5f, 0x73, 0x63, + 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbc, 0xff, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x8c, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x7e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb0, 0xff, 0xff, 0xff, 0x60, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0xa0, 0xff, 0xff, 0xff, 0xdc, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, }; // clang-format on diff --git a/onnxruntime/core/framework/model_metadef_id_generator.cc b/onnxruntime/core/framework/model_metadef_id_generator.cc new file mode 100644 index 000000000000..e51c6ebc2997 --- /dev/null +++ b/onnxruntime/core/framework/model_metadef_id_generator.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include "model_metadef_id_generator.h" +#include "core/platform/ort_mutex.h" +#include "core/graph/graph_viewer.h" +#include "core/framework/murmurhash3.h" + +namespace onnxruntime { +int ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer, + HashValue& model_hash) const { + // if the EP is shared across multiple sessions there's a very small potential for concurrency issues. + // use a lock when generating an id to be paranoid + static OrtMutex mutex; + std::lock_guard lock(mutex); + model_hash = 0; + + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + + uint32_t instance_hash[4] = {0, 0, 0, 0}; + + const Graph& main_graph = *cur_graph; + + // hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use + // the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique + // fingerprint for the instance that can use used as the key to the hash of the model path/contents. + MurmurHash3::x86_128(&main_graph, gsl::narrow_cast(sizeof(Graph)), instance_hash[0], &instance_hash); + HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32); + + // if we've already hashed this main graph instance use the cached value + auto entry = main_graph_hash_.find(graph_instance_hash); + if (entry != main_graph_hash_.cend()) { + model_hash = entry->second; + } else { + uint32_t hash[4] = {0, 0, 0, 0}; + + // prefer path the model was loaded from + // this may not be available if the model was loaded from a stream or in-memory bytes + const auto& model_path_str = main_graph.ModelPath().ToPathString(); + if (!model_path_str.empty()) { + MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast(model_path_str.size()), hash[0], &hash); + } else { + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // fingerprint the main graph by hashing graph inputs and the ordered outputs from each node + for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); + } + + // note: process nodes in order defined in model to be deterministic + for (const auto& node : main_graph.Nodes()) { + for (const auto* node_arg : node.OutputDefs()) { + if (node_arg->Exists()) { + hash_str(node_arg->Name()); + } + } + } + } + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + main_graph_hash_[graph_instance_hash] = model_hash; + } + + // return the current unique id, and increment to update + return model_metadef_id_[model_hash]++; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/model_metadef_id_generator.h b/onnxruntime/core/framework/model_metadef_id_generator.h new file mode 100644 index 000000000000..82f68c42b5c3 --- /dev/null +++ b/onnxruntime/core/framework/model_metadef_id_generator.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "core/common/basic_types.h" +namespace onnxruntime { +class GraphViewer; + +/// +/// helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across +/// multiple sessions. +/// +class ModelMetadefIdGenerator { + public: + /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance. + The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models. + @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph. + @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model. + This is created using the model path if available, + or the model input names and the output names from all nodes in the main graph. + */ + int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const; + + private: + // mutable as these are caches so we can minimize the hashing required on each usage of GenerateId + mutable std::unordered_map main_graph_hash_; // map graph instance hash to model contents hash + mutable std::unordered_map model_metadef_id_; // current unique id for model +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc new file mode 100644 index 000000000000..174942b9033d --- /dev/null +++ b/onnxruntime/core/framework/node_unit.cc @@ -0,0 +1,359 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "node_unit.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { + +namespace { + +enum class QLinearOpType : uint8_t { + Unknown, // Unknown or not a linear quantized op + DequantizeLinear, + QuantizeLinear, + QLinearConv, + QLinearMatMul, + QLinearAdd, + QLinearSigmoid, + QLinearAveragePool, + QLinearMul, + QLinearReduceMean, + QLinearConcat, + QLinearGlobalAveragePool, + QLinearLeakyRelu, +}; + +QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { + const auto& op_type = node.OpType(); + if (op_type == "DequantizeLinear") + return QLinearOpType::DequantizeLinear; + else if (op_type == "QuantizeLinear") + return QLinearOpType::QuantizeLinear; + else if (op_type == "QLinearConv") + return QLinearOpType::QLinearConv; + else if (op_type == "QLinearMatMul") + return QLinearOpType::QLinearMatMul; + else if (op_type == "QLinearAdd") + return QLinearOpType::QLinearAdd; + else if (op_type == "QLinearSigmoid") + return QLinearOpType::QLinearSigmoid; + else if (op_type == "QLinearAveragePool") + return QLinearOpType::QLinearAveragePool; + else if (op_type == "QLinearMul") + return QLinearOpType::QLinearMul; + else if (op_type == "QLinearReduceMean") + return QLinearOpType::QLinearReduceMean; + else if (op_type == "QLinearConcat") + return QLinearOpType::QLinearConcat; + else if (op_type == "QLinearGlobalAveragePool") + return QLinearOpType::QLinearGlobalAveragePool; + else if (op_type == "QLinearLeakyRelu") + return QLinearOpType::QLinearLeakyRelu; + + return QLinearOpType::Unknown; +} + +// Ops have 1 input +bool IsUnaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearSigmoid || + type == QLinearOpType::QLinearAveragePool || + type == QLinearOpType::QLinearGlobalAveragePool || + type == QLinearOpType::QLinearLeakyRelu || + type == QLinearOpType::QLinearReduceMean; +} + +// Ops have 2 inputs +bool IsBinaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConv || + type == QLinearOpType::QLinearMatMul || + type == QLinearOpType::QLinearAdd || + type == QLinearOpType::QLinearMul; +} + +// Ops have 1 or more inputs +bool IsVariadicQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConcat; +} + +const std::vector GetQDQIONodes(const GraphViewer& graph_viewer, + const QDQ::NodeGroup& node_group, bool is_input) { + std::vector io_nodes; + const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + io_nodes.reserve(src_nodes.size()); + for (const auto& node_idx : src_nodes) { + io_nodes.push_back(graph_viewer.GetNode(node_idx)); + } + + return io_nodes; +} + +// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup +std::vector GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) { + const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs(); + const size_t target_node_io_defs_size = target_node_io_defs.size(); + + // Find all the quantized IO defs and indices (for the input/output of the target node) + std::unordered_map quantized_io_defs; + quantized_io_defs.reserve(target_node_io_defs_size); + + auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin(); + auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd(); + + for (; cur != end; ++cur) { + const Node& node = cur->GetNode(); + + // If we can find the node index in the dq or q nodes this is a quantized input/output + if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) { + const auto node_inputs = node.InputDefs(); + const auto& node_attrs = node.GetAttributes(); + + // Get the Q or DQ axis attribute if available. + std::optional axis; + if (auto entry = node_attrs.find("axis"); entry != node_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr, axis}; + + if (is_input) { + // DQ is input to the target node, use the DstArgIndex + auto idx = cur->GetDstArgIndex(); + // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2]) + quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}}); + } else { + // Q is output of the target node, use the SrcArgIndex + auto idx = cur->GetSrcArgIndex(); + // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2]) + const auto node_outputs = node.OutputDefs(); + quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}}); + } + } + } + + // Construct the IODefs for this QDQ NodeGroup + std::vector io_defs; + io_defs.reserve(target_node_io_defs_size); + for (size_t i = 0; i < target_node_io_defs_size; i++) { + // If we can find the NodeUnitIODef for this index, this is a quantized input/output + if (quantized_io_defs.find(i) != quantized_io_defs.cend()) { + io_defs.push_back(std::move(quantized_io_defs.at(i))); + } else { + // This is a regular input + io_defs.push_back({*target_node_io_defs[i], std::nullopt}); + } + } + + return io_defs; +} + +} // namespace + +Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes) { + // Within a QDQ node group, a target node input is the only consumer of each DQ. + // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications + // may have happened since. Verify that this is still true. + for (const auto* dq_node : dq_nodes) { + const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); + ORT_RETURN_IF(dq_produces_graph_output, + "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), + ", target node: ", target_node.Name()); + + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); + ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, + "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " + "DQ node: ", + dq_node->Name(), ", target node: ", target_node.Name()); + } + + // an output from the target node can have either Q consumers or direct consumers. it cannot have both. + // this must be checked on a per output basis. + // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ + // node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output. + if (!q_nodes.empty()) { + auto cur_edge = target_node.OutputEdgesBegin(); + auto end_edge = target_node.OutputEdgesEnd(); + std::vector output_consumers(target_node.OutputDefs().size(), nullptr); + + for (; cur_edge != end_edge; ++cur_edge) { + auto output_idx = cur_edge->GetSrcArgIndex(); + const Node& this_consumer = cur_edge->GetNode(); + const Node* existing_consumer = output_consumers[output_idx]; + + if (existing_consumer != nullptr) { + // another edge for this output. either both are Q or both are not. + bool valid = true; + if (existing_consumer->OpType() == "QuantizeLinear") { + valid = this_consumer.OpType() == "QuantizeLinear"; + } else { + valid = this_consumer.OpType() != "QuantizeLinear"; + } + + ORT_RETURN_IF_NOT(valid, + "QDQ node group cannot have an output from the target node being consumed by a Q node and " + "a non-Q node. target node: ", + target_node.Name()); + } else { + output_consumers[output_idx] = &this_consumer; + } + } + + const auto& graph_outputs = graph_viewer.GetOutputs(); + for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) { + // any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to + // a quantized op. + if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") { + const auto& output_name = target_node.OutputDefs()[idx]->Name(); + bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(), + [&output_name](const NodeArg* node_arg) { + return node_arg->Name() == output_name; + }); + ORT_RETURN_IF(is_graph_output, + "QDQ node group cannot have an output from the target node that is consumed by a Q node and " + "a graph output. target node: ", + target_node.Name(), " output idx:", idx); + } + } + } + + return Status::OK(); +} +NodeUnit::NodeUnit(const Node& node) + : target_node_(node), + type_(Type::SingleNode), + input_edge_count_(node.GetInputEdgesCount()) { + InitForSingleNode(); +} + +NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group) + : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, + target_node_(*graph_viewer.GetNode(node_group.target_node)), + q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, + type_(Type::QDQGroup), + inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, + outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { + ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_)); + + input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), + [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); + + // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node. + // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge). + input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size(); + + // create output edges. each target node output either goes to Q node/s or non-Q node/s. + // ValidateNodeGroupQDQNodes ensures this. + auto cur_edge = target_node_.OutputEdgesBegin(); + auto end_edge = target_node_.OutputEdgesEnd(); + for (; cur_edge != end_edge; ++cur_edge) { + const Node& node = cur_edge->GetNode(); + + // if node is in q_nodes we hide the Q node. + if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) { + auto src_idx = cur_edge->GetSrcArgIndex(); + auto q_cur_edge = node.OutputEdgesBegin(); + auto q_end_edge = node.OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()}); + } + } else { + // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is. + output_edges_.insert(*cur_edge); + } + } +} + +const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } +const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } +const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } +int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } +NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } +const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } +ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } + +void NodeUnit::InitForSingleNode() { + const auto& input_defs = target_node_.InputDefs(); + const auto& output_defs = target_node_.OutputDefs(); + auto qlinear_type = GetQLinearOpType(target_node_); + if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support + // Not a Qlinear op, add all inputs / outputs + auto add_all_io = [](std::vector& defs, + const ConstPointerContainer>& node_defs) { + defs.reserve(node_defs.size()); + + for (const auto def : node_defs) { + defs.push_back(NodeUnitIODef{*def, std::nullopt}); + } + }; + + add_all_io(inputs_, input_defs); + add_all_io(outputs_, output_defs); + } else if (IsUnaryQLinearOp(qlinear_type)) { + // Unary QLinear Op has 5 inputs + // x, x_scale, x_zp, y_scale, y_zp (optional) + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], + NodeUnitIODef::QuantParam{*input_defs[3], + input_defs.size() > 4 ? input_defs[4] : nullptr}}); + + } else if (IsBinaryQLinearOp(qlinear_type)) { + // Binary QLinear Op has 9 inputs + // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}}); + + if (input_defs.size() == 9) { // has Bias + inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt}); // for Bias the scale and zp are optional + } + + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}}); + + } else if (qlinear_type == QLinearOpType::DequantizeLinear) { + // DequantizeLinear has 3 inputs + // x, x_scale, x_zp + // output is not quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt}); + + } else if (qlinear_type == QLinearOpType::QuantizeLinear) { + // QuantizeLinear the input is not quantized and has 3 inputs + // x, y_scale, y_zp (optional) + // The output is quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + } else { + ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); + } +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin(); +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end(); +} + +std::vector NodeUnit::GetAllNodesInGroup() const noexcept { + std::vector all_nodes = dq_nodes_; + all_nodes.push_back(&target_node_); + all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end()); + return all_nodes; +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.h b/onnxruntime/core/framework/node_unit.h similarity index 51% rename from onnxruntime/core/providers/shared/node_unit/node_unit.h rename to onnxruntime/core/framework/node_unit.h index b47204ca3c42..a168495f12eb 100644 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -3,6 +3,9 @@ #pragma once +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include #include @@ -18,17 +21,31 @@ class NodeArg; class Path; namespace QDQ { -struct NodeGroup; -} +// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group +struct NodeGroup { + std::vector dq_nodes; + std::vector q_nodes; + NodeIndex target_node; + + // Validator to check if the set of nodes can form a valid QDQ NodeGroup. + // Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to + // be converted into a single node with a quantized operator. + static Status CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes); +}; +} // namespace QDQ // Definition of one input or output // If the optional quant_param is present, then this is a quantized input, // otherwise this is a regular input struct NodeUnitIODef { - // The quantization parameter, scale is manadatory, and zero_point is optional + // The quantization parameter. Scale is mandatory. Zero-point and axis are optional. struct QuantParam { const NodeArg& scale; const NodeArg* zero_point{nullptr}; + std::optional axis{std::nullopt}; }; const NodeArg& node_arg; @@ -69,26 +86,33 @@ class NodeUnit { const std::vector& GetQNodes() const noexcept { return q_nodes_; } std::vector GetAllNodesInGroup() const noexcept; - Node::EdgeConstIterator OutputEdgesBegin(size_t index) const; - Node::EdgeConstIterator OutputEdgesEnd(size_t index) const; + /// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes + /// plus any other edges to the target node for inputs that are not via a DQ node. + size_t InputEdgeCount() const { return input_edge_count_; } + + // output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit + // output. any Q nodes are hidden. + Node::EdgeConstIterator OutputEdgesBegin() const; + Node::EdgeConstIterator OutputEdgesEnd() const; private: - const std::vector q_nodes_; // q-nodes for this NodeUnit - const std::vector dq_nodes_; // dq nodes for this NodeUnit, not all inputs + // Initialization for a NodeUnit that contains a single node + void InitForSingleNode(); + + const std::vector dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs const Node& target_node_; + const std::vector q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs const Type type_; std::vector inputs_; std::vector outputs_; - // Initializing for a single Node - void InitForSingleNode(); -}; + size_t input_edge_count_; // total number of input edges -// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) -// And return a map to quick query the NodeUnit which contains the given Node, -// Note, the value of the map is owned by the vector of std::unique_ptr -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); + // output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group. + Node::EdgeSet output_edges_; +}; } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/op_kernel_info.cc b/onnxruntime/core/framework/op_kernel_info.cc index 841fdb585f0d..28793dae36d2 100644 --- a/onnxruntime/core/framework/op_kernel_info.cc +++ b/onnxruntime/core/framework/op_kernel_info.cc @@ -15,7 +15,8 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node, const std::unordered_map& constant_initialized_tensors, const OrtValueNameIdxMap& ort_value_name_idx_map, const DataTransferManager& data_transfer_mgr, - const AllocatorMap& allocators) + const AllocatorMap& allocators, + const ConfigOptions& config_options) : OpNodeProtoHelper(&proto_helper_context_), node_(node), kernel_def_(kernel_def), @@ -24,15 +25,22 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node, ort_value_name_idx_map_(ort_value_name_idx_map), data_transfer_mgr_(data_transfer_mgr), proto_helper_context_(node), - allocators_(allocators) {} + allocators_(allocators), + config_options_(config_options) { +} OpKernelInfo::OpKernelInfo(const OpKernelInfo& other) : OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_, - other.ort_value_name_idx_map_, other.data_transfer_mgr_, other.allocators_) {} + other.ort_value_name_idx_map_, other.data_transfer_mgr_, + other.allocators_, other.config_options_) { +} AllocatorPtr OpKernelInfo::GetAllocator(OrtMemType mem_type) const { auto it = allocators_.find(execution_provider_->GetOrtDeviceByMemType(mem_type)); - if (it != allocators_.end()) return it->second; + if (it != allocators_.end()) { + return it->second; + } + return nullptr; } diff --git a/onnxruntime/core/framework/sequential_execution_plan.h b/onnxruntime/core/framework/sequential_execution_plan.h index 3152154e52d7..62c66bc6f336 100644 --- a/onnxruntime/core/framework/sequential_execution_plan.h +++ b/onnxruntime/core/framework/sequential_execution_plan.h @@ -203,6 +203,8 @@ struct SequentialExecutionPlan : public ExecutionPlanBase { } return count; } + + InlinedVector node_stream_map_; }; // Output details of an execution plan: diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index ba68bc1d7d83..0cc7294a4649 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -181,7 +181,7 @@ class SessionScope { } auto& logger = session_state_.Logger(); - LOGS(logger, VERBOSE) << "Begin execution"; + VLOGS(logger, 0) << "Begin execution"; const SequentialExecutionPlan& seq_exec_plan = *session_state_.GetExecutionPlan(); const auto& exec_plan_vec = seq_exec_plan.execution_plan; VLOGS(logger, 1) << "Size of execution plan vector: " << exec_plan_vec.size(); @@ -306,18 +306,20 @@ class KernelScope { #endif #ifdef ENABLE_NVTX_PROFILE - auto& node = kernel_.Node(); - profile::NvtxRangeCreator& forward_range = session_scope_.forward_range_; - profile::NvtxRangeCreator& backward_range = session_scope_.backward_range_; - if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) { - // Start timing forward pass when encountering the first forward node. - forward_range.Begin(); - } else if (node.Description() == "Backward pass" && !backward_range.IsBeginCalled() && - forward_range.IsBeginCalled()) { - // Start timing backward pass when encountering the first backward node. - // In the meanwhile, forward range ends. - forward_range.End(); - backward_range.Begin(); + { + auto& node = kernel_.Node(); + profile::NvtxRangeCreator& forward_range = session_scope_.forward_range_; + profile::NvtxRangeCreator& backward_range = session_scope_.backward_range_; + if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) { + // Start timing forward pass when encountering the first forward node. + forward_range.Begin(); + } else if (node.Description() == "Backward pass" && !backward_range.IsBeginCalled() && + forward_range.IsBeginCalled()) { + // Start timing backward pass when encountering the first backward node. + // In the meanwhile, forward range ends. + forward_range.End(); + backward_range.Begin(); + } } #endif @@ -515,7 +517,7 @@ onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx, return Status(status.Category(), status.Code(), msg_string); } ctx.RecycleNodeInputs(idx); - LOGS(logger, VERBOSE) << "stream " << stream_idx << " launch kernel with idx " << idx; + VLOGS(logger, 0) << "stream " << stream_idx << " launch kernel with idx " << idx; return Status::OK(); } @@ -531,7 +533,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< const bool only_execute_path_to_fetches, bool single_thread_mode) { auto* execution_plan = session_state.GetExecutionPlan(); - LOGS(logger, VERBOSE) << "Number of streams: " << execution_plan->execution_plan.size(); + VLOGS(logger, 0) << "Number of streams: " << execution_plan->execution_plan.size(); int32_t valid_streams = 0; for (auto& stream : execution_plan->execution_plan) { if (stream && stream->steps_.size() > 0) diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 40c59cfcf699..796a018ac0f6 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -65,6 +65,11 @@ struct FreeDimensionOverride { * Configuration information for a session. */ struct SessionOptions { +#if defined(__wasm__) && defined(__EMSCRIPTEN_PTHREADS__) + static constexpr bool DEFAULT_USE_PER_SESSION_THREADS = false; +#else + static constexpr bool DEFAULT_USE_PER_SESSION_THREADS = true; +#endif ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL; // set the execution order of the graph @@ -129,7 +134,8 @@ struct SessionOptions { // By default the session uses its own set of threadpools, unless this is set to false. // Use this in conjunction with the CreateEnvWithGlobalThreadPools API. - bool use_per_session_threads = true; + bool use_per_session_threads = DEFAULT_USE_PER_SESSION_THREADS; + bool thread_pool_allow_spinning = true; // Deterministic compute is likely not as performant. This option is default to false. diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 51bb02918d82..e318c9a8238c 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -8,7 +8,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/gsl.h" @@ -259,8 +259,8 @@ class SessionState { * \param p_node0 Nullable * \param kci0 Nullable */ - NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0) - : index(index0), p_node(p_node0), kci(kci0), device(&device0) {} + NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0, int stream_index0 = -1) + : index(index0), p_node(p_node0), kci(kci0), device(&device0), stream_index(stream_index0) {} size_t index; // Nullable @@ -268,6 +268,7 @@ class SessionState { // Nullable const KernelCreateInfo* kci = nullptr; const OrtDevice* device = nullptr; + int stream_index; }; using NameNodeInfoMapType = InlinedHashMap>; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index df11fe8302ae..692ca0877253 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -367,6 +367,7 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& for (auto& node : graph.Nodes()) { const KernelCreateInfo& kci = session_state.GetNodeKernelCreateInfo(node.Index()); + int stream_index = static_cast(exec_plan->node_stream_map_[node.Index()]); ORT_RETURN_IF_ERROR( onnxruntime::Node::ForEachWithIndex( @@ -379,8 +380,7 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& int arg_index; ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index)); const auto& device = exec_plan->GetLocation(arg_index); - - SessionState::NodeInfo node_info(index, &node, &kci, device); + SessionState::NodeInfo node_info(index, &node, &kci, device, stream_index); if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) { ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info)); @@ -419,7 +419,7 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& int arg_index; ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index)); auto& device = exec_plan->GetLocation(arg_index); - SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, &kci, device); + SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, &kci, device, stream_index); ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info)); } } diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index 4ff5ee5db865..dd7f4d35b34b 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -168,7 +168,7 @@ void StreamExecutionContext::RecycleNodeInputs(onnxruntime::NodeIndex node_index for (auto idx : execution_plan->node_release_list[node_index]) { if (--release_plan_[idx] == 0) { ORT_ENFORCE(frame_.ReleaseMLValue(static_cast(execution_plan->release_actions[idx].value_index)).IsOK()); - LOGS(*logger_, VERBOSE) << "ort value " << execution_plan->release_actions[idx].value_index << " released"; + VLOGS(*logger_, 0) << "ort value " << execution_plan->release_actions[idx].value_index << " released"; } } } @@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess } #ifdef USE_CANN + // Leave it to CANN EP to fill the gap if they want to use run_options + static onnxruntime::RunOptions run_options; // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool, // which is different from CUDA Runtime API, but similar to CUDA Driver API. auto& execution_providers = ctx.GetSessionState().GetExecutionProviders(); for (auto& xp : execution_providers) { - auto status = xp->OnRunStart(); + auto status = xp->OnRunStart(run_options); if (!status.IsOK()) { ctx.SetStatus(status); return; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index fd32aaedcc2e..8a2db6d5728a 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -7,6 +7,10 @@ #include #include +#if defined(__wasm__) +#include +#endif + #include "core/common/gsl.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" @@ -769,6 +773,7 @@ static void DeleteCharArray(void* param) noexcept { delete[] arr; } +#if !defined(__wasm__) static Status GetFileContent( const Env& env, const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, void*& raw_buffer, OrtCallback& deleter) { @@ -797,6 +802,7 @@ static Status GetFileContent( raw_buffer = buffer.release(); return Status::OK(); } +#endif Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, @@ -819,6 +825,69 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, ext_data_len = raw_data_safe_len; ext_data_deleter = OrtCallback{nullptr, nullptr}; } else { +#if defined(__wasm__) + ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, + "External initializer: ", tensor_proto.name(), + " offset: ", file_offset, " size to read: ", static_cast(raw_data_safe_len), + " are out of bounds or can not be read in full (>4GB)."); + + auto buffer = std::make_unique(raw_data_safe_len); + ext_data_deleter = OrtCallback{DeleteCharArray, buffer.get()}; + ext_data_buf = buffer.release(); + ext_data_len = raw_data_safe_len; + + // In WebAssembly, try use a simplified preloaded file map in WebAssembly when available. + auto err_code = EM_ASM_INT(({ + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (typeof Module == 'undefined' || !Module.MountedFiles) { + return 1; // "Module.MountedFiles" is not available. + } + let fileName = UTF8ToString($0 >>> 0); + if (fileName.startsWith('./')) { + fileName = fileName.substring(2); + } + const fileData = Module.MountedFiles.get(fileName); + if (!fileData) { + return 2; // File not found in preloaded files. + } + const offset = $1 >>> 0; + const length = $2 >>> 0; + const buffer = $3 >>> 0; + + if (offset + length > fileData.byteLength) { + return 3; // Out of bounds. + } + + try { + // Copy the file data (fileData,offset,length) into WebAssembly memory (HEAPU8,buffer,length). + HEAPU8.set(fileData.subarray(offset, offset + length), buffer); + return 0; + } catch { + return 4; + } + }), + external_data_file_path.c_str(), + static_cast(file_offset), + static_cast(raw_data_safe_len), + ext_data_buf); + const char* err_msg; + switch (err_code) { + case 0: + return Status::OK(); + case 1: + err_msg = "Module.MountedFiles is not available."; + break; + case 2: + err_msg = "File not found in preloaded files."; + break; + case 3: + err_msg = "Out of bounds."; + break; + default: + err_msg = "Unknown error occurred in memory copy."; + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", external_data_file_path, "\", error: ", err_msg); +#else size_t file_length; // error reporting is inconsistent across platforms. Make sure the full path we attempted to open is included. auto status = env.GetFileLength(external_data_file_path.c_str(), file_length); @@ -836,6 +905,7 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path.c_str(), file_offset, raw_data_safe_len, ext_data_buf, ext_data_deleter)); ext_data_len = raw_data_safe_len; +#endif } return Status::OK(); diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 23fe5e1cd3d9..0c4d498fae9e 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -270,6 +270,15 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session } copy_info.target_device = *node_info.device; + copy_info.unique_stream_index_consumes_it = node_info.stream_index; + ORT_RETURN_IF(node_info.stream_index < 0, "node_info.stream_index < 0"); + for (size_t i = 1; i < node_info_vec.size(); i++) { + ORT_RETURN_IF(node_info_vec[i].stream_index < 0, "node_info_vec[i].stream_index < 0"); + if (node_info_vec[i].stream_index != node_info.stream_index) { + copy_info.unique_stream_index_consumes_it = -1; + break; + } + } #ifdef ENABLE_TRAINING } else { @@ -441,11 +450,12 @@ static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager static common::Status CopyInputsAcrossDevices(const SessionState& session_state, gsl::span orig_feeds, std::vector& new_feeds, - gsl::span copy_info, - gsl::span feed_streams) { +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollection* device_stream_collection, +#endif + gsl::span copy_info) { size_t num_feeds = orig_feeds.size(); ORT_ENFORCE(copy_info.size() == num_feeds); - ORT_ENFORCE(feed_streams.size() == num_feeds); new_feeds.resize(num_feeds); std::vector batched_data_transfers; @@ -453,14 +463,32 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state, std::vector batched_sparse_data_transfers; #endif + std::unordered_set stream_to_flush; for (size_t idx = 0; idx < num_feeds; ++idx) { + Stream* copy_this_feed = nullptr; +#ifdef ORT_ENABLE_STREAM + if (device_stream_collection) { + if (copy_info[idx].unique_stream_index_consumes_it < 0) { + for (size_t i = 0; i < device_stream_collection->NumStreams(); i++) { + Stream* stream = device_stream_collection->GetStream(i); + if (stream && stream->GetDevice().Type() == copy_info[idx].target_device.Type()) { + copy_this_feed = stream; + stream_to_flush.insert(stream); + break; + } + } + } else { + copy_this_feed = device_stream_collection->GetStream(copy_info[idx].unique_stream_index_consumes_it); + } + } +#endif #if !defined(DISABLE_SPARSE_TENSORS) ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(session_state, copy_info[idx], orig_feeds[idx], new_feeds[idx], - feed_streams[idx], + copy_this_feed, &batched_data_transfers, &batched_sparse_data_transfers)); #else ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(session_state, copy_info[idx], orig_feeds[idx], new_feeds[idx], - feed_streams[idx], + copy_this_feed, &batched_data_transfers)); #endif } @@ -479,10 +507,7 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state, // TODO: this sync is because the graph inputs can be consumed by multiple stream, // but we can only place the MemCpyAsync on one of the stream. Ideally we should make // other stream wait on the event of the memory copy stream, instead of host sync stream. - std::unordered_set visited; - for (auto* stream : feed_streams) { - if (stream && visited.insert(stream).second) stream->Flush(); - } + for (const auto& stream : stream_to_flush) stream->Flush(); return Status::OK(); } @@ -640,33 +665,12 @@ ExecuteGraphImpl(const SessionState& session_state, if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) { const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo(); - InlinedVector feed_streams; - feed_streams.reserve(feed_copy_info.size()); - // TODO: we can pre-calculate the stream index for graph inputs in execution plan + auto status = CopyInputsAcrossDevices(session_state, feeds, device_feeds, #ifdef ORT_ENABLE_STREAM - for (auto& copy_info : feed_copy_info) { - auto& device = copy_info.target_device; - bool found = false; - if (device_stream_collection) { - size_t num_streams = device_stream_collection->NumStreams(); - for (size_t i = 0; i < num_streams; i++) { - Stream* stream = device_stream_collection->GetStream(i); - if (stream && stream->GetDevice().Type() == device.Type()) { - feed_streams.push_back(stream); - found = true; - break; - } - } - } - if (!found) - feed_streams.push_back(nullptr); - } -#else - for (size_t i = 0; i < feed_copy_info.size(); ++i) { - feed_streams.push_back(nullptr); - } + device_stream_collection, #endif - ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info, feed_streams)); + feed_copy_info); + ORT_RETURN_IF_ERROR(status); feeds_to_use = device_feeds; } @@ -819,27 +823,7 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) { const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo(); - InlinedVector feed_streams; - feed_streams.reserve(feed_copy_info.size()); - // TODO: we can pre-calculate the stream index for graph inputs in execution plan - for (auto& copy_info : feed_copy_info) { - auto& device = copy_info.target_device; - bool found = false; - if (device_stream_collection) { - size_t num_streams = device_stream_collection->NumStreams(); - for (size_t i = 0; i < num_streams; i++) { - Stream* stream = device_stream_collection->GetStream(i); - if (stream && stream->GetDevice().Type() == device.Type()) { - feed_streams.push_back(stream); - found = true; - break; - } - } - } - if (!found) - feed_streams.push_back(nullptr); - } - ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info, feed_streams)); + ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, device_stream_collection, feed_copy_info)); p_feeds = device_feeds; } @@ -1015,9 +999,19 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor inputs are on device, all non-tensor inputs are on CPU, + // except those specified in attribute cpu_input_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_input_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1025,7 +1019,7 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, true); } #else ORT_UNUSED_PARAMETER(node); @@ -1040,9 +1034,19 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor outputs are on device, all non-tensor outputs are on CPU, + // except those specified in attribute cpu_output_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_output_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1050,7 +1054,7 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index ea67218b5c92..adfa1b61e192 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -14,11 +14,12 @@ using namespace ::ONNX_NAMESPACE; namespace ONNX_NAMESPACE { -void matmulShapeInference( +namespace defs::math::utils { +void MatMulShapeInference( ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); - +} // namespace defs::math::utils } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -260,12 +261,22 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); } else { - fail_shape_inference("Missing input 2 (value)"); + ONNX_NAMESPACE::TensorShapeProto output_shape; + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0); + int64_t hidden_size = query_dims[2].dim_value(); + int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads); + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + output_shape.add_dim()->set_dim_value(head_size * num_heads); + updateOutputShape(ctx, 0, output_shape); } } if (ctx.getNumOutputs() > 1) { // has present output if (hasInputShape(ctx, past_key_index)) { + // auto& query_shape = getInputShape(ctx, 0); + // auto& query_dims = query_shape.dim(); auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); if (past_dims.size() != 4) { @@ -273,8 +284,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& } ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not } } } @@ -333,6 +343,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("rotary_embedding_dim", + "Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size", + AttributeProto::INT, + OPTIONAL_VALUE) .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, @@ -923,6 +937,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("unidirectional", + "Whether every token can only attend to previous tokens. Default value is 0.", + AttributeProto::INT, + static_cast(0)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", @@ -1007,18 +1025,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", AttributeProto::INT, static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" + "(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", "T") .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", - "T") + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "past_key", "past state key with support for format BNSH. When past_key uses same tensor as present_key" @@ -1039,6 +1068,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "total_sequence_length", "Scalar tensor of total sequence length (past + new).", "M") + .Input(7, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) + .Input(8, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", @@ -1055,7 +1094,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") - .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { GroupQueryAttentionTypeAndShapeInference(ctx, 3); @@ -1141,6 +1180,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("rotary_embedding_dim", + "Rotary embedding dimension. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("num_heads", + "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "input", "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", @@ -1151,23 +1198,88 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(2, "cos_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Input(3, "sin_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Output(0, "output", "tensor with same shape as input.", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); propagateShapeFromInputToOutput(ctx, 0, 0); })); +constexpr const char* GemmaRotaryEmbedding_ver1_doc = R"DOC( +GemmaRotaryEmbedding is the implementation of below part of rotary positional embeddings (RoPE). It implements below from modeling_gemma.py. + +Here's onnxscript that was tested + +from onnxscript import FLOAT, FLOAT16, script +from onnxscript import opset18 as op + +@script() +def gemma_rotary_embedding(emb: FLOAT["bs", "seq_len", "dim"], q: FLOAT16["bs", "num_heads", "seq_len", "dim"], q_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"], k: FLOAT16["bs", "num_heads", "seq_len", "dim"], k_rot: FLOAT16["bs", "num_heads", "seq_len", "dim"]): + sin_val = op.Sin(emb) + casted_sin = op.Cast(sin_val, to=10) # for fp16 mix-precision training. Other types are not supported. + cos_val = op.Cos(emb) + casted_cos = op.Cast(cos_val, to=10) + unsqueezed_sin = op.Unsqueeze(casted_sin, [1]) + unsqueezed_cos = op.Unsqueeze(casted_cos, [1]) + q_embed = (q * casted_cos) + (q_rot * casted_sin) + k_embed = (k * casted_cos) + (k_rot * casted_sin) + return q_embed, k_embed + +onnx_model = gemma_rotary_embedding.to_model_proto() + + +)DOC"; +ONNX_MS_OPERATOR_SET_SCHEMA( + GemmaRotaryEmbedding, 1, + OpSchema() + .SetDoc(GemmaRotaryEmbedding_ver1_doc) + .Input(0, + "emb", + "embeddding - 3D tensor with shape (batch_size, seq_len, dim)", + "U") + .Input(1, + "q", + "q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Input(2, + "q_rot", + "half rotated q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Input(3, + "k", + "k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Input(4, + "k_rot", + "k state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Output(0, + "output1", + "4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .Output(1, + "output2", + "4D tensor with shape (batch_size, num_heads, seq_len, dim)", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float16 tensors.") + .TypeConstraint("U", {"tensor(float)"}, "Constrain input 0 type to float tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + propagateElemTypeFromInputToOutput(ctx, 1, 1); + propagateShapeFromInputToOutput(ctx, 1, 0); + propagateShapeFromInputToOutput(ctx, 1, 1); + })); + constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, @@ -1281,7 +1393,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(3, "input_skip_bias_sum", "Sum of the input and skip inputs (and bias if it exists) with shape (batch_size, sequence_length, hidden_size).", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); ONNX_MS_OPERATOR_SET_SCHEMA( SkipSimplifiedLayerNormalization, 1, @@ -1330,7 +1442,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); constexpr const char* NGramRepeatBlock_ver1_doc = R"DOC( Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form a repeated n-gram if added to the back of the input_ids. @@ -1398,7 +1510,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Constrain input and output types to float or half tensors.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::matmulShapeInference(ctx, 0, 1); + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); })); constexpr const char* RemovePadding_ver1_doc = R"DOC( diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 4aa43f5de1cd..a0ca2e45f153 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -91,10 +91,18 @@ void RegisterCollectiveOps() { "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Attr("normalize_routing_weights", + "Whether to normalize routing weights", + AttributeProto::INT, + static_cast(0)) .Attr("local_experts_start_index", "The start index of local experts", AttributeProto::INT, - static_cast(-1)) + static_cast(0)) + .Attr("tensor_shards", + "Tensor parallelism config. The number of shards for each expert weight and bias", + AttributeProto::INT, + static_cast(1)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or " @@ -106,22 +114,32 @@ void RegisterCollectiveOps() { "T") .Input(2, "fc1_experts_weights", - "3D input tensor with shape (local_num_experts, hidden_size, inter_size)", + "3D input tensor with shape (local_num_experts, hidden_size, local_inter_size)", "T") .Input(3, - "fc2_experts_weights", - "3D input tensor with shape (local_num_experts, inter_size, hidden_size)", - "T") - .Input(4, "fc1_experts_bias", - "2D optional input tensor with shape (local_num_experts, inter_size)", + "2D optional input tensor with shape (local_num_experts, local_inter_size)", "T", OpSchema::Optional) + .Input(4, + "fc2_experts_weights", + "3D input tensor with shape (local_num_experts, local_inter_size, hidden_size)", + "T") .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Input(6, + "fc3_experts_weights", + "3D optional input tensor with shape (local_num_experts, hidden_size, local_inter_size)", + "T", + OpSchema::Optional) + .Input(7, + "fc3_experts_bias", + "2D optional input tensor with shape (local_num_experts, local_inter_size)", + "T", + OpSchema::Optional) .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or " diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 54eb43753931..0f364b888006 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -39,10 +39,13 @@ void convPoolShapeInference( bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx); -void matmulShapeInference( + +namespace defs::math::utils { +void MatMulShapeInference( ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); +} void convTransposeWithDynamicPadsShapeInference(InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -1163,7 +1166,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1188,7 +1191,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1203,27 +1214,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1231,23 +1239,23 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) + .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1321,7 +1329,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1362,7 +1370,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) @@ -1377,8 +1385,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, constexpr const char* MoE_ver1_doc = R"DOC( Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, - GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) - usually uses top 32 experts. + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, @@ -1386,16 +1394,77 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, .SetDoc(MoE_ver1_doc) .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") - .Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") - .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) + .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +ONNX_MS_OPERATOR_SET_SCHEMA( + QMoE, 1, + OpSchema() + .SetDoc("Int4 MoE") + .Attr("activation_type", + "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + AttributeProto::STRING, + std::string("relu")) + .Attr("k", + "Number of top experts to select from expert pool", + AttributeProto::INT, + static_cast(1)) + .Attr("normalize_routing_weights", + "Whether to normalize routing weights", + AttributeProto::INT, + static_cast(0)) + .Input(0, + "input", + "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " + "(batch_size, sequence_length, hidden_size)", + "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size / 2)", "T1") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(4, + "fc1_experts_bias", + "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size / 2)", "T1") + .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") + .Input(7, + "fc2_experts_bias", + "2D optional input tensor with shape (num_experts, hidden_size)", + "T", + OpSchema::Optional) + .Input(8, + "fc3_experts_weights", + "3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)", + "T1", + OpSchema::Optional) + .Input(9, + "fc3_scales", + "2D optional input tensor with shape (num_experts, inter_size)", + "T", + OpSchema::Optional) + .Input(10, + "fc3_experts_bias", + "2D optional input tensor with shape (num_experts, inter_size)", + "T", + OpSchema::Optional) + .Output(0, + "output", + "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " + "(batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") @@ -1893,7 +1962,7 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy- // Right now we only support int32 y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32); - ONNX_NAMESPACE::matmulShapeInference(ctx, 0, 1); + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); })); /** @@ -3230,6 +3299,11 @@ void RegisterContribSchemas() { "(Optional) SDK version used to convert the model.", AttributeProto::STRING, OPTIONAL_VALUE) + .Attr( + "hardware_architecture", + "(Optional) Hardware architecture.", + AttributeProto::STRING, + OPTIONAL_VALUE) .Attr( "partition_name", "(Optional) partitioned graph name.", @@ -3333,22 +3407,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. -Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: -- n_blocks_per_col = (K + block_size - 1) / block_size -- blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - +Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) @@ -3367,12 +3442,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored "type T1.", AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1-dimensional data blob", "T2") + .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -3460,6 +3538,8 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 /*min_arity*/ 1) .Attr("operator", "Name of ATen operator.", AttributeProto::STRING) .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false) + .Attr("cpu_input_args", "CPU input argument indices.", AttributeProto::INTS, false) + .Attr("cpu_output_args", "CPU output argument indices.", AttributeProto::INTS, false) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor."); #endif diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index c8960578f9e3..6bf19654a3ce 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -106,6 +106,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -206,6 +209,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 4313fae767fe..47f61a43458e 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -22,7 +22,9 @@ void RNNShapeInference(InferenceContext& ctx); void convTransposeShapeInference(InferenceContext& ctx); void convPoolShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx); -void matmulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); +namespace defs::math::utils { + void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); +} } // namespace ONNX_NAMESPACE @@ -400,7 +402,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::matmulShapeInference(ctx, 0, 1); + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); })); ONNX_MS_OPERATOR_SET_SCHEMA( @@ -434,11 +436,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Matrix multiply results from A * B", "T3") .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data type to 8-bit integer tensor.") .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.") - .TypeConstraint("T3", {"tensor(float)"}, + .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input a_scale, b_scale and output Y data type as float tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 2, 0); - ONNX_NAMESPACE::matmulShapeInference(ctx, 0, 1); + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); })); ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1129,7 +1131,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("S", {"tensor(float)"}, "Constrain bias and scales to float32") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::matmulShapeInference(ctx, 0, 2); + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 2); })); static const char* Attention_QOrdered_doc = R"DOC( diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc index eeef20e9dff5..8b1812f62be2 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc @@ -114,6 +114,45 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c } } +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + + auto stash_type = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + if (ctx.getNumOutputs() > 1) { + auto output_type = ctx.getOutputType(1); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 2) { + auto output_type = ctx.getOutputType(2); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 3) { + propagateElemTypeFromInputToOutput(ctx, 0, 3); + } + if (!hasNInputShapes(ctx, 1)) { + return; + } + auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + int64_t input_ndim = input_shape.dim_size(); + int axis = static_cast(input_ndim - 1); + + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); + mean_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 2) { + auto inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape(); + inv_std_dev_shape->CopyFrom(input_shape); + inv_std_dev_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 3) { + propagateShapeFromInputToOutput(ctx, 0, 3); + } +} + // Shape inference for Attention and QAttention void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { // Input 0, 1, 2 are input, weights and bias. diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h index 93cf5b304f65..6eb06af15309 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h @@ -13,5 +13,6 @@ namespace onnxruntime { namespace contrib { void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index); void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index baebe2420073..2220b9cd1db7 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span from, } } +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + VisitorPriorityQueue to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { + if (node_in_degree == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2367,8 +2388,14 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso inferred_type = existing_type; } else { // This should not happen: indicates incompleteness in ONNX inference. + std::stringstream ss; + ss << "index=" << operand_index; + for (auto it = op_formal_parameter.GetTypes().begin(); it != op_formal_parameter.GetTypes().end(); ++it) { + ss << "," << *(*it); + } Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, - "Node (" + node_name + ") output arg (" + output_def->Name() + ") type inference failed"); + "Node (" + node_name + ") Op (" + node.OpType() + ") output arg (" + + output_def->Name() + ") type inference failed, inferred types: " + ss.str()); return status; } @@ -2550,15 +2577,23 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { // Node verification. auto& node = *GetNode(node_index); - NodeProto node_proto; - node.ToProto(node_proto); const auto& node_name = node.Name(); if (!node.Op()) { { auto status = Status::OK(); ORT_TRY { - checker::check_node(node_proto, ctx, lsc); + // if this is first Graph::Resolve call, we may have a NodeProto that was set on the Node so we can skip + // the ToProto call. + if (const NodeProto* orig_node_proto = node.GetOriginalNodeProto(); orig_node_proto) { + checker::check_node(*orig_node_proto, ctx, lsc); + // clear original as we don't know if the node will be modified once the Graph::Resolve completes. + node.SetOriginalNodeProto(nullptr); + } else { + NodeProto node_proto; + node.ToProto(node_proto); + checker::check_node(node_proto, ctx, lsc); + } } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -2630,8 +2665,8 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { NO_CHANGE_ON_SYNC_FLAG(ORT_RETURN_IF_ERROR(InferAndVerifyTypeMatch(node, *p_op, options))); // Accumulate output names of the iterated Node - for (auto& output_name : node_proto.output()) { - lsc.output_names.insert(output_name); + for (const auto& output : node.OutputDefs()) { + lsc.output_names.insert(output->Name()); } } @@ -2792,12 +2827,13 @@ Status Graph::Resolve(const ResolveOptions& options) { graph.GraphProtoSyncNeeded(false); } + // set num_resolves_ here so the graph and any subgraphs all have the same value + ++graph.num_resolves_; + return Status::OK(); }; ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func)); - ++num_resolves_; - return Status::OK(); } @@ -2836,7 +2872,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. @@ -3095,13 +3131,25 @@ Node& Graph::AddNode(const NodeProto& node_proto, attributes[attr.name()] = attr; } - return AddNode(node_proto.name(), - node_proto.op_type(), - node_proto.doc_string(), - input_defs, - output_defs, - &attributes, - node_proto.domain()); + Node& new_node = AddNode(node_proto.name(), + node_proto.op_type(), + node_proto.doc_string(), + input_defs, + output_defs, + &attributes, + node_proto.domain()); + + // Perf optimization: temporarily set NodeProto in Node so we don't need to call Node::ToProto prior to + // calling onnx::check_node + // NOTE: We don't handle a node with kOnnxDomainAlias. The entry in schema_registry_ uses kOnnxDomain, + // and that's what onnx::check_node uses during validation. + // The Node ctor automatically converts kOnnxDomainAlias to kOnnxDomain to handle this. + // node_proto is const so we can't do the same here. + if (node_proto.domain() != kOnnxDomainAlias) { + new_node.SetOriginalNodeProto(&node_proto); + } + + return new_node; } static flatbuffers::Offset>> diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 8e962403556d..2314a5228f83 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -3,7 +3,7 @@ #include "graph_flatbuffers_utils.h" -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" @@ -392,6 +392,14 @@ Status LoadOrtTensorOrtFormat(const fbs::Tensor& fbs_tensor, const AllocatorPtr ort_tensor = onnxruntime::Tensor( tensor_dtype, TensorShape(tensor_dims->data(), tensor_dims->size()), allocator); + if (fbs_tensor.raw_data()->size() == 0U) { + // Empty tensor. Nothing to unpack. + // This check is necessary because an empty ort tensor will return a size of 1. + // As a result, the following call to UnpackTensor will fail since the src and + // dst sizes do not match (0 and 1 elements). + return Status::OK(); + } + // The tensor proto is used as a dummy here. The actual data is stored in the raw_data field of the flatbuffer. // The data is copied from the raw_data field to the ort_tensor. ONNX_NAMESPACE::TensorProto unused_tensor_proto; diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h index b625cbf3ca49..9c55dad3c41e 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.h +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h @@ -5,7 +5,7 @@ #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/status.h" #include "core/graph/ort_format_load_options.h" diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac..119d420066a8 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); @@ -212,6 +218,8 @@ const std::string& GraphViewer::Description() const noexcept { bool GraphViewer::GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { + value = nullptr; + // if we are using filtered subgraph, the initializer has to be part of the subgraph if (filter_info_ != nullptr && filtered_initializers_.find(tensor_name) == filtered_initializers_.cend()) return false; diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 4ce6660b794b..a774d5fe3446 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -8,7 +8,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/path.h" #include "core/graph/graph_viewer.h" diff --git a/onnxruntime/core/graph/op_identifier_utils.h b/onnxruntime/core/graph/op_identifier_utils.h index 8a9351a2d0dd..f7b1198c3197 100644 --- a/onnxruntime/core/graph/op_identifier_utils.h +++ b/onnxruntime/core/graph/op_identifier_utils.h @@ -3,7 +3,7 @@ #pragma once -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/graph/op_identifier.h" diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.h b/onnxruntime/core/graph/runtime_optimization_record_container.h index a28b19e786de..75750c2b9698 100644 --- a/onnxruntime/core/graph/runtime_optimization_record_container.h +++ b/onnxruntime/core/graph/runtime_optimization_record_container.h @@ -9,7 +9,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/common.h" #include "core/graph/runtime_optimization_record.h" diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md index 7e8d30cd1805..735ec4b80daf 100644 --- a/onnxruntime/core/mickey/README.md +++ b/onnxruntime/core/mickey/README.md @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that are often shared by various AI operators. The intention is to make this header files only, with no binary impact unless it is instantiated where it is needed. + +Currently cuda code are scattered in multiple locations in the repo. +Hopefully this can be the starting point of consolidating all cuda +code. diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h new file mode 100644 index 000000000000..52bff7e40dbe --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -0,0 +1,208 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blk_q4/f16_gemm_sm80.h + * + * Abstract: + * Entry point for Q4F16 GEMM kernel for SM80 devices. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/device/quantb_gemm.h" + +namespace onnxruntime { +namespace cuda { + +// +// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type +// +template < + typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16 + typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape + bool SmallM, // <- true if M <= 16 + bool kHasQuantOffset> +struct BlkQ4F16GemmImpl { + // + // Type definitions + // + + using ElementDequant = ElementDequant_; + using QuantBlocking = QuantBlocking_; + + static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!"); + + // Data types that are fixed for this kernel + using ElementAccumulator = float; + using ElementComputeEpilogue = ElementAccumulator; + using ElementInputA = ElementDequant; + using ElementOutput = ElementDequant; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + + // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators + // for async shared memory loading and minimize bank conflict + using ElementWPack = ElementDequant; + + using ElementQScale = ElementDequant; // <- data type of quantization scale + using ElementQOffset = uint8_t; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputWPack = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // Layout of quantization scale and offset, oriented to be loaded using less instructions + // in a warp tile + using LayoutInputQScale = + typename std::conditional::type; // <- layout of quantization scale + + using ShapeMMAThreadBlock = + typename std::conditional, + cutlass::gemm::GemmShape<128, 256, 64>>::type; + + static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32; + using ShapeMMAWarp = + typename std::conditional, + cutlass::gemm::GemmShape<64, 64, 64>>::type; + + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + + // Number of pipelines you want to use + static constexpr int NumStages = 3; + + using Gemm = cutlass::gemm::device::QuantBGemm< + ElementInputA, + LayoutInputA, + ElementWPack, + LayoutInputWPack, + ElementQScale, + typename std::conditional::type, + LayoutInputQScale, + QuantBlocking, + ElementOutput, + LayoutOutput, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages>; + + using Arguments = typename Gemm::Arguments; + + // Invoke gemm kernel (the version with quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_Qoffset_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (!kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_Qoffset_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } + + // Invoke gemm kernel (the version without quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h similarity index 94% rename from onnxruntime/core/mickey/blk_q4/prepack_sm80.h rename to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index e291ab39e8aa..c81b4967d271 100644 --- a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -3,7 +3,7 @@ * Licensed under the MIT License. * * Module Name: - * prepack_sm80.h + * blk_q4/f16_prepack_sm80.h * * Abstract: * Prepack weights and quantization parameters (scales and offsets) for @@ -110,8 +110,8 @@ struct BlockwiseQuantization { static void prepack_weights( int rows, int columns, - const gsl::span& weights, // <- int4 weights, column major - const gsl::span& weights_prepacked // <- int4 prepacked weights tensor, same size buffer + gsl::span weights, // <- int4 weights, column major + gsl::span weights_prepacked // <- int4 prepacked weights tensor, same size buffer ) { ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 && (rows % QuantBlocking::kRow) == 0 && @@ -171,10 +171,10 @@ struct BlockwiseQuantization { static void prepack_quant_scales( size_t rows, size_t columns, - const gsl::span& scales, // <- quant scales, column major layout - const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer + gsl::span scales, // <- quant scales, column major layout + gsl::span scales_prepacked // <- quant scales prepacked, same size buffer ) { - auto meta_shape = get_quant_meta_shape(rows, columns); + auto meta_shape = get_quant_meta_shape(static_cast(rows), static_cast(columns)); ORT_ENFORCE(scales.size() == size_t(meta_shape.product()), "Quantization scale tensor shape mismatch!"); ORT_ENFORCE(scales_prepacked.size() == size_t(meta_shape.product()), @@ -241,10 +241,10 @@ struct BlockwiseQuantization { static void prepack_quant_offsets( size_t rows, size_t columns, - const gsl::span& offsets, // <- quant offsets, int4, column major layout - const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer + gsl::span offsets, // <- quant offsets, int4, column major layout + gsl::span offsets_prepacked // <- quant offsets prepacked, double size buffer ) { - auto meta_shape = get_quant_meta_shape(rows, columns); + auto meta_shape = get_quant_meta_shape(static_cast(rows), static_cast(columns)); ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0, "Does not support odd number of rows or columns!"); diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h new file mode 100644 index 000000000000..38795291b032 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass_ext/q4gemm/kernel/default_quantb_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! A specialized GEMM operator for quantized B GEMM. + + It is modified from cutlass::gemm::device::Gemm. Both this class and the original Gemm class + are pretty much boilerplate code that construct the Gemm kernel class, and pass parameters + and controls to it. The only difference is that this class has a few more template parameters + to support quantization. + + This implementation pretty much follows the design of cutlass. But this class seems to be + just a wrapper of the Gemm kernel class. Consider combining them in future iterations. + +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class QuantBGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + // Quantization Parameters + static_assert(std::is_same::value, + "LayoutB, i.e. packed weights must appear ColumnMajor."); + static_assert(InstructionShape::kK == 16, + "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout."); + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + static constexpr bool kHasQOffset = !(std::is_same::value); + + // TODO(chenfucn): consider moving to uint4_t or smaller for QOffset + static_assert(!kHasQOffset || std::is_same::value, "QOffset must be uint8_t"); + + /// Define the kernel + using GemmKernel = typename kernel::DefaultQuantBGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQMeta, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout + >::GemmKernel; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_Qscale; + TensorRef ref_Qoffset; + + typename EpilogueOutputOp::Params epilogue; + + // split-K parallelism (etc.) are not yet supported, keeping this for future extension + int split_k_slices{1}; + // For gather+scatter operations + int const *gather_A_indices{nullptr}; + int const *gather_B_indices{nullptr}; + int const *scatter_D_indices{nullptr}; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); + } + + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); + } + }; + + private: + /// Kernel parameters object + typename GemmKernel::Params params_; + + public: + /// Constructs the GEMM. + QuantBGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_Qscale.reset(args.ref_Qscale.non_const_ref().data()); + params_.ref_Qoffset.reset(args.ref_Qoffset.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + std::cerr << "Failed to obtain maximum shared memory size " << smem_size << " for kernel: " + << cudaGetErrorString(result) << "\n"; + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h new file mode 100644 index 000000000000..2f4460bb59e9 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/default_gemm.h. templates for combining + * threadblock-scoped matrix multiply-add with the appropriate + * threadblock-scoped epilogue. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/kernel/quantb_gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "cutlass/layout/permute.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Access granularity of quant scales in units of elements + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, + /// + typename Enable = void +> +struct DefaultQuantBGemm; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout type for quant scales + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Access granularity of quant scales in units of elements + typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout +> +struct DefaultQuantBGemm { + + static_assert((platform::is_same::value + || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + + using Affine2Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< + 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + using Epilogue = typename platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::QuantBGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h new file mode 100644 index 000000000000..6e5ad8f40614 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -0,0 +1,462 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/gemm.h. + * Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct QuantBGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + static constexpr bool kHasQOffset = Mma::kHasQOffset; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorQScale::Params params_QScale; + typename Mma::IteratorQScale::TensorRef ref_QScale; + typename Mma::IteratorQOffset::Params params_QOffset; + typename Mma::IteratorQOffset::TensorRef ref_QOffset; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_size; // how many k vectors are processed by this threadblock + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr, + int const *gather_A_indices = nullptr, + int const *gather_B_indices = nullptr, + int const *scatter_D_indices = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + QuantBGemm() { } + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + // TODO check problem_size K, N must be multiple of QuantBlocking + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (problem_size.k() % Mma::Shape::kK != 0) { + // Currently we don't support this case due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. We can fix this by adding a predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return Status::kErrorInvalidProblem; + } + + int qscale_k = problem_size.k() / Mma::QuantBlocking::kRow; + int qscale_n = problem_size.n() / Mma::QuantBlocking::kColumn; + if ((qscale_k == 0) || (qscale_k * Mma::QuantBlocking::kRow != problem_size.k())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + if ((qscale_n == 0) || (qscale_n * Mma::QuantBlocking::kColumn != problem_size.n())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + + if (!TensorRef_aligned(ref_QScale, Mma::IteratorQScale::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + + if constexpr(kHasQOffset) { + if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k/2, params.problem_size.n()/2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; + const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; + + // should have been verified by can_implement() + assert((qscale_k > 0) && (qscale_k * Mma::QuantBlocking::kRow == problem_size_k)); + assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); + + cutlass::MatrixCoord tb_offset_QScale{ + threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) + }; + + typename Mma::IteratorQScale iterator_QScale( + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); + + typename Mma::IteratorQOffset iterator_QOffset( + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + const int warp_idx = canonical_warp_idx(); + const int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_QScale, iterator_QOffset, accumulators); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h new file mode 100644 index 000000000000..0af604f090e1 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma.h + * @brief Modified from cutlass/gemm/threadblock/default_mma.h. + * Defining global memory data layout and iterators, combinging with mma core and + * pipelined GEMM kernel. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h" +#include "cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout for quant scales and offsets + typename LayoutQMeta_, + /// Blocking size for quantization + typename QuantBlocking_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute + > +struct DefaultQuantBMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout for quant scales and offsets + typename LayoutQMeta, + /// Blocking size for quantization + typename QuantBlocking, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout + > +struct DefaultQuantBMma { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; + + // Define iterators over tiles from the quant scales + using ThreadMapQScale = typename MmaCore::IteratorThreadMapQScale; + using AccessTypeQScale = + cutlass::Array; + using IteratorQScale = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, + ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>; + + using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset; + using AccessTypeQOffset = + cutlass::Array; + using IteratorQOffset = + cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorQScale, typename MmaCore::SmemIteratorQScale, + cutlass::arch::CacheOperation::Global, IteratorQOffset, + typename MmaCore::SmemIteratorQOffset, cutlass::arch::CacheOperation::Global, + ElementAccumulator, LayoutC, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h new file mode 100644 index 000000000000..ad322f650520 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_core.h + * @brief Modified from cutlass/gemm/threadblock/default_mma_core.h. + * Defining data layout in shared memory, and its iterators. + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Element data type of quant scale + typename ElementQScale, + /// Element data type of quant offset + typename ElementQOffset, + /// Layout of quant scale + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DefaultQuantBMmaCore; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Element data type of quant scale + typename ElementQScale_, + /// Element data type of quant offset + typename ElementQOffset_, + /// Layout of quant scale + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultQuantBMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + (Shape::kK / 2) / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK/2>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + using SmemLayoutQScale = LayoutQMeta; + using SmemLayoutQOffset = LayoutQMeta; + + /// Threadblock-level quantization meta data shape + using ThreadblockQShape = MatrixShape; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!"); + static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, + "Only support single column or row quantize blocking!"); + static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + + /// Threadblock-level quantization meta data shape in pitch-linear layout + using TBQPitchLinearShape = typename std::conditional< + std::is_same::value, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// By default we would like to use 128b load. However, we can't load more than + /// a column at a time in a column major layout. + static int const kElementsPerAccessQScale = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + + /// quant scale is tiny. Not all threads are needed. + static int const kAccessCntQScale = ThreadblockQShape::kCount / kElementsPerAccessQScale; + static int const kThreadsQScale = (kAccessCntQScale > kThreads) ? kThreads : kAccessCntQScale; + + using IteratorThreadMapQScale = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQScale, kElementsPerAccessQScale>; + + using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + + static int const kElementsPerAccessQOffset = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + static int const kAccessCntQOffset = ThreadblockQShape::kCount / kElementsPerAccessQOffset; + static int const kThreadsQOffset = (kAccessCntQOffset > kThreads) ? kThreads : kAccessCntQOffset; + + using IteratorThreadMapQOffset = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQOffset, kElementsPerAccessQOffset>; + + using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultQuantBMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQScale, QuantBlocking, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h new file mode 100644 index 000000000000..6f27a692a3a2 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -0,0 +1,314 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_predicated_tile_access_iter.h + * @brief Templates for loading and storing optional tiles of matrix data. + * This iterator is just a wrapper of PredicatedTileAccessIterator, with + * the option to turn it off at compile time and minimize its runtime + * footprint. Also, it utilize the higher numbered threads in the + * threadblock when the iterator can not utilize all the threads. + */ + +#pragma once + +#include + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D matrix data loader, when element is std::monostate, the +/// iterator becomes no-op with minimal runtime footprint. Also, it utilize the +/// higher numbered threads in the threadblock when the iterator can not utilize +/// all the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + /// Element data type of the iterator, no-op when it is std::monostate + typename Element_, + /// Layout of the source matrix + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + /// Number of threads in the threadblock, when provided, the iterator + /// will utilize the higher numbered threads + int kThreadBlockSize_ = -1> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kAdvanceRank = AdvanceRank_; + static constexpr int kThreadblockSize = kThreadBlockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized version below."); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + public: + Base base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : OptionalPredicatedTileAccessIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + base_.add_tile_offset(tile_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + OptionalPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + base_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + base_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + base_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + base_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return base_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for the disabled version +/// Reduce runtime overhead +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + int kThreadBlockSize_> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kThreadblockSize = kThreadBlockSize_; + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + public: + std::monostate base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_() {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) {} + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) {} + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() {} + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) {} + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) {} + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return false; } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h new file mode 100644 index 000000000000..4b0ae5317f8b --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h @@ -0,0 +1,224 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_regular_tile_access_iter.h + * @brief Templates implementing the address computation of storing of tiles + * from pitch-linear rank=2 tensors. + * + * This iterator is just a wrapper of RegularTileAccessIterator, with the + * option to turn it off at compile time and minimize its runtime footprint. + * Also, it utilize the higher numbered threads in the threadblock when the + * iterator can not utilize all the threads. + * + * Must be used in conjunction with OptionalPredicatedTileAccessIterator, + * with the same template parameters. + */ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D tile iterator, when element is std::monostate, the iterator +/// becomes no-op with minimal runtime footprint. Also, it utilize the higher +/// numbered threads in the threadblock when the iterator can not utilize all +/// the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Element_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + /// Number of threads in the threadblock, when not -1, the iterator + /// will utilize the higher numbered threads + int ThreadblockSize_ = -1, + int Alignment = + sizeof_bits::value * ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized template"); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + private: + + Base base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_(ref, flip_thread_id(thread_id)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + base_.add_tile_offset(coord); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization when Element is std::monostate, the iterator becomes no-op +/// +template < + typename Shape_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + int ThreadblockSize_, + int Alignment> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + private: + + std::monostate base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + return *this; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) {} +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h new file mode 100644 index 000000000000..28364cc34f2d --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -0,0 +1,1290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_multistage.h + * @brief Modified from cutlass/gemm/threadblock/mma_multistage.h. + * Added the quantized data memory pipeline, dequantization, and feeding + * to tensor cores. Mainloop pipeline is heavily modified. + */ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Utilities for printing layout for the prepacked weights and quantization parameters +/// +template< + /// Data type of the prepacked weights + typename ElementWeight, + /// Data type of the quant scales + typename ElementQScale, + /// Data type of the quant offsets + typename ElementQOffset> +struct QuantBLayoutDebug{ + static constexpr bool debug_smem = true; + static constexpr bool debug_fragment = true; + ElementWeight* smem_b_ptr_; + ElementQScale* smem_qscale_ptr_; + ElementQOffset* smem_qoffset_ptr_; + int warp_id_; + int lane_id_; + int block_id_; + + template + CUTLASS_DEVICE + static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + static_assert(Size % 4 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const Element* ptr = reinterpret_cast(&frag); + for (int i = 0; i < Size/4; i++, ptr+=4){ + if constexpr(std::is_integral::value){ + printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", + threadIdx.x, label, i, + ptr[0], ptr[1], ptr[2], ptr[3]); + } else { + printf("T%.2d%c%d, %.3f, %.3f, %.3f, %.3f\n", + threadIdx.x, label, i, + float(ptr[0]), float(ptr[1]), float(ptr[2]), float(ptr[3])); + } + } + } + } + } + + template + CUTLASS_DEVICE + static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; + static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const uint8_t* ptr = reinterpret_cast(&frag); + for (int i = 0; i < I8Size/2; i++, ptr+=2){ + printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); + } + } + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dummy type when quant offset is not used, to avoid compilation error, +/// and reduce runtime footprint +/// +struct DummyType{ + std::monostate dummy_; + public: + DummyType() = default; + + CUTLASS_HOST_DEVICE + void* data() const { + return nullptr; + } + + CUTLASS_HOST_DEVICE + std::monostate& operator[](int /*idx */) { + return dummy_; + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + static constexpr bool kHasQOffset = !std::is_same::value; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the prepacked weights + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // Tensor reference to the quantization scales + using TensorRefQScale = TensorRef; + using TensorRefQOffset = TensorRef; + + // Block size of the quantization (one set of quantization parameters per block of weights) + using QuantBlocking = typename Operator::QuantBlocking; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the prepacked weights in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the quantization parameter matrix in shared memory + /// Validation done in mma core class ThreadblockQShape + using ShapeQScale = + MatrixShape<(Shape::kK / QuantBlocking::kRow) * kStages, + Shape::kN / QuantBlocking::kColumn>; + + using BufTypeQOffset = std::conditional_t, + DummyType>; + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for prepacked weights + AlignedBuffer operand_B; + + /// Buffer for quantization scales + AlignedBuffer operand_QScale; + + /// Buffer for quantization offsets + BufTypeQOffset operand_QOffset; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQScale LayoutQMeta() { + return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQOffset LayoutQOffset() { + return Operator::SmemLayoutQOffset::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the prepacked weights + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the quantization scales + CUTLASS_HOST_DEVICE + TensorRefQScale operand_QScale_ref() { + return TensorRefQScale{operand_QScale.data(), LayoutQMeta()}; + } + + CUTLASS_HOST_DEVICE + TensorRefQOffset operand_QOffset_ref() { + if constexpr (!kHasQOffset){ + return TensorRefQOffset(); + } else { + return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; + } + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + /// Iterator to load a warp-scoped tile of quant scales from shared memory + typename Operator::IteratorQMeta warp_tile_iterator_QScale_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) + {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over tiles of quant scales in global memory + typename IteratorQScale_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQScale_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQScale, + /// Iterators over tiles of quant scales in global memory + typename IteratorQOffset_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQOffset_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQOffset, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaMultistage : + public QuantBMmaBase { +public: + ///< Base class + using Base = QuantBMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using IteratorQScale = IteratorQScale_; + using IteratorQOffset = IteratorQOffset_; + using SmemIteratorQScale = SmemIteratorQScale_; + using SmemIteratorQOffset = SmemIteratorQOffset_; + using QuantBlocking = typename Base::QuantBlocking; + + static cutlass::arch::CacheOperation::Kind const kCacheOpQScale = CacheOpQScale; + static cutlass::arch::CacheOperation::Kind const kCacheOpQOffset = CacheOpQOffset; + static constexpr bool kHasQOffset = Base::kHasQOffset; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of packed weights + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQScale = + IteratorQScale::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant scale + static int const kAccessesPerGroupQScale = + (AsyncCopyIterationsPerStageQScale + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQOffset = + IteratorQOffset::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant offset + static int const kAccessesPerGroupQOffset = + (AsyncCopyIterationsPerStageQOffset + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; + }; + + private: + + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + + using WarpLoadedFragmentQScale = typename Operator::FragmentQScale; + WarpLoadedFragmentQScale warp_loaded_frag_QScale_; + + using WarpLoadedFragmentQOffset = typename std::conditional::type; + WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; + }; + + + private: + + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of quant meta data to shared memory + SmemIteratorQScale smem_iterator_QScale_; + SmemIteratorQOffset smem_iterator_QOffset_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + /// very small meta data tensor require less threads to load + bool const should_load_qscale_; + bool const should_load_qoffset_; + + /// Shared memory pointers for debug dumping + static constexpr bool debug_layout = false; + using LayoutDebugType = typename std::conditional, + std::monostate>::type; + LayoutDebugType layout_debug_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + if constexpr(debug_layout){ + layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); + layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); + if constexpr(kHasQOffset){ + layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); + } else { + layout_debug_.smem_qoffset_ptr_ = nullptr; + } + layout_debug_.warp_id_ = warp_idx; + layout_debug_.lane_id_ = lane_idx; + layout_debug_.block_id_ = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + } + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_QScale_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_QScale_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + + smem_read_stage_idx_ = 0; + } + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + IteratorQScale &iterator_QScale, + IteratorQOffset &iterator_QOffset) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_QScale.add_tile_offset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_QScale_.add_tile_offset({1, 0}); + + if constexpr (kHasQOffset) { + iterator_QOffset.add_tile_offset({1, 0}); + smem_iterator_QOffset_.add_tile_offset({1, 0}); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0}); + if constexpr (kHasQOffset) { + smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0}); + } + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, + // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only + // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so + // it should be loaded in less than one cp.async instruction per thread. + // Even less for quant offset matrix. + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, + "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, + "Quant scale should 1 access per vector!"); + + // Async Copy for quantization scale + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); + } + + CUTLASS_DEVICE + void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, + "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, + "Quant offset should 1 access per vector!"); + + if constexpr(kHasQOffset) { + // Async Copy for quantization offset + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start = 0) { + auto group_start_A = group_start * Detail::kAccessesPerGroupA; + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + auto group_start_B = group_start * Detail::kAccessesPerGroupB; + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Async Copy for quantization scale + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); + + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + auto gmem_ptr = iterator_QScale.get(); + + cutlass::arch::cp_async( + dst_ptr, gmem_ptr, iterator_QScale.valid()); + + if constexpr (kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + + // Async Copy for quantization offset + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { + if (threadIdx.x == 0){ + printf("stage: %d\n", smem_write_stage_idx_); + } + cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); + if constexpr(kHasQOffset){ + cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); + } + } + } + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Loading next warp-level tiles from shared memory. This can be skipped on the very + // last iteration where: + // (gemm_k_iterations == (1 - Base::kStages)) && (warp_mma_k == (Base::kWarpGemmIterations - 1)) + // However, evaluating this condition seems more expensive than simply loading the tiles + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k == 0) { + copy_qscale_tiles(iterator_QScale); + } + if (warp_mma_k == 1) { + copy_qoffset_tiles(iterator_QOffset); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + /// Specialized mainloop iteration of matrix multiply-accumulate, for small M + CUTLASS_DEVICE + void mac_loop_iter_small_m( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // In the case of small M, memory latency dominates. We try to move uses far + // from their definitions to hide latency. + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Loading next warp-level tiles from shared memory. + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + copy_tiles_and_advance(iterator_A, iterator_B, 0); + + if constexpr(Shape::kM > 32) { + // the case of bigger m + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } else { + // the case of small m + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + } + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + if constexpr(Shape::kM > 32) { + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } else { + mac_loop_iter_small_m( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over quant scales in global memory + IteratorQScale iterator_QScale, + ///< Iterator over quant offsets in global memory + IteratorQOffset iterator_QOffset, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h new file mode 100644 index 000000000000..2c49888c9450 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/default_mma_tensor_op.h + * Default warp-level GEMM operators selected by data type, size, and layouts of operands. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Data type of quant scales + typename ElementQScale, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale, + /// Data type of quant offsets + typename ElementQOffset, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset, + /// Blocking size of quantization + typename QuantBlocking, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultQuantBMmaTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::QuantBMmaTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, QuantBlocking, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h new file mode 100644 index 000000000000..26239161cf8a --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -0,0 +1,882 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file quantb_meta_mma_tensor_op_tile_iterator.h + * @brief Templates for loading quantization meta data for operand B + * from shared memory to fragments. This is meant to be used in + * lock step with the operand B tile iterator. Containing logic + * to figure out the operand B layout in the tensor core, + * and deliver each meta data element to its corresponding + * operand B element for dequantization. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace{ + +struct b32_pair{ + uint32_t a; + uint32_t b; +}; + +struct fp16_quad{ + cutlass::half_t a; + cutlass::half_t b; + cutlass::half_t c; + cutlass::half_t d; +}; + +struct b16_quad{ + int16_t a; + int16_t b; + int16_t c; + int16_t d; +}; + +union b64 { + uint64_t single; + b32_pair pair; + b16_quad quard; + fp16_quad fp16_quad; +}; + +static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); + +/// Convert packed 4b weights into fp16(weight + 16) +/// Current bit hacking only supports fp16, need to add bf16 later. +/// +template +CUTLASS_DEVICE +void weights2Half(cutlass::Array const &weights, + cutlass::Array& dest) +{ + static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const uint32_t* w_oct = reinterpret_cast(weights.data()); + + CUTLASS_PRAGMA_UNROLL + for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + // static_cast(16 + weight) + // 4b weights are prepacked into [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent weights + // are in different 16b half words, making it easier to convert to fp16. + asm volatile( + "{\n\t" + " shl.b32 %0, %4, 6;\n" + " shl.b32 %1, %4, 2;\n" + " shr.u32 %2, %4, 2;\n" + " shr.u32 %3, %4, 6;\n" + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 + " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" + "}\n" + : "=r"(dest_pair[0]), "=r"(dest_pair[1]), + "=r"(dest_pair[2]), "=r"(dest_pair[3]) + : "r"(*w_oct)); +#else + assert(0); +#endif + } + +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +// Traits to describe the layout of quantization meta data layout in a MMA fragment +// Since operand B is quantized on a per block basis, it's one meta data per block. + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ArchMmaOperator = ArchMmaOperator_; + + static_assert(Threads == 32, "This iterator should work in a warp only."); + + /// Shape of the curresponding operand B tile iterator + using TileShapeB = MatrixShape; + + // Tensor core operand B layout is a column major 4x8 tile, divided + // into 32 threads (T0 ~ T31) as shown below. Each element of the tile is 32b, + // so for fp16 it becomes 8 x 8, and int8 it becomes 16 x 8. + // T0 | T4 | T8 | T12 | T16 | T20 | T24 | T28 + // T1 | T5 | T9 | T13 | T17 | T21 | T25 | T29 + // T2 | T6 | T10 | T14 | T18 | T22 | T26 | T30 + // T3 | T7 | T11 | T15 | T19 | T23 | T27 | T31 + using CoreTile = layout::PitchLinearShape<4, 8>; + + /// Each thread holds a 32b fragment per tile: for half precision, it's 2 elements, 4 elements for int8 + static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits::value; + + /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension) + static int const kBTilesPerMma = + sizeof_bits::value * ArchMmaOperator::FragmentB::kElements / 32; + static_assert(kBTilesPerMma == 1 || kBTilesPerMma == 2, "Only support 1 or 2 operand B tiles per mma."); + + /// Each operand B tile iterator load covers a number of mma instructions + static int const kMmaIterationsB = WarpShapeB::kColumn / ArchMmaOperator::Shape::kN; + + /// Number of B elements a fragment of meta data should cover + static int const kExpandedSize = kNumBsPerCoreTileFragement * kBTilesPerMma * kMmaIterationsB; + + // Now we figure out how many meta data elements to load for each TileShapeB + + /// Number of meta elements per CoreTile. + static int const kCoreTileFragementSize = (kNumBsPerCoreTileFragement + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension + /// exceeds the tile depth, so two tiles share the same meta data + static int const kTilesPerMma = ((kBTilesPerMma == 2) && + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 : 1; + + /// stride to reach the meta data for the next CoreTile on the K dimension + static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Stride on N dimension should be the tile width, shrunk by blocking size on this dimension. + static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn; + + /// On N dimension, how many tiles share the same meta data + static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided; + + /// Each fragment should cover kMmaIterationsB number of mma intructions on the N dimension. + /// When blocking size on this dimension exceeds the tile width, multiple iterations + /// would share the same data. + static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats; + + static int const kFragementSize = kCoreTileFragementSize * kTilesPerMma * kMmaIterations; + + CUTLASS_DEVICE + static MatrixCoord lane_position(int lane_id) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + // The scale and offset tensors are prepacked to reduce the number of load instructions. + return make_Coord((lane_id % CoreTile::kContiguous) * 4, + lane_id / CoreTile::kContiguous); + } else { + return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, + lane_id / CoreTile::kContiguous); + } + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is to load quantization meta data for operand B from +/// shared memory to fragments (hopefully allocated to registers by compilers). +/// Examples of meta data include scale or offsets. The operand B matrix is +/// quantized on a per block basis, meaning one element of meta data per block. +/// +/// This is meant to be used in lock step with the operand B tile iterator. +/// So all parameters are logical positions in the operand B tiles. +/// The goal here is to deliver each meta data element to its corresponding +/// operand B element for dequantization. As a result, we need to figure +/// out the operand B layout in the tensor core. +/// +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class QuantBMetaMmaTensorOpTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using Layout = cutlass::layout::ColumnMajor; + using ElementOffset = ElementOffset_; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, + "Only support row blocking for column major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + + using AccessTypeScale = Array; + using AccessTypeOffset = Array; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)){} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + // The scale and offset tensors are prepacked to reduce the number of load instructions needed + const int row = lane_position_.row(); + const int column = lane_position_.column() / BlockingShape::kColumn; + + Array *dst_ptr = reinterpret_cast*>(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + + if constexpr(kHasOffset){ + Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + *dst_ptr_offset = *src_ptr_offset; + dst_ptr_offset++; + } + } + + } else { + // Other cases, offsets and scales are not prepacked. + + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + + AccessTypeScale* dst_ptr = reinterpret_cast(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + + if constexpr(kHasOffset){ + AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &fragment_offsets, + Array const &weights, + Array& dest){ + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); + static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + if constexpr(kBTilesPerMma == 2){ + // Optimize for a special case of: + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const b64* scales_ptr = reinterpret_cast(scales.data()); + [[maybe_unused]] const ElementOffset* fragment_offsets_ptr = nullptr; + if constexpr(kHasOffset) { fragment_offsets_ptr = fragment_offsets.data(); } + + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + // dequantize: d = scale * (weight - offset) + // to use FMA, d = scale * weight + (scale * (-offset)) + + [[maybe_unused]] b64 offsets{0}; + if constexpr(kHasOffset) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + const uint32_t* p = reinterpret_cast(fragment_offsets_ptr); + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + + fragment_offsets_ptr += 4; + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); +#endif + } + + CUTLASS_PRAGMA_UNROLL + for (int n_r = 0; n_r < kNRepeats; n_r++){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %1, %3, %1, %5;\n" + "}\n" + : "+r"(dest_pair[0]), "+r"(dest_pair[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(offsets.pair.a), "r"(offsets.pair.b)); +#else + assert(0); +#endif + dest_pair += 2; + } + scales_ptr++; + } + + } else { + // unoptiomized path for other cases, very slow + int out_idx = 0; + ElementScale offset; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ElementScale s = scales[idx]; + if constexpr(kHasOffset){ + offset = s * static_cast(-16 - static_cast(fragment_offsets[idx])); + } else { + offset = s * static_cast(-16-8); + } + dest[out_idx] = s * dest[out_idx] + offset; + out_idx++; + } + } + } + + } + + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using ElementOffset = ElementOffset_; + using Layout = cutlass::layout::RowMajor; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, + "Only support column blocking for row major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) + {} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); + + ElementScale* src_ptr = pointer_ + layout_({row, column}); + ElementScale* dst_ptr = frag.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr[n_idx] = src_ptr[n_idx * kNStride]; + } + + if constexpr(kHasOffset){ + ElementOffset* src_ptr_offset = pointer_offset_ + layout_offset_({row, column}); + ElementOffset* dst_ptr_offset = frag_offset.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr_offset[n_idx] = src_ptr_offset[n_idx * kNStride]; + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; + int idx = col * kMmaIterations + n_idx; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + ElementScale addon[kMmaIterationsB]; + if constexpr (kMmaIterationsB % 4 == 0) { + const b64* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + scales_ptr++; + p++; + addon_ptr += 2; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + assert(0); +#endif + scales_ptr++; + addon_ptr += 2; + } + } + } else if constexpr (kMmaIterationsB % 2 == 0) { + if constexpr (kHasOffset){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + const uint32_t* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + // possible buffer over read 2 bytes here. + const uint32_t* p = reinterpret_cast(offsets.data()); + + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) + "r"(p[0])); +#else + assert(0); +#endif + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); +#else + assert(0); +#endif + } + } else { + // kMmaIterationsB == 1 + if constexpr(kHasOffset){ + uint8_t zp = offsets[0]; + addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); + } else { + addon[0] = scales[0] * static_cast(-16-8); + } + } + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; + dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; + out_idx += 2; + } + } + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h new file mode 100644 index 000000000000..f29cedf326a4 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -0,0 +1,361 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/mma_tensor_op.h + * Templates implementing warp-level matrix multiply-accumulate operations + * targeting tensor cores. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class QuantBMmaTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + // warp B MatrixShape<64, 64>, + // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, + // instruction op shape cutlass::MatrixShape<16, 8>, + // kPartitionsK 1 + // FragmentB::kElements 32 + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; // cutlass::Array + + /// Storage for transformed B tile + /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded + /// we multiply the number of elements by 4. + /// TODO: make sure ArchMmaOperator::ElementB same as dequantized ElementB + /// and change the transform function below to perform dequantization + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + using ElementQScale = ElementQScale_; + using SmemLayoutQScale = SmemLayoutQScale_; + using QuantBlocking = QuantBlocking_; + + using ElementQOffset = ElementQOffset_; + using SmemLayoutQOffset = SmemLayoutQOffset_; + + /// Iterates over the quantization parameters in memory + using WarpQScaleShape = MatrixShape<(Shape::kK / QuantBlocking::kRow), (Shape::kN / QuantBlocking::kColumn)>; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(WarpQScaleShape::kCount > 0, "QuantBlocking too big to fit in a warp block!"); + + // TODO This is an expanding iterator, it needs to replicate the quantization parameters + // to all threads in the warp. + using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; + + using FragmentQScale = typename IteratorQMeta::FragmentScale; + using FragmentQOffset = typename IteratorQMeta::FragmentOffset; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + QuantBMmaTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentB &dst_B, + FragmentB const &B, + FragmentQScale const &scales, + FragmentQOffset const &offsets) const { + + Array const *ptr_B = + reinterpret_cast const *>(&B); + IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bdd4dba521eb..ce7838556fbf 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 1e83dd1cec40..32e9cc98106d 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -23,19 +23,34 @@ Module Name: #include "mlas.h" #include "mlas_gemm_postprocessor.h" +/** + * @brief Define compute types of block quantization, in order of decreasing accuracy. + */ +typedef enum { + CompUndef = 0, /*!< undef */ + CompFp32, /*!< input fp32, accumulator fp32 */ + CompFp16, /*!< input fp16, accumulator fp16 */ + CompBf16, /*!< input bf16, accumulator fp32 */ + CompInt8, /*!< input int8, accumulator int32 */ + + // special values that should be the first and last actual values + + CompMostAccurate = CompUndef, + CompLeastAccurate = CompInt8, +} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; + /** * @brief Data parameters for float/n-bit quantized int GEMM routine. */ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; @@ -46,13 +61,26 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * A must be a float32 matrix * B must be a quantized and packed n-bit int matrix * + * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * + * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasSQNBitGemmPackQuantBData(). + * + * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * point to an intermediate workspace buffer. + * * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B * @param[in] BatchN number of batches * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] Workspace Address of intermediate workspace buffer. + If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ void MLASCALL @@ -63,158 +91,91 @@ MlasSQNBitGemmBatch( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL MlasIsSQNBitGemmAvailable( size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** - * @brief Define compute types of block quantization - */ -typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32 = 1, /*!< input fp32, accumulator fp32 */ - CompFp16 = 2, /*!< input fp16, accumulator fp16 */ - CompBf16 = 3, /*!< input bf16, accumulator fp32 */ - CompInt8 = 4 /*!< input int8, accumulator int32 */ -} MLAS_SQNBIT_COMPUTE_TYPE; - -/** - * @brief Data parameters for NBits GEMM routine - * C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * All except C are [in] parameters - */ -struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (packed nbits blob)*/ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ -}; - -/** - * @brief Compute the byte size of the parameter combination + * @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM + * implementation. If zero, no intermediate workspace is required. * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return size of the packing buffer, 0 if the operation is not yet supported. + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type -); - -/** - * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. - * - * @param PackedBuf packed data buffer - * @param QData quantized data buffer - * @param Scale scale pointer - * @param Zp zero point pointer - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor - * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where - * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale - * (is_asym is false) and Zp(is_asym is true). - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, +MlasSQNBitGemmBatchWorkspaceSize( + size_t M, size_t N, size_t K, - size_t ldb, - size_t block_size, - int nbits, - bool is_asym, - bool last_call, - MLAS_SQNBIT_COMPUTE_TYPE comp_type, - MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Unpack and dequantize to fp32 - * - * @param FpData unpacked float32 data - * @param PackedBuf quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmUnPackB( - float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** - * @brief Get the workspace size required by computation. + * @brief Gets the size in bytes of the packed quantized B data. + * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). + * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasSQNBitGemmBatch(). * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @return Workspace size in bytes + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** - * @brief Batched GEMM: C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob + * @brief Packs the quantized B data in a format that the kernel expects. * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] WorkSpace temporary buffer - * @param[in] ThreadPool - * @return + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] QuantBData quantized B data + * @param[out] PackedQuantBData packed quantized B data + * @param[in] ThreadPool optional thread pool to use */ void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const void* QuantBData, + void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr ); diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 000000000000..e424c30515e9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm new file mode 100644 index 000000000000..3e83bc852f55 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm @@ -0,0 +1,103 @@ +;++ +; +;Copyright (c) Microsoft Corporation. All rights reserved. +; +;Licensed under the MIT License. +; +;Module Name: +; +; SoftmaxKernelAvx512F.asm +; +;Abstract: +; +; This module implements the kernels for the single precision softmax +; operation. +; +; This implementation uses AVX512F instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + EXTERN MlasMinimumF32Value:NEAR + +;++ +; +;Routine Description: +; +; This routine implements a vectorized kernel to find the maximum value of +; the supplied buffer. +; +;Arguments: +; +; Input (rcx) - Supplies the input buffer. +; +; N (rdx) - Supplies the number of elements to process. +; +;Return Value: +; +; Returns the maximum value of the supplied buffer. +; +;-- + + LEAF_ENTRY MlasReduceMaximumF32KernelAvx512F, _TEXT + + vbroadcastss zmm0,DWORD PTR [MlasMinimumF32Value] + test rdx,rdx + jz ExitKernel + cmp rdx,16 + jb ProcessRemainingCountBy1 + cmp rdx,64 + jb ProcessRemainingCountBy16 + vmovaps zmm1,zmm0 + vmovaps zmm2,zmm0 + vmovaps zmm3,zmm0 + +ProcessRemainingCountBy64: + vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] + vmaxps zmm1,zmm1,ZMMWORD PTR [rcx+16*4] + sub rdx,64 + vmaxps zmm2,zmm2,ZMMWORD PTR [rcx+32*4] + vmaxps zmm3,zmm3,ZMMWORD PTR [rcx+48*4] + add rcx,64*4 ; advance input by 64 elements + cmp rdx,64 + jae ProcessRemainingCountBy64 + vmaxps zmm0,zmm0,zmm1 ; reduce to single vector + vmaxps zmm2,zmm2,zmm3 + vmaxps zmm0,zmm0,zmm2 + +ProcessRemainingCountBy16: + cmp rdx,16 + jb ProcessRemainingCountLessThan16 + vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] + sub rdx,16 + add rcx,16*4 ; advance input by 16 elements + jmp ProcessRemainingCountBy16 + +ProcessRemainingCountLessThan16: + vextractf32x8 ymm1,zmm0,1 ; reduce to single scalar + vmaxps ymm0,ymm0,ymm1 + vextractf128 xmm1,ymm0,1 + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0EEh + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,055h + vmaxss xmm0,xmm0,xmm1 + test rdx,rdx + jz ExitKernel + +ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rcx] + add rcx,4 ; advance input by 1 element + dec edx + jnz ProcessRemainingCountBy1 + +ExitKernel: + vzeroupper + ret + + LEAF_END MlasReduceMaximumF32KernelAvx512F, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h index 3eb0700932fa..caf94af02362 100644 --- a/onnxruntime/core/mlas/lib/amx_common.h +++ b/onnxruntime/core/mlas/lib/amx_common.h @@ -18,7 +18,7 @@ Module Name: #include "mlasi.h" -#ifdef WIN32 +#ifdef _WIN32 #define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) #define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h deleted file mode 100644 index 9cd1711a3ffd..000000000000 --- a/onnxruntime/core/mlas/lib/jblas_defs.h +++ /dev/null @@ -1,73 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - ---*/ - -#pragma once - -#include "jblas/jit_blas_prologue_b.h" -#include "jblas/jit_blas_wrapper.h" - -namespace jblas -{ - -/* -Name conversion explaination: -Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight -classes) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores. -*/ -template -using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationKBlockBaseF32, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompFp32BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -/* -Name conversion explaination: -Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores. -*/ -template -using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationF32KBlockQuantize, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompInt8BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; -using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; -using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; -using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency -using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; -using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; -using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; -using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency - -class ORTThreading : public jblas::parallel::IThreading -{ - public: - ORTThreading(void* tp); - void parallel_for(const jblas::parallel::thread_func& func) override; - void set_threads(int nthreads) override { assert(0); } - void sync() override { assert(0); } - void* mTp; -}; - -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp deleted file mode 100644 index f3cae3186c28..000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.cpp +++ /dev/null @@ -1,534 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.cpp - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#include "jblas_gemm.h" - -#include "jblas_defs.h" -#include "mlasi.h" - -using namespace jblas; - -jblas::ORTThreading::ORTThreading(void* tp) - : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) -{ -} - -void -jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) -{ - MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { - func(static_cast(tid)); - }); -} - -template -static void -JblasSQ4GemmCompF32( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - if (M <= 16) { - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - if (B->mIsAsym) { - reduceA.assign(WorkSpace); - ORTThreading single(nullptr); - kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single); - } - typename Launcher::BEpiParam blkargs{ - B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), - reduceA.template get(), reduceA.lda}; - - typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); - } else { - using Parallel = jblas::parallel::gemm::SchedulerBase; - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - - typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}}; - jblas::parallel::GemmBaseRun(kernel, args, th); - } -} - -template -static void -JblasSQ4GemmCompInt8( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - static Launcher kernel; - auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym); - quanA.assign(WorkSpace); - if (M <= 16) { - ORTThreading single(nullptr); - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); - } else { - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); - } - typename Launcher::Param args{ - M_, - N_, - K_, - B->mBlockSize, - {A, lda_, &quanA}, - {B}, - {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep, - quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), - quanA.template RPtr(), B->mBlockSize}, - {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); -} - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - ORTThreading orth(ThreadPool); - bool processed = true; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - } - } else { - processed = false; - break; - } - } - return processed; -} - -template -static size_t -JblasSQ4GemmCompF32WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - auto M_ = static_cast(M); - auto K_ = static_cast(K); - (void)(N); - (void)(lda); - (void)(ldc); - if (M <= 16) { - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - if (B->mIsAsym) { - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - return reduceA.mSize; - } - return 0; - } else { - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - return 0; - } - return 0; -} - -template -static size_t -JblasSQ4GemmCompInt8WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - static Launcher kernel; - (void)(N); - (void)(lda); - (void)(ldc); - auto quanA = kernel.mProA.createStorage( - static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym - ); - return quanA.mSize; -} - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ - GetCPUDevice(); - size_t size = 0; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - } - } - } - return size; -} - -template -static size_t -JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) -{ - static T launcher; - auto stor = launcher.mProB.createStorage( - static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, - JBLAS_DTYPE::BF16, isAsym - ); - // TODO(Yu) support more scale dtype - return stor.mSize; -} - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) -{ - GetCPUDevice(); - if (K % BlkSize != 0) { - return 0; - } - // from low precision to high precision - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - break; - default: - return 0; - } - return 0; -} - -template -static void -JblasQ4GemmPackBImpl( - void* PackedBuf, - size_t BlkSize, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - bool IsAsym, - bool lastCall, - size_t ldb, - MLAS_THREADPOOL* ThreadPool -) -{ - static T JblasKernel; - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto stor = JblasKernel.mProB.createStorage( - N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym - ); - stor.assign(reinterpret_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); - if (lastCall) { - JblasKernel.mProB.reduceWeight(&stor, &orth); - } -} - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - // explicit statement fall through. - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - default: - return false; - } - return false; -} - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); - auto uptr = std::unique_ptr(ptr); - ORTThreading orth(ThreadPool); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto ldb_ = static_cast(ldb); - GetCPUDevice(); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - } - return true; - } - return false; -} diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h deleted file mode 100644 index 044dc5e849a0..000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.h +++ /dev/null @@ -1,61 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.h - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#pragma once - -#include "mlas_qnbit.h" - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -); - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb - , MLAS_THREADPOOL* ThreadPool); - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -); - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bb8b17031a8..4b93dde1bcef 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -193,6 +193,8 @@ class MLASCPUIDInfo bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); @@ -200,6 +202,7 @@ class MLASCPUIDInfo bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -357,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -727,6 +744,10 @@ extern "C" { #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -825,6 +846,7 @@ extern "C" { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx512F; MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx; #endif @@ -856,6 +878,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 8329a34f1338..a53c5085b10c 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -60,6 +60,10 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -70,6 +74,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif @@ -415,6 +421,7 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx512F; this->ComputeExpF32Kernel = MlasComputeExpF32KernelAvx512F; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelAvx512F; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx512F; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelAvx512F; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelAvx512F; this->NchwcBlockSize = 16; @@ -482,7 +489,6 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -512,6 +518,9 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; + + // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h new file mode 100644 index 000000000000..de7fd72fad45 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 000000000000..a6a73996c548 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 7f1d1b084aec..38c31c884176 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -11,38 +11,535 @@ Module Name: Abstract: This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + as well as some SQNBitGemm-related query functions. --*/ #include "sqnbitgemm.h" -#ifdef MLAS_JBLAS -#include "jblas_gemm.h" -#endif + +#include + +namespace +{ + +enum SQNBitGemmVariant { + SQNBitGemmVariantInvalid = -1, + + // Valid variants + + SQNBitGemmVariant_BitWidth4_CompFp32 = 0, + SQNBitGemmVariant_BitWidth4_CompInt8, + + // End of valid variants + + // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Its value is used as an array size. + SQNBitGemmVariantCount, +}; + +SQNBitGemmVariant +GetSQNBitGemmVariant( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + if (BlkBitWidth == 4 && + (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (ComputeType == CompFp32 || + ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + return SQNBitGemmVariant_BitWidth4_CompFp32; + } else if (ComputeType == CompInt8) { + return SQNBitGemmVariant_BitWidth4_CompInt8; + } + } + + return SQNBitGemmVariantInvalid; +} + +} // namespace + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return false; + } + + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: { + return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && + Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case SQNBitGemmVariant_BitWidth4_CompInt8: { + return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && + Dispatch->QuantizeARow_CompInt8 != nullptr; + } + default: { + return false; + } + } +} + +namespace +{ + +size_t +SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant) +{ + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + +size_t +SQNBitGemmPerGemmWorkspaceSize( + SQNBitGemmVariant Variant, + size_t M, + size_t N, + size_t K, + size_t BlkLen +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } +} + +size_t +SQNBitGemmPerGemmWorkspaceStride( + SQNBitGemmVariant Variant, + size_t M, + size_t N, + size_t K, + size_t BlkLen +) +{ + const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen); + const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant); + return MlasDivRoundup(Size, Alignment) * Alignment; +} + +} // namespace + +size_t MLASCALL +MlasSQNBitGemmBatchWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + if (PerGemmWorkspaceStride == 0) { + return 0; + } + + const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + + const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; + + return WorkspaceSize + Alignment - 1; +} + +size_t MLASCALL +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->SQ4BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); + } + + return 0; +} + +void MLASCALL +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBData), + ThreadPool + ); + return; + } +} namespace { -// Get quantization variant based on `BlkBitWidth` and `BlkLen`. -// Return -1 if the input values are unsupported. -int32_t -GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +MLAS_FORCEINLINE void +AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +typedef void(SQNBitGemmFn)( + size_t BlkLen, + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* PerGemmWorkspace, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +void +SQ4BitGemm_CompFp32( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + BlkLen, + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +void +SQ4BitGemm_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) { - int32_t type = -1; - if (BlkBitWidth == 4 && BlkLen == 16) { - type = QuantVariant_BitWidth4_BlockSize16; - } else if (BlkBitWidth == 4 && BlkLen == 32) { - type = QuantVariant_BitWidth4_BlockSize32; - } else if (BlkBitWidth == 4 && BlkLen == 64) { - type = QuantVariant_BitWidth4_BlockSize64; - } else if (BlkBitWidth == 4 && BlkLen == 128) { - type = QuantVariant_BitWidth4_BlockSize128; - } else if (BlkBitWidth == 4 && BlkLen == 256) { - type = QuantVariant_BitWidth4_BlockSize256; + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; } - return type; + // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. + // TODO Replace it with an optimized implementation. + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + + c_blk += ldc; + a_row += lda; + } + } } +typedef void(InitializeWorkspaceFn)( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +); + +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); +} + +struct Operations { + InitializeWorkspaceFn* InitializeWorkspace = nullptr; + SQNBitGemmFn* SQNBitGemm = nullptr; +}; + +constexpr auto OperationMap = []() { + std::array ops; + + ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; + + ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; + + return ops; +}(); + } // namespace void MLASCALL @@ -53,17 +550,43 @@ MlasSQNBitGemmBatch( const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, MLAS_THREADPOOL* ThreadPool ) { - const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); - MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + assert(Variant != SQNBitGemmVariantInvalid); + + // + // Ensure `Workspace` has correct alignment. + // + if (Workspace != nullptr) { + const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); + Workspace = reinterpret_cast( + (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) + ); + } + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + + if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + InitializeWorkspaceOperation != nullptr) { + InitializeWorkspaceOperation( + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + ); + } + + const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - auto Data = &DataParams[gemm_i]; - Operation(K, Data, 0, M, 0, N); + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); } return; } @@ -112,7 +635,10 @@ MlasSQNBitGemmBatch( MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; - auto Data = &DataParams[gemm_i]; + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = reinterpret_cast( + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride + ); const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; @@ -123,149 +649,6 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } - -bool MLASCALL -MlasIsSQNBitGemmAvailable( - size_t BlkBitWidth, - size_t BlkLen -) -{ - const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); - if (QuantVariant == -1) { - return false; - } - - if (GetMlasPlatform().SQNBitGemmDispatch == nullptr || - GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) { - return false; - } - - return true; -} - -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); - if (jsize) { - return jsize; - } - } -#endif - (void)(N); - (void)(K); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(CompType); - return 0; -} - -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - int nbits, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { - return; - } - } -#endif - (void)(PackedBuf); - (void)(QData); - (void)(Scale); - (void)(Zp); - (void)(N); - (void)(K); - (void)(ldb); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(lastCall); - (void)(CompType); - (void)(ThreadPool); -} - -void MLASCALL -MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ -#ifdef MLAS_JBLAS - if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { - return; - } -#endif - (void)(FpData); - (void)(PackedBuf); - (void)(N); - (void)(K); - (void)(ldb); - (void)(ThreadPool); -} - -size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ -#ifdef MLAS_JBLAS - return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - return 0; -} - -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetMlasPlatform(); -#ifdef MLAS_JBLAS - if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { - // PackedWeight is created by jblas - return; - } -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - (void)(WorkSpace); - (void)(ThreadPool); -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index f8f7dcd43699..3992bc3e452a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -10,98 +10,23 @@ Module Name: Abstract: - This module includes: + This module includes kernel function prototypes and helper functions for + implementing SQNBitGemm. - - Declaration of the set of template functions used to implement a kernel - for a matrix/matrix multiplication, A*B, where A is a float matrix and B is - a n-bit quantized integer matrix (QNBitGemm). - - - A shared kernel driver function template, MlasSQNBitGemmOperation. - - - Kernel dispatch structure. - - The B matrix is block quantized, which means that its values are grouped - into blocks which each have one scale and optional zero point. Each - quantized value in B is n-bits wide. + SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float + matrix and B is a n-bit quantized integer matrix. B is block quantized, + meaning values of B are divided into blocks and each block has its own + scale and optional zero point. --*/ #pragma once +#include + #include "mlas_qnbit.h" #include "mlasi.h" -// -// Kernel implementation template declarations -// - -/** - * @brief Multiply float matrix A with quantized n-bit integer matrix B. - * B is block quantized and column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. - * - * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam BlkLen Number of values in a block. - * @tparam KernelType Hardware-specific kernel type. - * - * @param A Supplies the A matrix. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. - * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - * @param Bias Bias vector of length N. - */ -template -MLAS_FORCEINLINE void -MlasSQNBitGemmM1Kernel( - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -); - -/** - * @brief Dequantize B into the format expected by the Sgemm kernel. - * B is block quantized and column major. - * This is equivalent to dequantizing B and then running - * MlasSgemmCopyPackB. - * - * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam BlkLen Number of values in a block. - * @tparam KernelType Hardware-specific kernel type. - * - * @param[out] FpData Supplies the output buffer for the dequantized B float data. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param CountN Number of columns of B. - * @param CountK Number of rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - */ -template -MLAS_FORCEINLINE void -MlasQNBitBlkDequantBForSgemm( - float* FpData, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -); - -// -// MlasQNBitGemmOperation and helpers -// - constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { @@ -119,169 +44,201 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) } } -MLAS_FORCEINLINE void -MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +// +// Quantized int8 block helpers. +// + +MLAS_FORCEINLINE +const float& +Q8BlkScale(const std::byte* BlkPtr) { - for (size_t m = 0; m < CountM; m++) { - const float* bias = Bias; - float* sum = C; - for (size_t n = 0; n < CountN; n += 4) { - if (CountN - n < 4) { - for (size_t nn = n; nn < CountN; nn++) { - *sum += *bias; - sum++; - bias++; - } - break; - } - - MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); - acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); - MlasStoreFloat32x4(sum, acc_x); - bias += 4; - sum += 4; - } - C += ldc; - } + return *reinterpret_cast(BlkPtr); } -template -MLAS_FORCEINLINE void MLASCALL -MlasSQNBitGemmOperation( - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) +MLAS_FORCEINLINE +float& +Q8BlkScale(std::byte* BlkPtr) { - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const float* A = DataParams->A + RangeStartM * lda; - - const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const uint8_t* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - MlasSQNBitGemmM1Kernel( - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } + return *reinterpret_cast(BlkPtr); +} - constexpr size_t StrideN = 32; - size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); - MlasThreadedBufAlloc(bufsize); - auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); - // - // Step through each slice of matrix B along the N dimension. - // +MLAS_FORCEINLINE +const int8_t* +Q8BlkData(const std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, StrideN); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - MlasQNBitBlkDequantBForSgemm( - dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks - ); - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true - ); -#else - auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); -#endif - - if (bias) { - MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); - } - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, - RowsHandled, CountN, ldc - ); - } - - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } +MLAS_FORCEINLINE +int8_t* +Q8BlkData(std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkSize(size_t BlkLen) +{ + const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(BlkSize % alignof(float) == 0); + return BlkSize; +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkAlignment() +{ + return alignof(float); } // // Kernel dispatch structure. // -typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - -enum QuantVariant { - QuantVariant_BitWidth4_BlockSize16, - QuantVariant_BitWidth4_BlockSize32, - QuantVariant_BitWidth4_BlockSize64, - QuantVariant_BitWidth4_BlockSize128, - QuantVariant_BitWidth4_BlockSize256, - QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. - // Its value is used as an array size. -}; - struct MLAS_SQNBIT_GEMM_DISPATCH { - MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { - // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. - }; + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ + typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ + typedef void(SQ4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + + // + // CompFp32 kernel function prototypes. + // + + /** + * @brief Multiply float matrix A with quantized 4-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + // + // CompInt8 kernel function prototypes. + // + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; + + /** + * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeARow_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA + ); + + QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 63afe57dd913..9d7b0ae06e22 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -15,19 +15,114 @@ Module Name: --*/ -#include "sqnbitgemm.h" - #include #include #include #include +#include "sqnbitgemm.h" + // -// Hardware-specific kernel type. +// Quantized B data packing function implementation. +// + +namespace +{ + +size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + const size_t SubBlkLen = (ComputeType == CompInt8) + ? ((BlkLen == 16) ? 16 : 32) + : 16; + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +} // namespace + +// +// General helpers. // -struct MLAS_SQNBIT_GEMM_KERNEL_NEON { -}; namespace { @@ -70,7 +165,7 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) template MLAS_FORCEINLINE void -LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) +LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) { static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); @@ -101,13 +196,23 @@ LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) } } -template +} // namespace + +// +// CompFp32 kernel implementation. +// + +namespace +{ + +template MLAS_FORCEINLINE void -ComputeDotProducts( +ComputeDotProducts_BlkBitWidth4_CompFp32( + size_t BlkLen, const float* ARowPtr, - const uint8_t* QuantBDataColPtr, + const std::byte* QuantBDataColPtr, const float* QuantBScaleColPtr, - const uint8_t* QuantBZeroPointColPtr, + const std::byte* QuantBZeroPointColPtr, float* SumPtr, size_t CountK, size_t StrideQuantBData, @@ -116,8 +221,13 @@ ComputeDotProducts( const float* BiasPtr ) { + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + const uint8x8_t LowMask = vdup_n_u8(0x0F); // Manual conversion to float takes place in two steps: @@ -135,9 +245,10 @@ ComputeDotProducts( float32x4_t acc[NCols]{}; - const uint8_t* QuantBData = QuantBDataColPtr; + const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true for (size_t k = 0; k < CountK; k += BlkLen) { const size_t k_blk_len = std::min(CountK - k, BlkLen); @@ -147,52 +258,42 @@ ComputeDotProducts( [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } ); - float offset[NCols]; // Includes zero point and float conversion offset of 16. - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16. + // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { - const uint8_t zp_packed = + const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); - offset[i] = 16.0f + zp; - }); - } else { - UnrolledLoop([&](size_t i) { - constexpr float zp = 8.0f; - offset[i] = 16.0f + zp; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = 16.0f + std::to_integer(zp); }); } - constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { // load A row vector elements // load `SubBlkLen` elements from A, padded with 0's if there aren't enough const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); float32x4_t av[4]{}; - LoadData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); // load B column vectors uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; UnrolledLoop([&](size_t i) { - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); - }); - - uint8x8_t bv_u8_unzipped[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); }); uint8x8_t bv_u8[NCols][2]; UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); - bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); }); - // dequantize B - // shift left 3 and widen to 16 bits uint16x8_t bv_u16[NCols][2]; UnrolledLoop([&](size_t i) { @@ -221,10 +322,17 @@ ComputeDotProducts( }); // subtract float conversion offset (16) and zero point - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } // multiply by scale UnrolledLoop([&](size_t i) { @@ -241,7 +349,9 @@ ComputeDotProducts( // increment pointers to next block QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -262,19 +372,14 @@ ComputeDotProducts( } } -} // namespace - -// -// MlasSQNBitGemmKernel and helpers. -// - -template -MLAS_FORCEINLINE void -MlasSQNBitGemmM1KernelNeon( +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( + size_t BlkLen, const float* A, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, float* C, size_t CountN, size_t CountK, @@ -282,6 +387,7 @@ MlasSQNBitGemmM1KernelNeon( const float* Bias ) { + constexpr size_t BlkBitWidth = 4; constexpr size_t NCols = 4; const float* ARowPtr = A; @@ -295,16 +401,17 @@ MlasSQNBitGemmM1KernelNeon( const float* BiasPtr = Bias; - const uint8_t* QuantBDataColPtr = QuantBData; + const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; float* SumPtr = CRowPtr; int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts( + ComputeDotProducts_BlkBitWidth4_CompFp32( + BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr @@ -314,7 +421,7 @@ MlasSQNBitGemmM1KernelNeon( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -327,7 +434,8 @@ MlasSQNBitGemmM1KernelNeon( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts( + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( + BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr @@ -337,7 +445,7 @@ MlasSQNBitGemmM1KernelNeon( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -346,59 +454,70 @@ MlasSQNBitGemmM1KernelNeon( } } -#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \ - template <> \ - MLAS_FORCEINLINE void \ - MlasSQNBitGemmM1Kernel( \ - const float* A, \ - const uint8_t* QuantBData, \ - const float* QuantBScale, \ - const uint8_t* QuantBZeroPoint, \ - float* C, \ - size_t CountN, \ - size_t CountK, \ - size_t BlockStrideQuantB, \ - const float* Bias \ - ) \ - { \ - return MlasSQNBitGemmM1KernelNeon( \ - A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \ - BlockStrideQuantB, Bias \ - ); \ +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); } +} -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256) - -#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL - -// -// MlasQNBitBlkDequantBForSgemm and helpers. -// - -template MLAS_FORCEINLINE void -MlasQNBitBlkDequantBForSgemmNeon( +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, float* FpData, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, size_t CountN, size_t CountK, size_t BlockStrideQuantB ) { auto impl0_reference = [&]() { - static_assert(BlkBitWidth == 4); + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; float* Dst = FpData; - const uint8_t* QuantBDataCol = QuantBData; + const std::byte* QuantBDataCol = QuantBData; const float* QuantBScaleCol = QuantBScale; - const uint8_t* QuantBZeroPointCol = QuantBZeroPoint; + const std::byte* QuantBZeroPointCol = QuantBZeroPoint; for (size_t n = 0; n < CountN; n += 16) { const size_t nnlen = std::min(CountN - n, size_t{16}); @@ -407,20 +526,26 @@ MlasQNBitBlkDequantBForSgemmNeon( for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { const size_t kklen = std::min(CountK - k, BlkLen); - const uint8_t* b_data = + const std::byte* b_data = QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const float b_s = QuantBScaleCol[k_blk_idx]; const uint8_t b_z = (QuantBZeroPointCol != nullptr) ? ((k_blk_idx & 1) == 1) - ? QuantBZeroPointCol[k_blk_idx / 2] >> 4 - : QuantBZeroPointCol[k_blk_idx / 2] & 0x0F + ? std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] >> 4) + : std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] & std::byte{0x0F}) : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const uint8_t b_packed = b_data[kk / 2]; - const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; - const float b_value = (b_byte - b_z) * b_s; + const size_t packed_idx = kk % SubBlkLen; + + const bool is_low_half = packed_idx < (SubBlkLen / 2); + const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2); + const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2); + + const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; + const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); + const float b_value = (std::to_integer(b_byte) - b_z) * b_s; Dst[(k + kk) * 16 + nn] = b_value; } @@ -448,31 +573,651 @@ MlasQNBitBlkDequantBForSgemmNeon( impl0_reference(); } -#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \ - template <> \ - MLAS_FORCEINLINE void \ - MlasQNBitBlkDequantBForSgemm( \ - float* FpData, \ - const uint8_t* QuantBData, \ - const float* QuantBScale, \ - const uint8_t* QuantBZeroPoint, \ - size_t CountN, \ - size_t CountK, \ - size_t BlockStrideQuantB \ - ) \ - { \ - MlasQNBitBlkDequantBForSgemmNeon( \ - FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \ - ); \ +// +// CompInt8 kernel implementation. +// + +template +MLAS_FORCEINLINE void +QuantizeBlock( + size_t BlkLen, + const float* A, + size_t ElementCount, + std::byte* QuantA +) +{ + static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); + + assert(BlkLen % SubBlkLen == 0); + + // + // Scan block values first to determine scale. + // + + float amax = 0.0f; // max of absolute values of A block + + size_t k; + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); + + // find amax of SubBlkLen elements + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } + } + + // update existing amax + amax = std::max(amax, vmaxvq_f32(abs_a[0])); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + Q8BlkScale(QuantA) = scale; + + // + // Compute quantized block values. + // + + int8_t* QuantAData = Q8BlkData(QuantA); + + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + + UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); + QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); + }); + } + + // + // Zero out any remaining sub-block elements. + // + + for (; k < BlkLen; k += SubBlkLen) { + const int8x16_t Zeros = vdupq_n_s8(0); + UnrolledLoop([&](size_t i) { + vst1q_s8(QuantAData + k + i * 16, Zeros); + }); + } +} + +void MLASCALL +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 16; + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] >> 4) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); + + // load B + const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); + const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); + int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + bv1 = vsubq_s8(bv1, bzp1); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 8 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + + // load B + const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); + const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); + const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + bv_lo1 = vsubq_s8(bv_lo1, bzp1); + bv_hi1 = vsubq_s8(bv_hi1, bzp1); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPoint) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + + // quantized dot product + int32x4_t dot0{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; } +} -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen > 32); + assert(BlkLen % 32 == 0); + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + // compute combined scale + const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantAPtr) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp = [&]() -> int8x16_t { + if constexpr (HasZeroPoint) { + return vdupq_n_s8( + ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) + : std::to_integer((*QuantBZeroPointPtr) >> 4) + ); + } else { + return vdupq_n_s8(8); + } + }(); + + const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { + // load A + const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); + const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); + const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); + const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp); + bv1 = vsubq_s8(bv1, bzp); + bv2 = vsubq_s8(bv2, bzp); + bv3 = vsubq_s8(bv3, bzp); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); + dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale); + + // increment block data pointers to next sub-block + QuantADataPtr += 16 * 4; + QuantBDataPtr += 16 * 2; + } + + // increment other block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } -#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t /*CountK*/, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } +} + +} // namespace // // Kernel dispatch structure definition. @@ -480,10 +1225,15 @@ SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp index 955b7c5deee9..43a12b37e4ff 100644 --- a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp +++ b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp @@ -171,11 +171,9 @@ Return Value: if (k > 0) { Row0AElements0 = a[0]; - Row0AElements1 = a[1]; if (ProcessTwoRows) { Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; } BElements0 = MlasLoadFloat32x4(B + 0); diff --git a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S new file mode 100644 index 000000000000..db9728604656 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S @@ -0,0 +1,101 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SoftmaxKernelAvx512F.s + +Abstract: + + This module implements the kernels for the single precision softmax + operation. + + This implementation uses AVX512F instructions. + +--*/ + +#include "asmmacro.h" + + .intel_syntax noprefix + + .text + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to find the maximum value of + the supplied buffer. + +Arguments: + + Input (rdi) - Supplies the input buffer. + + N (rsi) - Supplies the number of elements to process. + +Return Value: + + Returns the maximum value of the supplied buffer. + +--*/ + + FUNCTION_ENTRY MlasReduceMaximumF32KernelAvx512F + + vbroadcastss zmm0,DWORD PTR C_UNDERSCORE(MlasMinimumF32Value)[rip] + test rsi,rsi + jz .LReduceMaximum.ExitKernel + cmp rsi,16 + jb .LReduceMaximum.ProcessRemainingCountBy1 + cmp rsi,64 + jb .LReduceMaximum.ProcessRemainingCountBy16 + vmovaps zmm1,zmm0 + vmovaps zmm2,zmm0 + vmovaps zmm3,zmm0 + +.LReduceMaximum.ProcessRemainingCountBy64: + vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] + vmaxps zmm1,zmm1,ZMMWORD PTR [rdi+16*4] + sub rsi,64 + vmaxps zmm2,zmm2,ZMMWORD PTR [rdi+32*4] + vmaxps zmm3,zmm3,ZMMWORD PTR [rdi+48*4] + add rdi,64*4 # advance input by 64 elements + cmp rsi,64 + jae .LReduceMaximum.ProcessRemainingCountBy64 + vmaxps zmm0,zmm0,zmm1 # reduce to single vector + vmaxps zmm2,zmm2,zmm3 + vmaxps zmm0,zmm0,zmm2 + +.LReduceMaximum.ProcessRemainingCountBy16: + cmp rsi,16 + jb .LReduceMaximum.ProcessRemainingCountLessThan16 + vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] + sub rsi,16 + add rdi,16*4 # advance input by 16 elements + jmp .LReduceMaximum.ProcessRemainingCountBy16 + +.LReduceMaximum.ProcessRemainingCountLessThan16: + vextractf32x8 ymm1,zmm0,1 # reduce to single scalar + vmaxps ymm0,ymm0,ymm1 + vextractf128 xmm1,ymm0,1 + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0xEE + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0x55 + vmaxss xmm0,xmm0,xmm1 + test rsi,rsi + jz .LReduceMaximum.ExitKernel + +.LReduceMaximum.ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rdi] + add rdi,4 # advance input by 1 element + dec esi + jnz .LReduceMaximum.ProcessRemainingCountBy1 + +.LReduceMaximum.ExitKernel: + vzeroupper + ret + + .end diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format deleted file mode 100644 index 84b876706161..000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format +++ /dev/null @@ -1,7 +0,0 @@ -Language: Cpp -BasedOnStyle: Google -DerivePointerAlignment: false -ColumnLimit: 120 -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SortIncludes: false diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt deleted file mode 100644 index 5d9c5edf45a9..000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt +++ /dev/null @@ -1,33 +0,0 @@ -cmake_minimum_required(VERSION 3.5) - -project(jblas LANGUAGES CXX VERSION 0.1.0) - -file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) -file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) - -add_library(${PROJECT_NAME} INTERFACE) -add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) - -target_include_directories( - ${PROJECT_NAME} INTERFACE - "$" - "$" -) - -if(WIN32) - target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) - target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100) - #4068 ignore unroll and GCC flags - #4849 ignore collapse - #6262 ignore stack too large - #4702 unreachable code(false warning on constexpr condition) - #4100 unreferenced formal parameter - - target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size -endif(WIN32) - - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h deleted file mode 100644 index 143adb771760..000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include -#include -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" - -#define OFFSET(field) offsetof(params, field) - -namespace jblas { - -namespace xbyak { -class JitBase : protected Xbyak::CodeGenerator { - protected: - JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {} - - void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) { - xor_(reg, reg); - mov(reg.cvt32(), addr); - } - - void vreg_push(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i)); - } -#endif - } - - void vreg_pop(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]); - } -#endif - } - - void padto_le(const Xbyak::Reg64& _src, int padding) { - // _src=_src/padding*padding - if (padding == 1) { - return; - } - for (int i = 1; i < 16; i++) { - if ((1 << i) == padding) { - shr(_src, i); - shl(_src, i); - return; - } - } - assert(0); - } - - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - inLocalLabel(); - lea(_tmp, _total); - sub(_tmp, _pos); - cmp(_tmp, N); - jb(".maskflag"); - cmp(_tmp, 0); - jl(".zeroflag"); - uint64_t allmask = (static_cast(1) << N) - 1; - if (N == 64) { - allmask = static_cast(-1); - } - mov(_tmp, allmask); - kmovq(_msk, _tmp); - jmp(".maskend"); - L(".maskflag"); - mov(_tmp1, 1); - shlx(_tmp1, _tmp1, _tmp); - sub(_tmp1, 1); - kmovq(_msk, _tmp1); - jmp(".maskend"); - L(".zeroflag"); - mov(_tmp1, 0); - kmovq(_msk, _tmp1); - L(".maskend"); - outLocalLabel(); - } - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N); - } -}; - -class JitAvx : protected JitBase { - protected: - static int constexpr VBits = 256; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 16; - typedef Xbyak::Ymm vreg_t; -}; - -class JitAvx2 : protected JitAvx { - protected: - static int constexpr VBits = 256; - typedef Xbyak::Ymm vreg_t; - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } - - void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } -}; - -class JitAvx512f : protected JitAvx2 { - protected: - static int constexpr VBits = 512; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 32; - typedef Xbyak::Zmm vreg_t; - - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } - - void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { - vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); - vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); - vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6)); - vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - } - - void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) { - for (int i = 0; i < 8; ++i) { - vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]); - vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]); - } - - for (int i = 0; i < 4; ++i) { - vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]); - vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]); - } - - for (int i = 0; i < 2; ++i) { - vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88); - vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88); - vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88); - vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88); - vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd); - vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd); - vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd); - vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd); - } - - // last step and move out - for (int i = 0; i < N; ++i) { - vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd); - } - } - - void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) { - vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]); - vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]); - vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]); - vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]); - - vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]); - vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]); - vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]); - vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4); - vmovups(src_4regs[0], tmp_regs[1]); - vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[1], tmp_regs[3]); - vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14); - vmovups(src_4regs[2], tmp_regs[1]); - vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[3], tmp_regs[3]); - vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - } - - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { - vpsrld(_fp32, _fp32, 16); - vpmovdw(_bf16, _fp32); - } - - void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } - - void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) { - mov(tmp.cvt16(), addr); - shl(tmp.cvt32(), 16); - vpbroadcastd(dst, tmp.cvt32()); - } - - void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) { - auto bf16 = Xbyak::Ymm(_fp32.getIdx()); - cvt_fp32_bf16(bf16, _fp32); - vmovups(_add, bf16); - } -}; - -class JitAvx512_bf16 : protected JitAvx512f {}; - -class JitAvx512_fp16 : protected JitAvx512f {}; - -class JitAvx512vnni : protected JitAvx512f { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); - } -}; - -class JitAvxvnni : protected JitAvx2 { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::VexEncoding); - } -}; - -class JitAmxtile : protected JitAvx512f { - public: - struct alignas(64) tileconfig_t { - uint8_t palette_id; - uint8_t reserved[15]; - uint16_t colb[16]; - uint8_t rows[16]; - }; - static int constexpr TileCount = 8; - - typedef long long (*configure_t)(void*); - - static void generate_config(Xbyak::CodeGenerator* g) { - Xbyak::util::StackFrame st(g, 1, 0, 0); - auto& parambase = st.p[0]; - g->ldtilecfg(g->ptr[parambase]); - } - - static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, - int CNum) { - // Filling tile configure structure. Could be done offline. - tc.palette_id = 1; - // Configure C tiles - int t = 0; - for (; t < CNum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_N * 4); - } - // Configure A tiles - for (; t < CNum + ANum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_K * elesize); - } - // Configure B tile. B effectively has 64 rows and 16 columns. - int kpack = 4 / elesize; - for (; t < CNum + ANum + BNum; ++t) { - tc.rows[t] = static_cast(TILE_K / kpack); - tc.colb[t] = static_cast(TILE_N * 4); - } - } -}; - -class JitAmxbf16 : protected JitAmxtile { - protected: - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); } -}; - -class JitAmxint8 : protected JitAmxtile { - protected: - template - void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3); -}; -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbssd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbsud(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbusd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbuud(x1, x2, x3); -} -} // namespace xbyak -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h deleted file mode 100644 index 8ecf3535c17f..000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include -enum JBLAS_CODE { - JblasSuccess = 0, - JblasInvalidParam = 1, - JblasInvalidISA = 2, - JblasRuntimeError = 4, - JblasNotSupport = 8, -}; -enum JBLAS_ISA : uint32_t { - JblasNoSIMD = 0, - JblasAVX, - JblasAVX2, - JblasAVX_VNNI, - JblasAVX512F, - JblasAVX512_VNNI, - JblasAMX_BF16, - JblasAMX_INT8, - JblasAVX512_FP16, - JblasAVX512_BF16, -}; -enum class JBLAS_DTYPE : uint32_t { - EleBitsMask = 0xff, - EleBitsUndef = 0, - EleBits4 = 4, - EleBits8 = 8, - EleBits16 = 16, - EleBits32 = 32, - EleBits64 = 64, - TypeMask = 0xff00, - TypeFloat = 0 << 8, - TypeInt = 1 << 8, - SubTypeMask = 0xff0000, - SubType0 = 0 << 16, - SubType1 = 1 << 16, - SubType2 = 2 << 16, - F64 = EleBits64 | TypeFloat, - F32 = EleBits32 | TypeFloat, - F16 = EleBits16 | TypeFloat, - BF16 = EleBits16 | TypeFloat | SubType1, - F8_E4M3 = EleBits8 | TypeFloat, - F8_E5M2 = EleBits8 | TypeFloat | SubType1, - F8_E3M4 = EleBits8 | TypeFloat | SubType2, - S8 = EleBits8 | TypeInt, - U8 = EleBits8 | TypeInt | SubType1, - S4_CLIP = EleBits4 | TypeInt, - S4_FULLRANGE = EleBits4 | TypeInt | SubType1, - F4_E2M1 = EleBits4 | TypeFloat, - F4_BNB = EleBits4 | TypeFloat | SubType1, - F4_NF4 = EleBits4 | TypeFloat | SubType2, - S32 = EleBits32 | TypeInt, - U32 = EleBits32 | TypeInt | SubType1, -}; - -enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; -enum JBLAS_TRANSPOSE { - JblasNoTrans = 111, - JblasTrans = 112, - JblasConjTrans = 113, -}; -enum JBLAS_ELTWISEOP { - GELU, - SWISH, - TANH, - EXP, - LOW_PRECISION_EXP, - RELU, - LINEAR, -}; - -enum class JBLAS_PROLOGUEB_IDS : uint32_t { - Undef = (uint32_t)-1, - Begin = 0, - NormalBegin = Begin, - WeightPack = NormalBegin, - NormalEnd, - KBlockBegin = NormalEnd, - WeightKBlockS8 = KBlockBegin, - WeightKBlockS4, - WeightKBlockF4, - KBlockEnd, - End, -}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h deleted file mode 100644 index 5cac1080bc61..000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "jit_blas.h" -#include "xbyak/xbyak_util.h" - -namespace jblas { - -namespace device { - -struct X64_ISA { - int64_t MMX : 1; // 0 - int64_t SSE : 1; // 1 - int64_t SSE2 : 1; // 2 - int64_t SSE3 : 1; // 3 - int64_t SSSE3 : 1; // 4 - int64_t SSE41 : 1; // 5 - int64_t SSE42 : 1; // 6 - int64_t AVX : 1; // 7 - int64_t F16C : 1; // 8 - int64_t FMA : 1; // 9 - int64_t AVX2 : 1; // 10 - int64_t AVX_VNNI : 1; // 11 - int64_t AVX_VNNI_INT8 : 1; // 12 - int64_t AVX_NE_CONVERT : 1; // 13 - int64_t AVX_IFMA : 1; // 14 - int64_t AVX512F : 1; // 15 - int64_t AVX512BW : 1; // 16 - int64_t AVX512CD : 1; // 17 - int64_t AVX512DQ : 1; // 18 - int64_t AVX512ER : 1; // 19 - int64_t AVX512IFMA52 : 1; // 20 - int64_t AVX512PF : 1; // 21 - int64_t AVX512VL : 1; // 22 - int64_t AVX512VPOPCNTDQ : 1; // 23 - int64_t AVX512_4FMAPS : 1; // 24 - int64_t AVX512_4VNNIW : 1; // 25 - int64_t AVX512_BF16 : 1; // 26 - int64_t AVX512_BITALG : 1; // 27 - int64_t AVX512_VBMI : 1; // 28 - int64_t AVX512_VBMI2 : 1; // 29 - int64_t AVX512_VNNI : 1; // 30 - int64_t AVX512_VP2INTERSECT : 1; // 31 - int64_t AVX512_FP16 : 1; // 32 - int64_t AMX_TILE : 1; // 33 - int64_t AMX_BF16 : 1; // 34 - int64_t AMX_INT8 : 1; // 35 - int64_t AMX_FP16 : 1; // 36 - int64_t AMX_COMPLEX : 1; // 37 - int64_t reserved : (64 - 38); -}; - -class AVX2_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 0; - static constexpr bool AVX512BW = 0; - static constexpr bool AVX512CD = 0; - static constexpr bool AVX512DQ = 0; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 0; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 0; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class AVX512_VNNI_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class SapphireRapids { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 1; - static constexpr bool AMX_BF16 = 1; - static constexpr bool AMX_INT8 = 1; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -template -class isa_base { - public: - static bool constexpr avx = ISA_T >= JblasAVX; - static bool constexpr avx2 = ISA_T >= JblasAVX2; - static bool constexpr avx512f = ISA_T >= JblasAVX512F; - static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; - static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; - static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; - static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; -}; - -class CpuDevice { - public: - inline void setThreads(int _nth) { - if (_nth <= 0) { - numthreads = numcores; - } else { - numthreads = std::min(numcores, _nth); - } - } - inline int getThreads() { return numthreads; } - inline int getCores() { return numcores; } - inline uint32_t getL2CacheSize() { return L2Cache; } - inline uint32_t getL1CacheSize() { return L1Cache; } - inline bool AVX() { return mHasAVX; } - inline bool AVX2() { return mHasAVX2; } - inline bool AVX_VNNI() { return mHasAVX_VNNI; } - inline bool AVX512F() { return mHasAVX512F; } - inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } - inline bool AMX_INT8() { return mHasAMX_INT8; } - inline bool AMX_BF16() { return mHasAMX_BF16; } - inline bool AVX512_BF16() { return mHasAVX512_BF16; } - inline bool AVX512_FP16() { return mHasAVX512_FP16; } -#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) - CpuDevice() { - static Xbyak::util::Cpu _cpu; - L1Cache = _cpu.getDataCacheSize(0); - L2Cache = _cpu.getDataCacheSize(1); - ADD_FLAG(AVX); - ADD_FLAG(AVX2); - ADD_FLAG(AVX512F); - ADD_FLAG(AVX512_VNNI); - ADD_FLAG(AVX_VNNI); - ADD_FLAG(AMX_BF16); - ADD_FLAG(AMX_INT8); - ADD_FLAG(AVX512_BF16); - ADD_FLAG(AVX512_FP16); - numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); - numthreads = numcores; - } - - static CpuDevice* getInstance() { - static CpuDevice instance; - return &instance; - } - - void print() { - printf( - "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", - mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, - mHasAVX512_FP16); - } -#undef ADD_FLAG - - protected: - uint32_t L2Cache, L1Cache; - bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, - mHasAVX512_FP16; - int numcores; - int numthreads; -}; - -#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); - -class CpuBase { - public: - CpuBase() { - GetCPUDevice(); - mL2Cache = _cd->getL2CacheSize(); - mL1Cache = _cd->getL1CacheSize(); - mNumThreads = _cd->getThreads(); - } - size_t mL2Cache, mL1Cache; - int mNumThreads; -}; -} // namespace device -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h deleted file mode 100644 index ceb7a545092d..000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_base.h" -#include "jit_blas.h" -#include "jit_blas_utils.h" -#include "kernel_wrapper.h" - -namespace jblas { -namespace epilogue { -namespace gemm { - -template -class AccumulatorWriteBack { - public: - using SType = _SRC_T; - using DType = _DST_T; - struct Param { - DType* C; - int ldc; - void* elt_const_v; - }; - - template - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - bool constexpr Valid = !std::is_same::value || std::is_same::value; - static_assert(Valid, "fp32 to bf16 conversion only."); - if constexpr (std::is_same::value) { - return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (std::is_same, std::tuple>::value) { - return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (sizeof(SType) == sizeof(DType)) { - return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v, ops...); - } else { - assert(false); - } - } -}; - -template -class CustomAccumulatorWriteBackWithEltop { - public: - struct Param { - _DST_T* C; - int ldc; - void* elt_const_v; - }; - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { - return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v); - } else { - assert(false); - } - } -}; -template -using AccumulatorWriteBackFp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackInt32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackBf16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; - -template -using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; - -template -using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; - -template -class AlphaBetaProcessFp32 { - public: - struct Param { - float *C, *D; - int ldc, ldd; - float alpha, beta; - }; - - JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto DOffset = M_offset * _param.ldd + N_offset; - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto dptr = _param.D + DOffset; - return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, - dptr, _param.ldd, cptr, _param.ldc, M, N); - } -}; - -template -class CompFp32BlockEpilogue { - public: - struct Param { - void* scales; - JBLAS_DTYPE scaledtype; - int ldsb; - int8_t* zps = nullptr; - float* reduce = nullptr; - int ldra; - }; - JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - auto ret = JblasNotSupport; - if (_param.scaledtype == JBLAS_DTYPE::F32) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(ret == JblasSuccess); - if (_param.zps != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( - dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, - _param.reduce + M_offset * _param.ldra + K_offset); - } - assert(ret == JblasSuccess); - return ret; - } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(_param.zps == nullptr); - assert(ret == JblasSuccess); - return ret; - } - return JblasNotSupport; - } -}; - -template -class DequantInt32ToFp32 { - public: - struct Param { - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, _param.ldsa, - _param.scalesB + N_offset); - } -}; - -template -class CompInt8BlockEpilogue { - public: - struct Param { - void* scalesB; - JBLAS_DTYPE scaleBdtype; - int ldsb; - float* scalesA; - int ldsa; - // optional if A asym - uint8_t* zpA = nullptr; - void* reduceB = nullptr; - JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - JBLAS_CODE ret = JblasNotSupport; - float* scab = nullptr; - size_t ScaleBTmpSize = N * sizeof(float); - size_t ReduceBTmpSize = N * sizeof(float); - assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); - if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { - auto scache = reinterpret_cast(tmpcache); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - scab = scache; - } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; - } - float* redb = nullptr; - if (_param.reduceB) { - if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { - auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - redb = rcache; - } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { - redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; - } - } - ret = kernel::wrapper::DequanS32Fp32::template forward( - srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); - assert(ret == JblasSuccess); - ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, - dstptr, cachestep, M, N); - assert(ret == JblasSuccess); - - if (_param.zpA == nullptr) { - if (_param.zpB == nullptr) { - return ret; - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa + K_offset); - } - } else { - if (_param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, - _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); - } - } - return ret; - } -}; - -template -class ZpDequantInt32ToFp32 { - public: - struct Param { - // necessary - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - // optional if A asym - uint8_t* zpA = nullptr; - float* reduceB = nullptr; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.scalesB + N_offset); - if (ret != JblasSuccess) { - return ret; - } - if (_param.zpA == nullptr && _param.zpB == nullptr) { - return ret; - } else if (_param.zpA != nullptr && _param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.reduceB + N_offset); - } else if (_param.zpA == nullptr && _param.zpB != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, - _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, - _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); - } - return ret; - } -}; - -template -class AlphaBetaProcessS32U8 { - public: - struct Param { - uint8_t* C; - int ldc; - float alpha; - float scaleAcc, scaleC; - int zpC; - }; - - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, - M, N, _param.scaleAcc, _param.scaleC, _param.zpC); - } -}; - -} // namespace gemm -} // namespace epilogue -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h deleted file mode 100644 index 364da9223940..000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h +++ /dev/null @@ -1,2699 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_blas_utils.h" -#include "jit_base.h" - -namespace jblas { -namespace gemm { -enum class CompType : uint32_t { - COMP_FP32 = 0, - COMP_BF16_FP32 = 1, - COMP_FP16_FP16 = 2, - COMP_INT_START = 3, - COMP_INT8_US_INT32 = COMP_INT_START, - COMP_INT8_UU_INT32 = 4, - COMP_INT8_SS_INT32 = 5, - COMP_INT8_SU_INT32 = 6, - COMP_INT16_SS_INT32 = 7, - COMP_INT8_US_FP32 = 8, - COMP_INT8_UU_FP32 = 9, - COMP_INT8_SS_FP32 = 10, - COMP_INT8_SU_FP32 = 11, -}; - -class CoreAttr { - public: - // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| - static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, - COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; - - static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } - static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { - return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); - } - - static void parse_id(uint32_t id, uint32_t* vals) { - vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); - vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); - vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); - } - - static const char* to_str(uint32_t id) { - static char tmp[128]; - uint32_t vals[4]; - parse_id(id, vals); - sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); - return tmp; - } - - static inline size_t get_bsize(uint32_t id) { - auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - return size_t(4 / packrow); - } -}; - -namespace code { - -template -class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { - public: - static int constexpr RegLen = 8, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { - public: - static int constexpr RegLen = 32, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; - typedef utils::fp16 AType; - typedef utils::fp16 BType; - typedef utils::fp16 CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { - public: - static int constexpr RegLen = 8, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; - -template -class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; - static uint32_t constexpr COMPUTE = - (uint32_t)(std::is_same_v - ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 - : std::is_same_v ? CompType::COMP_INT8_US_INT32 - : CompType::COMP_INT8_UU_INT32); - using AType = AT; - using BType = BT; - typedef int32_t CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; -template -using Amxint8N16P4US = Amxint8N16P4; - -template -using Amxint8N16P4SS = Amxint8N16P4; - -class AmxConfigure : protected jblas::xbyak::JitAmxtile { - public: - typedef long long (*func_t)(tileconfig_t*); - - static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { - static AmxConfigure code; - tileconfig_t cfg; - std::memset(&cfg, 0, sizeof(cfg)); - configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); - code.mKernel(&cfg); - } - - protected: - AmxConfigure() { - generate_config(this); - mKernel = getCode(); - } - - func_t mKernel = nullptr; -}; - -namespace kblock { -// optimize for kblock gemm, each block size in k dimension has dequant operation -// all accumulators use fp32 dtype. -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; - typedef uint8_t AType; - typedef int8_t BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - uint8_t* zpA; - float* scaleA; - int ldsa; - float* scaleB; - float* reduceB; - int ldsb; - int k; - int n; - int kblock; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_iterkb; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_tmp4; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = NRegs; - CReg = 0; - CF32Reg = CReg + CRegCount; - BReg = CF32Reg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg < RegCount); - TmpRegCount = RegCount - TmpReg; - assert(TmpRegCount >= 1); - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_iterkb = st.t[12]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_tmp4 = st.t[11]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - xor_(reg_iterkb, reg_iterkb); - L(".kloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - xor_(reg_tmp2, reg_tmp2); - load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); - mov(reg_tmp, reg_tmp3); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kbloop", T_NEAR); - L(".unkbloop"); - generate_fma(_mtile, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_tmp2, KUNROLL * KTILE); - cmp(reg_tmp2, reg_tmp); - jb(".unkbloop"); - cmp(reg_tmp, reg_tmp3); - jge(".kend", T_NEAR); - L(".kbloop"); - generate_fma(_mtile, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_tmp2, 1 * KTILE); - cmp(reg_tmp2, reg_tmp3); - jb(".kbloop"); - L(".kend"); - add(reg_iterk, reg_tmp2); - generate_f32_accumulate(_mtile); - generate_zp_correction(_mtile); - inc(reg_iterkb); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { - for (int kk = 0; kk < _ktile; kk++) { - lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void generate_f32_accumulate(int _mtile) { - load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - - mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); - load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); - lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); - for (int i = 0; i < NRegs; i++) { - vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); - vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); - vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - } - } - } - - void generate_zp_correction(int _mtile) { - load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp1, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - auto& reg_redB = reg_tmp2; - - mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); - auto& reg_zpA = reg_tmp; - - mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); - auto& reg_scaleA = reg_tmp1; - - load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); - auto& reg_ldsa = reg_tmp3; - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); - } - - for (int i = 0; i < _mtile; i++) { - vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); - vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); - vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); - for (int j = 0; j < NRegs; j++) { - vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); - vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); - } - lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); - lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); - } - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -} // namespace kblock -} // namespace code -template