diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 734ce3a8587d..235968219a34 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -19,9 +19,6 @@ on: runner-env: required: true type: string - gcs-dir: - required: true - type: string build-dir: required: true type: string @@ -41,6 +38,14 @@ on: required: true type: string +env: + # This duplicates the variable from ci.yml. The variable needs to be in env + # instead of the outputs of setup because it contains the run attempt and we + # want that to be the current attempt, not whatever attempt the setup step + # last ran in. It therefore can't be passed in via inputs because the env + # context isn't available there. + GCS_DIR: gs://iree-github-actions-${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }}-artifacts/${{ github.run_id }}/${{ github.run_attempt }} + jobs: build_suites: runs-on: @@ -89,7 +94,7 @@ jobs: target: - platform: "riscv" architecture: "rv64" - docker_image: "gcr.io/iree-oss/riscv@sha256:720bc0215d8462ea14352edc22710a6ce4c0c1daff581d179dd173885f1d8a35" + docker_image: "gcr.io/iree-oss/riscv@sha256:d6f0e293a50faf5abbd564c1d1bb9dc6456d7ce93d07b131c381fa64c1daed62" outputs: benchmark-tools-gcs-artifacts: ${{ toJSON(steps.upload.outputs) }} env: @@ -141,7 +146,7 @@ jobs: id: upload env: BENCHMARK_TOOLS_ARCHIVE: ${{ steps.archive.outputs.benchmark-tools-archive }} - BENCHMARK_TOOLS_GCS_ARTIFACT: ${{ inputs.gcs-dir }}/${{ steps.archive.outputs.benchmark-tools-archive }} + BENCHMARK_TOOLS_GCS_ARTIFACT: ${{ env.GCS_DIR }}/${{ steps.archive.outputs.benchmark-tools-archive }} run: | gcloud alpha storage cp "${BENCHMARK_TOOLS_ARCHIVE}" "${BENCHMARK_TOOLS_GCS_ARTIFACT}" echo "::set-output name=${PLATFORM}-${ARCHITECTURE}-benchmark-tools-gcs-artifact::${BENCHMARK_TOOLS_GCS_ARTIFACT}" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 652db665e45c..541948000073 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,6 +54,10 @@ env: # See note above regarding lack of proper variables. Also see note about # pseudo-ternary hack. _CI_STAGE: ${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }} + # This needs to be in env instead of the outputs of setup because it contains + # the run attempt and we want that to be the current attempt, not whatever + # attempt the setup step last ran in. + GCS_DIR: gs://iree-github-actions-${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }}-artifacts/${{ github.run_id }}/${{ github.run_attempt }} # Jobs are organized into groups and topologically sorted by dependencies jobs: @@ -68,7 +72,6 @@ jobs: should-run: ${{ steps.should-run.outputs.should-run }} # Variables for dependent jobs. See comment at top. runner-env: prod - gcs-dir: gs://iree-github-actions-${{ env._CI_STAGE }}-artifacts/${{ github.run_id }}/${{ github.run_attempt }} runner-group: ${{ env._CI_STAGE }} # Note that we can't flip the condition here because 0 is falsey. See # comment at top. @@ -151,7 +154,7 @@ jobs: id: upload env: BUILD_DIR_ARCHIVE: ${{ steps.archive.outputs.build-dir-archive }} - BUILD_DIR_GCS_ARTIFACT: ${{ needs.setup.outputs.gcs-dir }}/${{ steps.archive.outputs.build-dir-archive }} + BUILD_DIR_GCS_ARTIFACT: ${{ env.GCS_DIR }}/${{ steps.archive.outputs.build-dir-archive }} run: | gcloud alpha storage cp "${BUILD_DIR_ARCHIVE}" "${BUILD_DIR_GCS_ARTIFACT}" echo "::set-output name=build-dir-gcs-artifact::${BUILD_DIR_GCS_ARTIFACT}" @@ -180,7 +183,7 @@ jobs: run: | ./build_tools/github_actions/docker_run.sh \ --env "IREE_BAZEL_WRITE_REMOTE_CACHE=${IREE_BAZEL_WRITE_REMOTE_CACHE}" \ - gcr.io/iree-oss/frontends-swiftshader@sha256:41e516b8c1b432e3c02896c4bf4b7f06df6a67371aa167b88767b8d4d2018ea6 \ + gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 \ ./build_tools/bazel/build_core.sh test_all: @@ -343,7 +346,7 @@ jobs: ./build_tools/github_actions/docker_run.sh \ --env "IREE_BAZEL_WRITE_REMOTE_CACHE=${IREE_BAZEL_WRITE_REMOTE_CACHE}" \ --env "IREE_TF_BINARIES_OUTPUT_DIR=${IREE_TF_BINARIES_OUTPUT_DIR}" \ - gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 \ + gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 \ build_tools/cmake/build_tf_binaries.sh echo "::set-output name=binaries-dir::${IREE_TF_BINARIES_OUTPUT_DIR}" - name: "Creating archive of binaries" @@ -358,7 +361,7 @@ jobs: id: upload env: BINARIES_ARCHIVE: ${{ steps.archive.outputs.binaries-archive }} - BINARIES_GCS_ARTIFACT: ${{ needs.setup.outputs.gcs-dir }}/${{ steps.archive.outputs.binaries-archive }} + BINARIES_GCS_ARTIFACT: ${{ env.GCS_DIR }}/${{ steps.archive.outputs.binaries-archive }} run: | gcloud alpha storage cp "${BINARIES_ARCHIVE}" "${BINARIES_GCS_ARTIFACT}" echo "::set-output name=binaries-gcs-artifact::${BINARIES_GCS_ARTIFACT}" @@ -398,7 +401,7 @@ jobs: - name: "Running TF integrations tests" run: | ./build_tools/github_actions/docker_run.sh \ - gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 \ + gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 \ build_tools/cmake/run_tf_tests.sh \ "${BUILD_DIR}" @@ -440,7 +443,7 @@ jobs: --env IREE_LLVM_CPU_DISABLE=1 \ --gpus all \ --env NVIDIA_DRIVER_CAPABILITIES=all \ - gcr.io/iree-oss/frontends-nvidia@sha256:e934ed09e9e60c28ebe11a02f37a993dd975db40118d410c4279d0fa2d4e6b9a \ + gcr.io/iree-oss/frontends-nvidia@sha256:28cd43f36b1ca0633bbd915911abe6d22b4aa16093f074e87016305322a0eba1 \ bash -euo pipefail -c \ "./build_tools/scripts/check_cuda.sh ./build_tools/scripts/check_vulkan.sh @@ -494,9 +497,10 @@ jobs: if: needs.setup.outputs.should-run == 'true' uses: ./.github/workflows/benchmarks.yml with: + # env.GCS_DIR is also duplicated in this workflow. See the note there on + # why this is. runner-group: ${{ needs.setup.outputs.runner-group }} runner-env: ${{ needs.setup.outputs.runner-env }} - gcs-dir: ${{ needs.setup.outputs.gcs-dir }} build-dir: ${{ needs.build_all.outputs.build-dir }} build-dir-archive: ${{ needs.build_all.outputs.build-dir-archive }} build-dir-gcs-artifact: ${{ needs.build_all.outputs.build-dir-gcs-artifact }} @@ -536,7 +540,7 @@ jobs: build_tools/github_actions/docker_run.sh \ --env "ANDROID_ABI=${ANDROID_ABI}" \ --env "IREE_HOST_BINARY_ROOT=${BUILD_DIR}/install" \ - gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5 \ + gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab \ build_tools/cmake/build_android.sh riscv32: @@ -570,7 +574,7 @@ jobs: --env "BUILD_RISCV_DIR=${BUILD_RISCV_DIR}" \ --env "BUILD_PRESET=test" \ --env "IREE_HOST_BINARY_ROOT=${BUILD_DIR}/install" \ - gcr.io/iree-oss/riscv@sha256:720bc0215d8462ea14352edc22710a6ce4c0c1daff581d179dd173885f1d8a35 \ + gcr.io/iree-oss/riscv@sha256:d6f0e293a50faf5abbd564c1d1bb9dc6456d7ce93d07b131c381fa64c1daed62 \ bash -euo pipefail -c \ "./build_tools/cmake/build_riscv.sh && tests/riscv32/smoke.sh" @@ -614,7 +618,7 @@ jobs: --env "IREE_HOST_BINARY_ROOT=${BUILD_DIR}/install" \ --env "IREE_IMPORT_TFLITE_BIN=${TF_BINARIES_DIR}/iree-import-tflite" \ --env "LLVM_BIN_DIR=${BUILD_DIR}/third_party/llvm-project/llvm/bin" \ - gcr.io/iree-oss/riscv@sha256:720bc0215d8462ea14352edc22710a6ce4c0c1daff581d179dd173885f1d8a35 \ + gcr.io/iree-oss/riscv@sha256:d6f0e293a50faf5abbd564c1d1bb9dc6456d7ce93d07b131c381fa64c1daed62 \ bash -euo pipefail -c \ "./build_tools/cmake/build_riscv.sh && ./build_tools/cmake/test_riscv64.sh" diff --git a/CMakeLists.txt b/CMakeLists.txt index 434de60a8db2..416578f3aa74 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,7 @@ option(IREE_BUILD_MICROBENCHMARKS "Builds IREE microbenchmark suites." OFF) option(IREE_BUILD_EXPERIMENTAL_REMOTING "Builds experimental remoting support." OFF) option(IREE_BUILD_EXPERIMENTAL_VMVX_MMT4D "Enables MMT4D methods in the VMVX module." OFF) option(IREE_BUILD_EXPERIMENTAL_WEB_SAMPLES "Builds experimental web samples." OFF) +option(IREE_BUILD_EXPERIMENTAL_PYTHON_GENERATED_BENCHMARKS "Builds IREE benchmark suites generated by the python benchmark framework." OFF) #------------------------------------------------------------------------------- # Runtime HAL Driver Options @@ -368,6 +369,7 @@ include(iree_benchmark_suite) include(iree_microbenchmark_suite) include(iree_hal_cts_test_suite) include(iree_static_linker_test) +include(iree_fetch_artifact) set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) @@ -703,7 +705,7 @@ endif() # IREE top-level targets #------------------------------------------------------------------------------- -if(IREE_BUILD_BENCHMARKS) +if(IREE_BUILD_BENCHMARKS OR IREE_BUILD_EXPERIMENTAL_PYTHON_GENERATED_BENCHMARKS) # Add top-level custom targets to drive generating benchmark suites. # iree-benchmark-import-models imports benchmark models from their source @@ -746,7 +748,7 @@ add_subdirectory(build_tools/embed_data/) # Note: Test deps are not built as part of all (use the iree-test-deps target). add_subdirectory(tests EXCLUDE_FROM_ALL) -if(IREE_BUILD_BENCHMARKS) +if(IREE_BUILD_BENCHMARKS OR IREE_BUILD_EXPERIMENTAL_PYTHON_GENERATED_BENCHMARKS) find_program(IREE_IMPORT_TFLITE_PATH iree-import-tflite) if(IREE_IMPORT_TFLITE_PATH) message(STATUS "Found ${IREE_IMPORT_TFLITE_PATH} to generate benchmark artifacts") @@ -829,6 +831,12 @@ set(IREE_PUBLIC_INCLUDE_DIRS "${IREE_COMMON_INCLUDE_DIRS}" add_subdirectory(build_tools/benchmarks) +#------------------------------------------------------------------------------- +# IREE build tools python modules +#------------------------------------------------------------------------------- + +add_subdirectory(build_tools/python) + #------------------------------------------------------------------------------- # Samples #------------------------------------------------------------------------------- diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 0e9c88bf96c6..08986b49d598 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -1 +1,5 @@ iree_add_all_subdirs() + +if (IREE_BUILD_EXPERIMENTAL_PYTHON_GENERATED_BENCHMARKS) + include(generated_benchmark_suites.cmake) +endif() diff --git a/benchmarks/TFLite/android-adreno.cmake b/benchmarks/TFLite/android-adreno.cmake index d737ecec496f..43989d573974 100644 --- a/benchmarks/TFLite/android-adreno.cmake +++ b/benchmarks/TFLite/android-adreno.cmake @@ -87,7 +87,7 @@ iree_benchmark_suite( "GPU-Adreno" COMPILATION_FLAGS ${ANDROID_ADRENO_GPU_COMPILATION_FLAGS} - "--iree-flow-enable-fuse-padding-into-consumer-ops" + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops" BENCHMARK_TOOL iree-benchmark-module CONFIG @@ -129,7 +129,7 @@ iree_benchmark_suite( "GPU-Adreno" COMPILATION_FLAGS ${ANDROID_ADRENO_GPU_COMPILATION_FLAGS} - "--iree-flow-enable-fuse-padding-into-consumer-ops" + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops" "--iree-hal-benchmark-dispatch-repeat-count=16" BENCHMARK_TOOL iree-benchmark-module diff --git a/benchmarks/TFLite/android-mali.cmake b/benchmarks/TFLite/android-mali.cmake index 766b5ff71ea2..a849e0f0fdb1 100644 --- a/benchmarks/TFLite/android-mali.cmake +++ b/benchmarks/TFLite/android-mali.cmake @@ -114,7 +114,7 @@ iree_benchmark_suite( "GPU-Mali-Valhall" COMPILATION_FLAGS ${ANDROID_MALI_GPU_COMPILATION_FLAGS} - "--iree-flow-enable-fuse-padding-into-consumer-ops" + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops" BENCHMARK_TOOL iree-benchmark-module CONFIG @@ -140,7 +140,7 @@ iree_benchmark_suite( "--iree-input-type=tosa" "--iree-flow-demote-f32-to-f16" "--iree-vulkan-target-triple=valhall-unknown-android31" - "--iree-flow-enable-fuse-padding-into-consumer-ops" + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops" BENCHMARK_TOOL iree-benchmark-module CONFIG @@ -180,7 +180,7 @@ iree_benchmark_suite( "GPU-Mali-Valhall" COMPILATION_FLAGS ${ANDROID_MALI_GPU_COMPILATION_FLAGS} - "--iree-flow-enable-fuse-padding-into-consumer-ops" + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops" "--iree-hal-benchmark-dispatch-repeat-count=32" BENCHMARK_TOOL iree-benchmark-module @@ -209,7 +209,7 @@ iree_benchmark_suite( "--iree-input-type=tosa" "--iree-flow-demote-f32-to-f16" "--iree-vulkan-target-triple=valhall-unknown-android31" - "--iree-flow-enable-fuse-padding-into-consumer-ops" + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops" "--iree-hal-benchmark-dispatch-repeat-count=32" BENCHMARK_TOOL iree-benchmark-module diff --git a/benchmarks/generated_benchmark_suites.cmake b/benchmarks/generated_benchmark_suites.cmake new file mode 100644 index 000000000000..56189f430b25 --- /dev/null +++ b/benchmarks/generated_benchmark_suites.cmake @@ -0,0 +1,451 @@ +################################################################################ +# Autogenerated by build_tools/benchmarks/generate_cmake_benchmark_suites.py # +# To update the benchmarks, modify the files in build_tools/benchmarks/suites/ # +# and regenerate this file. # +################################################################################ + +################################################################################ +# Defines the required variables # +################################################################################ +iree_package_name(_PACKAGE_NAME) +set(_ROOT_ARTIFACTS_DIR "${IREE_BINARY_DIR}/benchmark_suites") +set(_MODEL_ARTIFACTS_DIR "${_ROOT_ARTIFACTS_DIR}/models") +set(_IREE_ARTIFACTS_DIR "${_ROOT_ARTIFACTS_DIR}/iree") + +################################################################################ +# Below is generated by build_tools/benchmarks/suites/cmake_rule_generator.py # +################################################################################ +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/deeplabv3.tflite" +iree_fetch_artifact( + NAME + "model-c36c63b0-220a-4d78-8ade-c45ce47d89d3" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/deeplabv3.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/c36c63b0-220a-4d78-8ade-c45ce47d89d3_DeepLabV3_fp32.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/mobile_ssd_v2_float_coco.tflite" +iree_fetch_artifact( + NAME + "model-0e466f69-91d6-4e50-b62b-a82b6213a231" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/mobile_ssd_v2_float_coco.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/0e466f69-91d6-4e50-b62b-a82b6213a231_MobileSSD_fp32.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/posenet.tflite" +iree_fetch_artifact( + NAME + "model-5afc3014-d29d-4e88-a840-fbaf678acf2b" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/posenet.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/5afc3014-d29d-4e88-a840-fbaf678acf2b_PoseNet_fp32.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-float.tflite" +iree_fetch_artifact( + NAME + "model-cc69d69f-6d1f-4a1a-a31e-e021888d0d28" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-float.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/cc69d69f-6d1f-4a1a-a31e-e021888d0d28_MobileBertSquad_fp32.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-quant.tflite" +iree_fetch_artifact( + NAME + "model-e3997104-a3d2-46b4-9fbf-39069906d123" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-quant.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/e3997104-a3d2-46b4-9fbf-39069906d123_MobileBertSquad_int8.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/mobilebertsquad.tflite" +iree_fetch_artifact( + NAME + "model-73a0402e-271b-4aa8-a6a5-ac05839ca569" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/mobilebertsquad.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/73a0402e-271b-4aa8-a6a5-ac05839ca569_MobileBertSquad_fp16.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v1_224_1.0_float.tflite" +iree_fetch_artifact( + NAME + "model-78eab9e5-9ff1-4769-9b55-933c81cc9a0f" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v1_224_1.0_float.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/78eab9e5-9ff1-4769-9b55-933c81cc9a0f_MobileNetV1_fp32.0_float.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite" +iree_fetch_artifact( + NAME + "model-7d45f8e5-bb5e-48d0-928d-8f125104578f" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/7d45f8e5-bb5e-48d0-928d-8f125104578f_MobileNetV2_fp32.0_224.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/MobileNetV3SmallStaticBatch.tflite" +iree_fetch_artifact( + NAME + "model-58855e40-eba9-4a71-b878-6b35e3460244" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/MobileNetV3SmallStaticBatch.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/58855e40-eba9-4a71-b878-6b35e3460244_MobileNetV3Small_fp32.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/person_detect.tflite" +iree_fetch_artifact( + NAME + "model-bc1338be-e3df-44fd-82e4-40ba9560a073" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/person_detect.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/bc1338be-e3df-44fd-82e4-40ba9560a073_PersonDetect_int8.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/efficientnet_lite0_int8_2.tflite" +iree_fetch_artifact( + NAME + "model-4a6f545e-1b4e-41a5-9236-792aa578184b" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/efficientnet_lite0_int8_2.tflite" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/4a6f545e-1b4e-41a5-9236-792aa578184b_EfficientNet_int8.tflite" + UNPACK +) + +# Fetch the model from "https://storage.googleapis.com/iree-model-artifacts/minilm-l12-h384-uncased-seqlen128-tf-model.tar.gz" +iree_fetch_artifact( + NAME + "model-ecf5c970-ee97-49f0-a4ed-df1f34e9d493" + SOURCE_URL + "https://storage.googleapis.com/iree-model-artifacts/minilm-l12-h384-uncased-seqlen128-tf-model.tar.gz" + OUTPUT + "${_MODEL_ARTIFACTS_DIR}/ecf5c970-ee97-49f0-a4ed-df1f34e9d493_MiniLML12H384Uncased" + UNPACK +) + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/c36c63b0-220a-4d78-8ade-c45ce47d89d3_DeepLabV3_fp32.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-c36c63b0-220a-4d78-8ade-c45ce47d89d3" + SOURCE "${_MODEL_ARTIFACTS_DIR}/c36c63b0-220a-4d78-8ade-c45ce47d89d3_DeepLabV3_fp32.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/c36c63b0-220a-4d78-8ade-c45ce47d89d3_DeepLabV3_fp32/DeepLabV3_fp32.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-c36c63b0-220a-4d78-8ade-c45ce47d89d3") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/0e466f69-91d6-4e50-b62b-a82b6213a231_MobileSSD_fp32.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-0e466f69-91d6-4e50-b62b-a82b6213a231" + SOURCE "${_MODEL_ARTIFACTS_DIR}/0e466f69-91d6-4e50-b62b-a82b6213a231_MobileSSD_fp32.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/0e466f69-91d6-4e50-b62b-a82b6213a231_MobileSSD_fp32/MobileSSD_fp32.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-0e466f69-91d6-4e50-b62b-a82b6213a231") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/5afc3014-d29d-4e88-a840-fbaf678acf2b_PoseNet_fp32.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-5afc3014-d29d-4e88-a840-fbaf678acf2b" + SOURCE "${_MODEL_ARTIFACTS_DIR}/5afc3014-d29d-4e88-a840-fbaf678acf2b_PoseNet_fp32.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/5afc3014-d29d-4e88-a840-fbaf678acf2b_PoseNet_fp32/PoseNet_fp32.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-5afc3014-d29d-4e88-a840-fbaf678acf2b") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/cc69d69f-6d1f-4a1a-a31e-e021888d0d28_MobileBertSquad_fp32.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-cc69d69f-6d1f-4a1a-a31e-e021888d0d28" + SOURCE "${_MODEL_ARTIFACTS_DIR}/cc69d69f-6d1f-4a1a-a31e-e021888d0d28_MobileBertSquad_fp32.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/cc69d69f-6d1f-4a1a-a31e-e021888d0d28_MobileBertSquad_fp32/MobileBertSquad_fp32.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-cc69d69f-6d1f-4a1a-a31e-e021888d0d28") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/e3997104-a3d2-46b4-9fbf-39069906d123_MobileBertSquad_int8.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-e3997104-a3d2-46b4-9fbf-39069906d123" + SOURCE "${_MODEL_ARTIFACTS_DIR}/e3997104-a3d2-46b4-9fbf-39069906d123_MobileBertSquad_int8.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/e3997104-a3d2-46b4-9fbf-39069906d123_MobileBertSquad_int8/MobileBertSquad_int8.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-e3997104-a3d2-46b4-9fbf-39069906d123") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/73a0402e-271b-4aa8-a6a5-ac05839ca569_MobileBertSquad_fp16.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-73a0402e-271b-4aa8-a6a5-ac05839ca569" + SOURCE "${_MODEL_ARTIFACTS_DIR}/73a0402e-271b-4aa8-a6a5-ac05839ca569_MobileBertSquad_fp16.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/73a0402e-271b-4aa8-a6a5-ac05839ca569_MobileBertSquad_fp16/MobileBertSquad_fp16.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-73a0402e-271b-4aa8-a6a5-ac05839ca569") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/78eab9e5-9ff1-4769-9b55-933c81cc9a0f_MobileNetV1_fp32.0_float.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-78eab9e5-9ff1-4769-9b55-933c81cc9a0f" + SOURCE "${_MODEL_ARTIFACTS_DIR}/78eab9e5-9ff1-4769-9b55-933c81cc9a0f_MobileNetV1_fp32.0_float.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/78eab9e5-9ff1-4769-9b55-933c81cc9a0f_MobileNetV1_fp32/MobileNetV1_fp32.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-78eab9e5-9ff1-4769-9b55-933c81cc9a0f") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/7d45f8e5-bb5e-48d0-928d-8f125104578f_MobileNetV2_fp32.0_224.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-7d45f8e5-bb5e-48d0-928d-8f125104578f" + SOURCE "${_MODEL_ARTIFACTS_DIR}/7d45f8e5-bb5e-48d0-928d-8f125104578f_MobileNetV2_fp32.0_224.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/7d45f8e5-bb5e-48d0-928d-8f125104578f_MobileNetV2_fp32/MobileNetV2_fp32.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-7d45f8e5-bb5e-48d0-928d-8f125104578f") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/58855e40-eba9-4a71-b878-6b35e3460244_MobileNetV3Small_fp32.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-58855e40-eba9-4a71-b878-6b35e3460244" + SOURCE "${_MODEL_ARTIFACTS_DIR}/58855e40-eba9-4a71-b878-6b35e3460244_MobileNetV3Small_fp32.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/58855e40-eba9-4a71-b878-6b35e3460244_MobileNetV3Small_fp32/MobileNetV3Small_fp32.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-58855e40-eba9-4a71-b878-6b35e3460244") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/bc1338be-e3df-44fd-82e4-40ba9560a073_PersonDetect_int8.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-bc1338be-e3df-44fd-82e4-40ba9560a073" + SOURCE "${_MODEL_ARTIFACTS_DIR}/bc1338be-e3df-44fd-82e4-40ba9560a073_PersonDetect_int8.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/bc1338be-e3df-44fd-82e4-40ba9560a073_PersonDetect_int8/PersonDetect_int8.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-bc1338be-e3df-44fd-82e4-40ba9560a073") + +# Import the TFLite model "${_MODEL_ARTIFACTS_DIR}/4a6f545e-1b4e-41a5-9236-792aa578184b_EfficientNet_int8.tflite" +iree_import_tflite_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-4a6f545e-1b4e-41a5-9236-792aa578184b" + SOURCE "${_MODEL_ARTIFACTS_DIR}/4a6f545e-1b4e-41a5-9236-792aa578184b_EfficientNet_int8.tflite" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/4a6f545e-1b4e-41a5-9236-792aa578184b_EfficientNet_int8/EfficientNet_int8.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-4a6f545e-1b4e-41a5-9236-792aa578184b") + +# Import the Tensorflow model "${_MODEL_ARTIFACTS_DIR}/ecf5c970-ee97-49f0-a4ed-df1f34e9d493_MiniLML12H384Uncased" +iree_import_tf_model( + TARGET_NAME "${_PACKAGE_NAME}_iree-import-model-ecf5c970-ee97-49f0-a4ed-df1f34e9d493" + SOURCE "${_MODEL_ARTIFACTS_DIR}/ecf5c970-ee97-49f0-a4ed-df1f34e9d493_MiniLML12H384Uncased" + ENTRY_FUNCTION "predict" + OUTPUT_MLIR_FILE "${_IREE_ARTIFACTS_DIR}/ecf5c970-ee97-49f0-a4ed-df1f34e9d493_MiniLML12H384Uncased/MiniLML12H384Uncased.mlir" +) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "${_PACKAGE_NAME}_iree-import-model-ecf5c970-ee97-49f0-a4ed-df1f34e9d493") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/c36c63b0-220a-4d78-8ade-c45ce47d89d3_DeepLabV3_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-c36c63b0-220a-4d78-8ade-c45ce47d89d3-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/c36c63b0-220a-4d78-8ade-c45ce47d89d3_DeepLabV3_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/c36c63b0-220a-4d78-8ade-c45ce47d89d3_DeepLabV3_fp32/DeepLabV3_fp32.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-c36c63b0-220a-4d78-8ade-c45ce47d89d3" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-c36c63b0-220a-4d78-8ade-c45ce47d89d3-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/0e466f69-91d6-4e50-b62b-a82b6213a231_MobileSSD_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-0e466f69-91d6-4e50-b62b-a82b6213a231-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/0e466f69-91d6-4e50-b62b-a82b6213a231_MobileSSD_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/0e466f69-91d6-4e50-b62b-a82b6213a231_MobileSSD_fp32/MobileSSD_fp32.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-0e466f69-91d6-4e50-b62b-a82b6213a231" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-0e466f69-91d6-4e50-b62b-a82b6213a231-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/5afc3014-d29d-4e88-a840-fbaf678acf2b_PoseNet_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-5afc3014-d29d-4e88-a840-fbaf678acf2b-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/5afc3014-d29d-4e88-a840-fbaf678acf2b_PoseNet_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/5afc3014-d29d-4e88-a840-fbaf678acf2b_PoseNet_fp32/PoseNet_fp32.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-5afc3014-d29d-4e88-a840-fbaf678acf2b" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-5afc3014-d29d-4e88-a840-fbaf678acf2b-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/cc69d69f-6d1f-4a1a-a31e-e021888d0d28_MobileBertSquad_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-cc69d69f-6d1f-4a1a-a31e-e021888d0d28-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/cc69d69f-6d1f-4a1a-a31e-e021888d0d28_MobileBertSquad_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/cc69d69f-6d1f-4a1a-a31e-e021888d0d28_MobileBertSquad_fp32/MobileBertSquad_fp32.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-cc69d69f-6d1f-4a1a-a31e-e021888d0d28" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-cc69d69f-6d1f-4a1a-a31e-e021888d0d28-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/e3997104-a3d2-46b4-9fbf-39069906d123_MobileBertSquad_int8/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-e3997104-a3d2-46b4-9fbf-39069906d123-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/e3997104-a3d2-46b4-9fbf-39069906d123_MobileBertSquad_int8/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/e3997104-a3d2-46b4-9fbf-39069906d123_MobileBertSquad_int8/MobileBertSquad_int8.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-e3997104-a3d2-46b4-9fbf-39069906d123" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-e3997104-a3d2-46b4-9fbf-39069906d123-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/73a0402e-271b-4aa8-a6a5-ac05839ca569_MobileBertSquad_fp16/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-73a0402e-271b-4aa8-a6a5-ac05839ca569-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/73a0402e-271b-4aa8-a6a5-ac05839ca569_MobileBertSquad_fp16/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/73a0402e-271b-4aa8-a6a5-ac05839ca569_MobileBertSquad_fp16/MobileBertSquad_fp16.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-73a0402e-271b-4aa8-a6a5-ac05839ca569" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-73a0402e-271b-4aa8-a6a5-ac05839ca569-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/78eab9e5-9ff1-4769-9b55-933c81cc9a0f_MobileNetV1_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-78eab9e5-9ff1-4769-9b55-933c81cc9a0f-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/78eab9e5-9ff1-4769-9b55-933c81cc9a0f_MobileNetV1_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/78eab9e5-9ff1-4769-9b55-933c81cc9a0f_MobileNetV1_fp32/MobileNetV1_fp32.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-78eab9e5-9ff1-4769-9b55-933c81cc9a0f" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-78eab9e5-9ff1-4769-9b55-933c81cc9a0f-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/7d45f8e5-bb5e-48d0-928d-8f125104578f_MobileNetV2_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-7d45f8e5-bb5e-48d0-928d-8f125104578f-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/7d45f8e5-bb5e-48d0-928d-8f125104578f_MobileNetV2_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/7d45f8e5-bb5e-48d0-928d-8f125104578f_MobileNetV2_fp32/MobileNetV2_fp32.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-7d45f8e5-bb5e-48d0-928d-8f125104578f" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-7d45f8e5-bb5e-48d0-928d-8f125104578f-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/58855e40-eba9-4a71-b878-6b35e3460244_MobileNetV3Small_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-58855e40-eba9-4a71-b878-6b35e3460244-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/58855e40-eba9-4a71-b878-6b35e3460244_MobileNetV3Small_fp32/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/58855e40-eba9-4a71-b878-6b35e3460244_MobileNetV3Small_fp32/MobileNetV3Small_fp32.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-58855e40-eba9-4a71-b878-6b35e3460244" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-58855e40-eba9-4a71-b878-6b35e3460244-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/bc1338be-e3df-44fd-82e4-40ba9560a073_PersonDetect_int8/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-bc1338be-e3df-44fd-82e4-40ba9560a073-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/bc1338be-e3df-44fd-82e4-40ba9560a073_PersonDetect_int8/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/bc1338be-e3df-44fd-82e4-40ba9560a073_PersonDetect_int8/PersonDetect_int8.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-bc1338be-e3df-44fd-82e4-40ba9560a073" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-bc1338be-e3df-44fd-82e4-40ba9560a073-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/4a6f545e-1b4e-41a5-9236-792aa578184b_EfficientNet_int8/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-4a6f545e-1b4e-41a5-9236-792aa578184b-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/4a6f545e-1b4e-41a5-9236-792aa578184b_EfficientNet_int8/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/4a6f545e-1b4e-41a5-9236-792aa578184b_EfficientNet_int8/EfficientNet_int8.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=tosa;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-4a6f545e-1b4e-41a5-9236-792aa578184b" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-4a6f545e-1b4e-41a5-9236-792aa578184b-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +# Compile the module "${_IREE_ARTIFACTS_DIR}/ecf5c970-ee97-49f0-a4ed-df1f34e9d493_MiniLML12H384Uncased/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" +iree_bytecode_module( + NAME + "iree-module-ecf5c970-ee97-49f0-a4ed-df1f34e9d493-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" + MODULE_FILE_NAME + "${_IREE_ARTIFACTS_DIR}/ecf5c970-ee97-49f0-a4ed-df1f34e9d493_MiniLML12H384Uncased/e7e18b0f-c72d-4f1c-89b1-5afee70df6e9.vmfb" + SRC + "${_IREE_ARTIFACTS_DIR}/ecf5c970-ee97-49f0-a4ed-df1f34e9d493_MiniLML12H384Uncased/MiniLML12H384Uncased.mlir" + FLAGS + --iree-hal-target-backends=llvm-cpu;--iree-input-type=mhlo;--iree-llvm-target-triple=x86_64-unknown-linux-gnu;--iree-llvm-target-cpu=cascadelake + DEPENDS + "${_PACKAGE_NAME}_iree-import-model-ecf5c970-ee97-49f0-a4ed-df1f34e9d493" +) +# Mark dependency so that we have one target to drive them all. +add_dependencies(iree-benchmark-suites "${_PACKAGE_NAME}_iree-module-ecf5c970-ee97-49f0-a4ed-df1f34e9d493-e7e18b0f-c72d-4f1c-89b1-5afee70df6e9") + +################################################################################ diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc index 8937507efb92..4cf79c2b6a8e 100644 --- a/build_tools/bazel/iree.bazelrc +++ b/build_tools/bazel/iree.bazelrc @@ -268,7 +268,7 @@ build:remote_cache_bazel_ci --config=_remote_cache_base # specific docker container the CI Bazel builds are run in. The image URL is # included for clarity and so that this reference is automatically updated by # manage_images.py -build:remote_cache_bazel_ci --host_platform_remote_properties_override='properties:{name:"cache-silo-key" value:"gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01"}' +build:remote_cache_bazel_ci --host_platform_remote_properties_override='properties:{name:"cache-silo-key" value:"gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9"}' ############################################################################### # Configuration for uploading build results to Result Store UI diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index 0b53eb835b6f..2e60681649ef 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -39,6 +39,7 @@ "@llvm-project//llvm:IPO": ["LLVMipo"], "@llvm-project//lld": ["${IREE_LLD_TARGET}"], "@llvm-project//llvm:FileCheck": ["FileCheck"], + "@llvm-project//llvm:not": ["not"], # MLIR "@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"], "@llvm-project//mlir:DialectUtils": [""], diff --git a/build_tools/benchmarks/CMakeLists.txt b/build_tools/benchmarks/CMakeLists.txt index b3cc01e2f53b..67e281874433 100644 --- a/build_tools/benchmarks/CMakeLists.txt +++ b/build_tools/benchmarks/CMakeLists.txt @@ -40,7 +40,6 @@ function(benchmark_tool_py_test) endfunction() add_subdirectory(common) -add_subdirectory(suites) benchmark_tool_py_test( NAME diff --git a/build_tools/benchmarks/comparisons/mobilebert_int8_commands.py b/build_tools/benchmarks/comparisons/mobilebert_int8_commands.py new file mode 100644 index 000000000000..daed4c4fb767 --- /dev/null +++ b/build_tools/benchmarks/comparisons/mobilebert_int8_commands.py @@ -0,0 +1,190 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +from typing import Optional + +from common.benchmark_command import * +from common.benchmark_command_factory import BenchmarkCommandFactory + +_DEFAULT_NUM_BENCHMARK_RUNS = 50 +_DEFAULT_NUM_THREADS = 1 + + +class TfliteMobilebertInt8(TFLiteBenchmarkCommand): + """ Specializes the benchmark command to use TFLite. """ + + def __init__(self, + benchmark_binary: str, + model_name: str, + model_path: str, + test_data_dir: str, + driver: str = "cpu", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None): + super().__init__(benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset) + self.driver = driver + self.args.append("--input_layer=input_ids,segment_ids,input_mask") + self.args.append("--input_layer_value_files=input_ids:" + test_data_dir + + "/input_word_id.bin,segment_ids:" + test_data_dir + + "/input_type_id.bin,input_mask:" + test_data_dir + + "/input_mask.bin") + self.args.append("--input_layer_shape=1,384:1,384:1,384") + + +class IreeMobilebertInt8(IreeBenchmarkCommand): + """ Specializes the benchmark command to use IREE. """ + + def __init__(self, + benchmark_binary: str, + model_name: str, + model_path: str, + driver: str = "local-task", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None): + super().__init__(benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset) + self.driver = driver + self.args.append("--entry_function=main") + self.args.append( + '--function_input="1x384xi32=101 2129 2116 19576 2015 2106 3854 4679 2486 1029 102 1996 14169 2165 2019 2220 2599 1999 3565 4605 2753 1998 2196 11145 1012 8446 2001 3132 2011 7573 1005 1055 3639 1010 2029 14159 2032 2698 2335 1998 3140 2032 2046 2093 20991 2015 1010 2164 1037 19576 2029 2027 6757 2005 1037 7921 1012 7573 15674 3854 4679 2001 2315 3565 4605 12041 1010 3405 2274 3948 10455 1010 1016 13714 14918 1010 1998 2048 3140 19576 2015 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) + self.args.append( + '--function_input="1x384xi32=0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) + self.args.append( + '--function_input="1x384xi32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) + + +class MobilebertInt8CommandFactory(BenchmarkCommandFactory): + """ Generates `BenchmarkCommand` objects specific to running MobileBert.""" + + def __init__(self, base_dir: str): + self._model_name = "mobilebert-baseline-tf2-quant" + self._base_dir = base_dir + self._iree_benchmark_binary_path = os.path.join(base_dir, + "iree-benchmark-module") + self._tflite_benchmark_binary_path = os.path.join(base_dir, + "benchmark_model") + self._tflite_model_path = os.path.join(self._base_dir, "models", "tflite", + self._model_name + ".tflite") + self._tflite_test_data_dir = os.path.join(self._base_dir, "test_data", + "squad") + + def generate_benchmark_commands(self, device: str, + driver: str) -> list[BenchmarkCommand]: + if device == "desktop" and driver == "cpu": + return self._generate_cpu(device) + elif device == "desktop" and driver == "gpu": + return self._generate_gpu("cuda") + elif device == "mobile" and driver == "cpu": + return self._generate_cpu(device) + elif device == "mobile" and driver == "gpu": + return self._generate_gpu("vulkan") + else: + print("Warning! Not a valid configuration.") + return [] + + def _generate_cpu(self, device: str): + # Generate TFLite benchmarks. + tflite_mobilebert = TfliteMobilebertInt8(self._tflite_benchmark_binary_path, + self._model_name, + self._tflite_model_path, + self._tflite_test_data_dir, + driver="cpu") + + tflite_mobilebert_noxnn = TfliteMobilebertInt8( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + self._tflite_model_path, + self._tflite_test_data_dir, + driver="cpu") + tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") + + # Generate IREE benchmarks. + driver = "local-task" + backend = "llvm-cpu" + iree_model_path = os.path.join(self._base_dir, "models", "iree", backend, + self._model_name + ".vmfb") + iree_mobilebert = IreeMobilebertInt8(self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + driver=driver) + commands = [tflite_mobilebert, tflite_mobilebert_noxnn, iree_mobilebert] + + # Test mmt4d only on mobile. + if device == "mobile": + model_mmt4d_name = self._model_name + "_mmt4d" + iree_mmt4d_model_path = os.path.join(self._base_dir, "models", "iree", + backend, model_mmt4d_name + ".vmfb") + iree_mmt4d_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + model_mmt4d_name, + iree_mmt4d_model_path, + driver=driver) + commands.append(iree_mmt4d_mobilebert) + + model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d" + iree_im2col_mmt4d_model_path = os.path.join( + self._base_dir, "models", "iree", backend, + model_im2col_mmt4d_name + ".vmfb") + iree_im2col_mmt4d_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + model_im2col_mmt4d_name, + iree_im2col_mmt4d_model_path, + driver=driver) + commands.append(iree_im2col_mmt4d_mobilebert) + + return commands + + def _generate_gpu(self, driver: str): + tflite_mobilebert = TfliteMobilebertInt8(self._tflite_benchmark_binary_path, + self._model_name, + self._tflite_model_path, + self._tflite_test_data_dir, + driver="gpu") + tflite_mobilebert.args.append("--gpu_precision_loss_allowed=false") + + tflite_mobilebert_noxnn = TfliteMobilebertInt8( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + self._tflite_model_path, + self._tflite_test_data_dir, + driver="gpu") + tflite_mobilebert_noxnn.args.append("--gpu_precision_loss_allowed=false") + tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") + + iree_model_path = os.path.join(self._base_dir, "models", "iree", driver, + self._model_name + ".vmfb") + iree_mobilebert = IreeMobilebertInt8(self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + driver=driver) + + iree_padfuse_model_path = os.path.join(self._base_dir, "models", "iree", + driver, + self._model_name + "_padfuse.vmfb") + iree_padfuse_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + self._model_name + "_padfuse", + iree_padfuse_model_path, + driver=driver) + return [ + tflite_mobilebert, tflite_mobilebert_noxnn, iree_mobilebert, + iree_padfuse_mobilebert + ] diff --git a/build_tools/benchmarks/comparisons/run_benchmarks.py b/build_tools/benchmarks/comparisons/run_benchmarks.py index 8cdfd2b74501..edbceb65f89d 100644 --- a/build_tools/benchmarks/comparisons/run_benchmarks.py +++ b/build_tools/benchmarks/comparisons/run_benchmarks.py @@ -27,6 +27,7 @@ from common.benchmark_runner import * from common.utils import * from mobilebert_fp32_commands import * +from mobilebert_int8_commands import * from simple_commands import * @@ -109,12 +110,23 @@ def main(args): # Create factories for all models to be benchmarked. command_factory = [] command_factory.append(MobilebertFP32CommandFactory(args.base_dir)) + command_factory.append(MobilebertInt8CommandFactory(args.base_dir)) command_factory.append( SimpleCommandFactory(args.base_dir, "mobilenet_v2_1.0_224", "1x224x224x3xf32")) command_factory.append( SimpleCommandFactory(args.base_dir, "mobilenet_v2_224_1.0_uint8", "1x224x224x3xui8", "input", "1,224,224,3")) + command_factory.append( + SimpleCommandFactory(args.base_dir, "deeplabv3", "1x257x257x3xf32")) + command_factory.append( + SimpleCommandFactory(args.base_dir, "person_detect", "1x96x96x1xi8")) + command_factory.append( + SimpleCommandFactory(args.base_dir, "ssd_mobilenet_v2_static_1.0_int8", + "1x320x320x3xi8")) + command_factory.append( + SimpleCommandFactory(args.base_dir, "resnet_v2_101_1_default_1", + "1x299x299x3xf32")) if args.mode == "desktop": results_path = os.path.join(args.output_dir, "results.csv") diff --git a/build_tools/benchmarks/comparisons/setup_desktop.sh b/build_tools/benchmarks/comparisons/setup_desktop.sh index e227f21d24a4..582bb6835449 100644 --- a/build_tools/benchmarks/comparisons/setup_desktop.sh +++ b/build_tools/benchmarks/comparisons/setup_desktop.sh @@ -23,7 +23,14 @@ mkdir "${ROOT_DIR}/output" wget https://storage.googleapis.com/iree-model-artifacts/tflite_squad_test_data.zip -O /tmp/tflite_squad_test_data.zip unzip /tmp/tflite_squad_test_data.zip -d "${ROOT_DIR}/test_data/" -wget https://storage.googleapis.com/iree-model-artifacts/mobilebert_float_384_gpu.tflite -O "${ROOT_DIR}/models/tflite/mobilebert_float_384_gpu.tflite" +wget https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-quant.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/mobilebert_float_384_gpu.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_224_1.0_uint8.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/deeplabv3.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/person_detect.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_static_1.0_int8.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/resnet_v2_101_1_default_1.tflite -P "${ROOT_DIR}/models/tflite/" # Build IREE source. SOURCE_DIR=/tmp/github @@ -112,3 +119,4 @@ python3.9 build_tools/benchmarks/comparisons/run_benchmarks.py \ --output_dir=${ROOT_DIR}/output --mode=desktop cat "${ROOT_DIR}/output/results.csv" + diff --git a/build_tools/benchmarks/comparisons/setup_mobile.sh b/build_tools/benchmarks/comparisons/setup_mobile.sh index 92ef72ee0ff3..fb29d6c142a9 100644 --- a/build_tools/benchmarks/comparisons/setup_mobile.sh +++ b/build_tools/benchmarks/comparisons/setup_mobile.sh @@ -31,9 +31,14 @@ touch "${ROOT_DIR}/output/results.csv" wget https://storage.googleapis.com/iree-model-artifacts/tflite_squad_test_data.zip -O /tmp/tflite_squad_test_data.zip unzip /tmp/tflite_squad_test_data.zip -d "${ROOT_DIR}/test_data/" +wget https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-quant.tflite -P "${ROOT_DIR}/models/tflite/" wget https://storage.googleapis.com/iree-model-artifacts/mobilebert_float_384_gpu.tflite -P "${ROOT_DIR}/models/tflite/" wget https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_224_1.0_uint8.tflite -P "${ROOT_DIR}/models/tflite/" wget https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/deeplabv3.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/person_detect.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_v2_static_1.0_int8.tflite -P "${ROOT_DIR}/models/tflite/" +wget https://storage.googleapis.com/iree-model-artifacts/resnet_v2_101_1_default_1.tflite -P "${ROOT_DIR}/models/tflite/" # Build IREE source. SOURCE_DIR=/tmp/github @@ -144,7 +149,7 @@ for i in $(ls ${ROOT_DIR}/models/tflite/); do --iree-input-type=tosa \ --iree-hal-target-backends=vulkan-spirv \ --iree-vulkan-target-triple=valhall-unknown-android31 \ - --iree-flow-enable-fuse-padding-into-consumer-ops \ + --iree-flow-enable-fuse-padding-into-linalg-consumer-ops \ --iree-llvm-debug-symbols=false \ --iree-vm-bytecode-module-strip-source-map=true \ --iree-vm-emit-polyglot-zip=false \ diff --git a/build_tools/benchmarks/generate_cmake_benchmark_suites.py b/build_tools/benchmarks/generate_cmake_benchmark_suites.py new file mode 100755 index 000000000000..0bbe9381049c --- /dev/null +++ b/build_tools/benchmarks/generate_cmake_benchmark_suites.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +## Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Generates a CMake file to build the benchmark suites.""" + +import sys +import pathlib +import argparse + +# Add build_tools python dir to the search path. +sys.path.insert(0, str(pathlib.Path(__file__).parent / ".." / "python")) + +import benchmarks.iree.definitions +from e2e_test_framework import cmake_rule_generator + +TEMPLATE_DIR = pathlib.Path(__file__).parent +GENERATED_BENCHMARK_SUITES_CMAKE_TEMPLATE = cmake_rule_generator.read_template_from_file( + TEMPLATE_DIR / "iree_generated_benchmark_suites_template.cmake") + + +def parse_arguments(): + """Parses command-line options.""" + + parser = argparse.ArgumentParser() + parser.add_argument("--output", + required=True, + help="Path to write the generated cmake file.") + + return parser.parse_args() + + +def main(args: argparse.Namespace): + compile_specs, _ = benchmarks.iree.definitions.generate() + benchmark_rules = cmake_rule_generator.generate_rules( + model_artifacts_dir="${_MODEL_ARTIFACTS_DIR}", + iree_artifacts_dir="${_IREE_ARTIFACTS_DIR}", + iree_compile_specs=compile_specs) + cmake_file = GENERATED_BENCHMARK_SUITES_CMAKE_TEMPLATE.substitute( + __BENCHMARK_RULES='\n'.join(benchmark_rules)) + with open(args.output, "w") as output_file: + output_file.write(cmake_file) + + +if __name__ == "__main__": + main(parse_arguments()) diff --git a/build_tools/benchmarks/iree_generated_benchmark_suites_template.cmake b/build_tools/benchmarks/iree_generated_benchmark_suites_template.cmake new file mode 100644 index 000000000000..f0d2ea447d7d --- /dev/null +++ b/build_tools/benchmarks/iree_generated_benchmark_suites_template.cmake @@ -0,0 +1,19 @@ +################################################################################ +# Autogenerated by build_tools/benchmarks/generate_cmake_benchmark_suites.py # +# To update the benchmarks, modify the files in build_tools/benchmarks/suites/ # +# and regenerate this file. # +################################################################################ + +################################################################################ +# Defines the required variables # +################################################################################ +iree_package_name(_PACKAGE_NAME) +set(_ROOT_ARTIFACTS_DIR "$${IREE_BINARY_DIR}/benchmark_suites") +set(_MODEL_ARTIFACTS_DIR "$${_ROOT_ARTIFACTS_DIR}/models") +set(_IREE_ARTIFACTS_DIR "$${_ROOT_ARTIFACTS_DIR}/iree") + +################################################################################ +# Below is generated by build_tools/benchmarks/suites/cmake_rule_generator.py # +################################################################################ +$__BENCHMARK_RULES +################################################################################ diff --git a/build_tools/benchmarks/suites/iree_download_artifact_template.cmake b/build_tools/benchmarks/suites/iree_download_artifact_template.cmake deleted file mode 100644 index 5dc8577c1661..000000000000 --- a/build_tools/benchmarks/suites/iree_download_artifact_template.cmake +++ /dev/null @@ -1,15 +0,0 @@ -# Fetch the model from "$__SOURCE_URL" -add_custom_command( - OUTPUT "$__OUTPUT_PATH" - COMMAND - "$${Python3_EXECUTABLE}" "$${IREE_ROOT_DIR}/build_tools/scripts/download_file.py" - "$__SOURCE_URL" -o "$__OUTPUT_PATH" - DEPENDS - "$${IREE_ROOT_DIR}/build_tools/scripts/download_file.py" - COMMENT "Downloading $__SOURCE_URL" -) -add_custom_target( - "$${_PACKAGE_NAME}_$__TARGET_NAME" - DEPENDS - "$__OUTPUT_PATH" -) diff --git a/build_tools/benchmarks/suites/models/tflite_models.py b/build_tools/benchmarks/suites/models/tflite_models.py deleted file mode 100644 index 5275315fae2d..000000000000 --- a/build_tools/benchmarks/suites/models/tflite_models.py +++ /dev/null @@ -1,20 +0,0 @@ -## Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Defines TFLite models.""" - -from .. import unique_ids -from ..definitions import common_definitions - -MOBILENET_V2 = common_definitions.Model( - id=unique_ids.MODEL_MOBILENET_V2, - name="mobilenet_v2", - tags=["f32", "imagenet"], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - # Mirror of https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite", - entry_function="main", - input_types=["1x224x224x3xf32"]) diff --git a/build_tools/benchmarks/suites/unique_ids.py b/build_tools/benchmarks/suites/unique_ids.py deleted file mode 100644 index a70eaf270693..000000000000 --- a/build_tools/benchmarks/suites/unique_ids.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""List of unique random IDs in the benchmark suites. - -Each ID should be generated from uuid.uuid4(). -""" - -# Models -MODEL_MOBILENET_V2 = "7d45f8e5-bb5e-48d0-928d-8f125104578f" - -# Devices -DEVICE_SPEC_GCP_C2_STANDARD_16 = "9a4804f1-b1b9-46cd-b251-7f16a655f782" - -# IREE benchmarks -IREE_COMPILE_CONFIG_LINUX_CASCADELAKE = "e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" -IREE_RUN_CONFIG_LOCAL_SYNC = "13fc65a9-e5dc-4cbb-9c09-25b0b08f4c03" -IREE_RUN_CONFIG_LOCAL_TASK_BASE = "c7c4a15e-b20c-4898-bb4a-864f34ff34b2" diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml index 1debddbe10fd..47fdbe6b0465 100644 --- a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml +++ b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml @@ -9,7 +9,7 @@ steps: - label: "Build" commands: - - "docker run --user=$(id -u):$(id -g) --volume=\\${HOME?}:\\${HOME?} --volume=/etc/passwd:/etc/passwd:ro --volume=/etc/group:/etc/group:ro --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5 build_tools/cmake/build_android_benchmark.sh" + - "docker run --user=$(id -u):$(id -g) --volume=\\${HOME?}:\\${HOME?} --volume=/etc/passwd:/etc/passwd:ro --volume=/etc/group:/etc/group:ro --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817 build_tools/cmake/build_android_benchmark.sh" - "tar --exclude='*.tar.gz' --exclude='*.tgz' --exclude='*.mlir' --exclude='*.tflite' --exclude='*tf-model' -czvf benchmark-suites-${BUILDKITE_BUILD_NUMBER}.tgz build-host/benchmark_suites" - "find build-host/benchmark_suites -name '*.mlir' | tar -czvf source-mlir-models-${BUILDKITE_BUILD_NUMBER}.tgz -T -" - "tar -czvf iree-android-tools-${BUILDKITE_BUILD_NUMBER}.tgz build-android/tools/iree-benchmark-module build-android-trace/tools/iree-benchmark-module build-android/tools/build_config.txt" diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml b/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml index 92e5bc7b2a94..a05aae9d9f12 100644 --- a/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml +++ b/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml @@ -8,7 +8,7 @@ steps: - label: "build" commands: - "git submodule sync && git submodule update --init --jobs 8 --depth 1" - - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5 build_tools/cmake/build_host_and_android.sh arm64-v8a" + - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab build_tools/cmake/build_host_and_android.sh arm64-v8a" - "tar --exclude='*.o' --exclude='*.a' -czvf build-artifacts.tgz build-android" agents: - "queue=build" diff --git a/build_tools/buildkite/cmake/linux/pipeline.yml b/build_tools/buildkite/cmake/linux/pipeline.yml index 0b2678901997..df2b2e381187 100644 --- a/build_tools/buildkite/cmake/linux/pipeline.yml +++ b/build_tools/buildkite/cmake/linux/pipeline.yml @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception env: - DOCKER_IMAGE: "gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5" + DOCKER_IMAGE: "gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817" IREE_DOCKER_WORKDIR: "/usr/src/github/iree" steps: diff --git a/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml b/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml index 435b698cbd57..45af65555865 100644 --- a/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml +++ b/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml @@ -15,7 +15,7 @@ steps: --volume="$$PWD:$$IREE_DOCKER_WORKDIR" \ --workdir="$$IREE_DOCKER_WORKDIR" \ --rm \ - gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5 \ + gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817 \ build_tools/cmake/build_linux_benchmark.sh tar --exclude="*.tar.gz" \ --exclude="*.tgz" \ diff --git a/build_tools/cmake/iree_benchmark_suite.cmake b/build_tools/cmake/iree_benchmark_suite.cmake index ee4e013c153b..9c38441040fe 100644 --- a/build_tools/cmake/iree_benchmark_suite.cmake +++ b/build_tools/cmake/iree_benchmark_suite.cmake @@ -185,6 +185,19 @@ function(iree_benchmark_suite) "BENCHMARK_MODES;BENCHMARK_TOOL;MODULES" ) + # Try to check if the compiler supports the TARGET_BACKEND. If + # IREE_HOST_BINARY_ROOT is defined, we are using a compiler binary, in which + # case we can't check it's supported backend just by looking at this build + # dir's cmake variables --- we would have to implement a configure-check + # executing that compiler. + if (NOT DEFINED IREE_HOST_BINARY_ROOT) + string(TOUPPER ${_RULE_TARGET_BACKEND} _UPPERCASE_TARGET_BACKEND) + string(REPLACE "-" "_" _NORMALIZED_TARGET_BACKEND ${_UPPERCASE_TARGET_BACKEND}) + if(NOT IREE_TARGET_BACKEND_${_NORMALIZED_TARGET_BACKEND}) + return() + endif() + endif() + iree_package_name(_PACKAGE_NAME) # Add the benchmark suite target. @@ -223,26 +236,23 @@ function(iree_benchmark_suite) # Update the source file to the downloaded-to place. string(REPLACE "/" ";" _SOURCE_URL_SEGMENTS "${_SOURCE_URL}") list(POP_BACK _SOURCE_URL_SEGMENTS _LAST_URL_SEGMENT) - set(_DOWNLOAD_TARGET "${_PACKAGE_NAME}_iree-download-benchmark-source-${_LAST_URL_SEGMENT}") + set(_DOWNLOAD_TARGET_NAME "iree-download-benchmark-source-${_LAST_URL_SEGMENT}") # Strip off gzip/tar suffix if present (downloader unpacks if necessary) string(REGEX REPLACE "(\.gz)|(\.tar\.gz)$" "" _SOURCE_FILE_BASENAME "${_LAST_URL_SEGMENT}") set(_MODULE_SOURCE "${_ROOT_ARTIFACTS_DIR}/${_SOURCE_FILE_BASENAME}") - if(NOT TARGET "${_DOWNLOAD_TARGET}") - add_custom_command( - OUTPUT "${_MODULE_SOURCE}" - COMMAND - "${Python3_EXECUTABLE}" "${IREE_ROOT_DIR}/build_tools/scripts/download_file.py" - "${_SOURCE_URL}" -o "${_MODULE_SOURCE}" - DEPENDS - "${IREE_ROOT_DIR}/build_tools/scripts/download_file.py" - COMMENT "Downloading ${_SOURCE_URL}" - ) - add_custom_target("${_DOWNLOAD_TARGET}" - DEPENDS "${_MODULE_SOURCE}" + if(NOT TARGET "${_PACKAGE_NAME}_${_DOWNLOAD_TARGET_NAME}") + iree_fetch_artifact( + NAME + "${_DOWNLOAD_TARGET_NAME}" + SOURCE_URL + "${_SOURCE_URL}" + OUTPUT + "${_MODULE_SOURCE}" + UNPACK ) endif() - set(_MODULE_SOURCE_TARGET "${_DOWNLOAD_TARGET}") + set(_MODULE_SOURCE_TARGET "${_PACKAGE_NAME}_${_DOWNLOAD_TARGET_NAME}") endif() # If the source is a TFLite file, import it. diff --git a/build_tools/cmake/iree_bytecode_module.cmake b/build_tools/cmake/iree_bytecode_module.cmake index 85c44d756295..6bab22561345 100644 --- a/build_tools/cmake/iree_bytecode_module.cmake +++ b/build_tools/cmake/iree_bytecode_module.cmake @@ -109,13 +109,13 @@ function(iree_bytecode_module) list(APPEND _OUTPUT_FILES "${_RULE_STATIC_LIB_PATH}" "${_STATIC_HDR_PATH}") endif() - if(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv" AND - RISCV_CPU STREQUAL "rv64" AND - NOT _RULE_FLAGS MATCHES "iree-llvm-target-triple") + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv64" AND + CMAKE_SYSTEM_NAME STREQUAL "Linux" AND + NOT _RULE_FLAGS MATCHES "iree-llvm-target-triple") # RV64 Linux crosscompile toolchain can support iree-compile with # specific CPU flags. Add the llvm flags to support RV64 RVV codegen if # llvm-target-triple is not specified. - list(APPEND _RULE_FLAGS ${RISCV64_TEST_DEFAULT_LLVM_FLAGS}) + list(APPEND _ARGS ${RISCV64_TEST_DEFAULT_LLVM_FLAGS}) endif() if(_RULE_FRIENDLY_NAME) diff --git a/build_tools/cmake/iree_cc_test.cmake b/build_tools/cmake/iree_cc_test.cmake index 6bfc9fbebace..645532ee4040 100644 --- a/build_tools/cmake/iree_cc_test.cmake +++ b/build_tools/cmake/iree_cc_test.cmake @@ -160,7 +160,8 @@ function(iree_cc_test) TEST_TMPDIR=${_ANDROID_ABS_DIR}/test_tmpdir ) set_property(TEST ${_NAME_PATH} PROPERTY ENVIRONMENT ${_ENVIRONMENT_VARS}) - elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv" AND RISCV_CPU STREQUAL "rv64") + elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv64" AND + CMAKE_SYSTEM_NAME STREQUAL "Linux") # The test target needs to run within the QEMU emulator for RV64 Linux # crosscompile build or on-device. add_test( diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake index eae1901101ad..98e08fa27a7b 100644 --- a/build_tools/cmake/iree_check_test.cmake +++ b/build_tools/cmake/iree_check_test.cmake @@ -42,8 +42,8 @@ function(iree_bytecode_module_for_iree_check_test_and_friends) list(APPEND _RULE_FLAGS "--iree-llvm-target-triple=${_TARGET_TRIPLE}") endif() - if(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv" AND - RISCV_CPU STREQUAL "rv64" AND + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv64" AND + CMAKE_SYSTEM_NAME STREQUAL "Linux" AND _RULE_TARGET_BACKEND STREQUAL "llvm-cpu" AND NOT _RULE_FLAGS MATCHES "iree-llvm-target-triple") # RV64 Linux crosscompile toolchain can support iree_check_test with diff --git a/build_tools/cmake/iree_fetch_artifact.cmake b/build_tools/cmake/iree_fetch_artifact.cmake new file mode 100644 index 000000000000..b282b7db3e6d --- /dev/null +++ b/build_tools/cmake/iree_fetch_artifact.cmake @@ -0,0 +1,56 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# iree_fetch_artifact() +# +# Download file from URL. NEVER Use this rule to download from untrusted +# sources, it doesn't unpack the file safely. +# +# Parameters: +# NAME: Name of target (see Note). +# SOURCE_URL: Source URL to donwload the file. +# OUTPUT: Path to the output file or directory to unpack. +# UNPACK: When added, it will try to unpack the archive if supported. +# +# Note: +# By default, it will create a target named ${_PACKAGE_NAME}_${_RULE_NAME}. +function(iree_fetch_artifact) + cmake_parse_arguments( + _RULE + "UNPACK" + "NAME;SOURCE_URL;OUTPUT" + "" + ${ARGN} + ) + + set(_ARGS "${IREE_ROOT_DIR}/build_tools/scripts/download_file.py") + list(APPEND _ARGS "${_RULE_SOURCE_URL}") + list(APPEND _ARGS "-o") + list(APPEND _ARGS "${_RULE_OUTPUT}") + + if(_RULE_UNPACK) + list(APPEND _ARGS "--unpack") + endif() + + # TODO: CMake built-in file command can replace the python script. But python + # script also provides streaming unpack (doesn't use double space when + # unpacking). Need to evaluate if we want to replace. + add_custom_command( + OUTPUT "${_RULE_OUTPUT}" + COMMAND + "${Python3_EXECUTABLE}" + ${_ARGS} + DEPENDS + "${IREE_ROOT_DIR}/build_tools/scripts/download_file.py" + COMMENT "Downloading ${_RULE_SOURCE_URL}" + ) + + iree_package_name(_PACKAGE_NAME) + add_custom_target("${_PACKAGE_NAME}_${_RULE_NAME}" + DEPENDS + "${_RULE_OUTPUT}" + ) +endfunction() diff --git a/build_tools/cmake/iree_hal_cts_test_suite.cmake b/build_tools/cmake/iree_hal_cts_test_suite.cmake index 3e98e9eb2f4e..ef1790cc5a28 100644 --- a/build_tools/cmake/iree_hal_cts_test_suite.cmake +++ b/build_tools/cmake/iree_hal_cts_test_suite.cmake @@ -89,6 +89,11 @@ function(iree_hal_cts_test_suite) list(APPEND _TRANSLATE_FLAGS "--iree-llvm-target-triple=${_TARGET_TRIPLE}") endif() + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv64" AND + CMAKE_SYSTEM_NAME STREQUAL "Linux") + list(APPEND _TRANSLATE_FLAGS ${RISCV64_TEST_DEFAULT_LLVM_FLAGS}) + endif() + # Skip if already created (multiple suites using the same compiler setting). iree_package_name(_PACKAGE_NAME) if(NOT TARGET ${_PACKAGE_NAME}_${_EXECUTABLES_TESTDATA_NAME}_c) diff --git a/build_tools/cmake/iree_native_test.cmake b/build_tools/cmake/iree_native_test.cmake index c386283b8f04..4915be4c4bae 100644 --- a/build_tools/cmake/iree_native_test.cmake +++ b/build_tools/cmake/iree_native_test.cmake @@ -113,7 +113,8 @@ function(iree_native_test) "TEST_TMPDIR=${_ANDROID_ABS_DIR}/test_tmpdir" ) set_property(TEST ${_TEST_NAME} PROPERTY ENVIRONMENT ${_ENVIRONMENT_VARS}) - elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv" AND RISCV_CPU STREQUAL "rv64") + elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv64" AND + CMAKE_SYSTEM_NAME STREQUAL "Linux") # The test target needs to run within the QEMU emulator for RV64 Linux # crosscompile build or on-device. add_test( @@ -143,5 +144,5 @@ function(iree_native_test) list(APPEND _RULE_LABELS "${_PACKAGE_PATH}") set_property(TEST ${_TEST_NAME} PROPERTY LABELS "${_RULE_LABELS}") set_property(TEST "${_TEST_NAME}" PROPERTY REQUIRED_FILES "${_RULE_DATA}") - set_property(TEST ${_TEST_NAME} PROPERTY TIMEOUT ${_RULE_ARGS}) + set_property(TEST ${_TEST_NAME} PROPERTY TIMEOUT ${_RULE_TIMEOUT}) endfunction() diff --git a/build_tools/cmake/iree_python.cmake b/build_tools/cmake/iree_python.cmake index d487313ddb5d..1b64c60f37ea 100644 --- a/build_tools/cmake/iree_python.cmake +++ b/build_tools/cmake/iree_python.cmake @@ -292,3 +292,36 @@ function(iree_py_test) ${_RULE_TIMEOUT} ) endfunction() + +# iree_build_tools_py_test() +# +# CMake function to test with build_tools python modules. +# +# Parameters: +# NAME: name of test +# SRC: Test source file +# ARGS: Command line arguments to the Python source file. +# LABELS: Additional labels to apply to the test. The package path is added +# automatically. +function(iree_build_tools_py_test) + cmake_parse_arguments( + _RULE + "" + "NAME;SRC" + "ARGS;LABELS" + ${ARGN} + ) + + iree_local_py_test( + NAME + "${_RULE_NAME}" + SRC + "${_RULE_SRC}" + ARGS + ${_RULE_ARGS} + LABELS + ${_RULE_LABELS} + PACKAGE_DIRS + "${IREE_ROOT_DIR}/build_tools/python" + ) +endfunction() diff --git a/build_tools/cmake/riscv.toolchain.cmake b/build_tools/cmake/riscv.toolchain.cmake index 8575a835bacd..590ac7ebf97b 100644 --- a/build_tools/cmake/riscv.toolchain.cmake +++ b/build_tools/cmake/riscv.toolchain.cmake @@ -46,6 +46,7 @@ set(RISCV_LINKER_FLAGS) set(RISCV_LINKER_FLAGS_EXE) if(RISCV_CPU STREQUAL "rv64") + set(CMAKE_SYSTEM_PROCESSOR riscv64) set(CMAKE_SYSTEM_NAME Linux) set(CMAKE_SYSTEM_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib") set(RISCV_COMPILER_FLAGS "${RISCV_COMPILER_FLAGS} -march=rv64gc -mabi=lp64d") @@ -59,6 +60,7 @@ if(RISCV_CPU STREQUAL "rv64") "--riscv-v-vector-bits-min=512" CACHE INTERNAL "Default llvm codegen flags for testing purposes") elseif(RISCV_CPU STREQUAL "rv32-baremetal") + set(CMAKE_SYSTEM_PROCESSOR riscv32) set(CMAKE_SYSTEM_NAME Generic) set(CMAKE_CROSSCOMPILING ON CACHE BOOL "") set(CMAKE_C_STANDARD 11) diff --git a/build_tools/cmake/run_riscv64_test.sh b/build_tools/cmake/run_riscv64_test.sh index 833cf5ae48a8..0fe78b842ea3 100755 --- a/build_tools/cmake/run_riscv64_test.sh +++ b/build_tools/cmake/run_riscv64_test.sh @@ -15,9 +15,9 @@ set -e # A QEMU 64 Linux emulator must be available at the path specified by the # `QEMU_RV64_BIN` environment variable to run the artifacts under the emulator. -if [[ -z "${QEMU_RV64_BIN}" ]]; then - "${QEMU_RV64_BIN}" "-cpu rv64,x-v=true,x-k=true,vlen=512,elen=64,vext_spec=v1.0 \ - -L ${RISCV_RV64_LINUX_TOOLCHAIN_ROOT}/sysroot $*" +if [[ ! -z "${QEMU_RV64_BIN}" ]]; then + "${QEMU_RV64_BIN}" "-cpu" "rv64,x-v=true,x-k=true,vlen=512,elen=64,vext_spec=v1.0" \ + "-L" "${RISCV_RV64_LINUX_TOOLCHAIN_ROOT}/sysroot" $* fi # TODO(dcaballe): Add on-device run commands. diff --git a/build_tools/cmake/test_riscv64.sh b/build_tools/cmake/test_riscv64.sh index ee0bc7774eca..b091a83c2bc2 100755 --- a/build_tools/cmake/test_riscv64.sh +++ b/build_tools/cmake/test_riscv64.sh @@ -79,8 +79,10 @@ ctest --test-dir ${BUILD_RISCV_DIR}/runtime/ --timeout 900 --output-on-failure \ --no-tests=error --label-exclude \ '(^nokokoro$|^driver=vulkan$|^driver=cuda$|^vulkan_uses_vk_khr_shader_float16_int8$|^requires-filesystem$|^requires-dtz$)' -# Test e2e models. Excluding mobilebert and fp16 for now. +# Test e2e models. Excluding mobilebert, fp16, and lowering_config regression +# tests for now. +# TODO(#10462): Investigate the lowering_config test issue. ctest --test-dir ${BUILD_RISCV_DIR}/tests/e2e --timeout 900 --output-on-failure \ --no-tests=error --label-exclude \ '(^nokokoro$|^driver=vulkan$|^driver=cuda$|^vulkan_uses_vk_khr_shader_float16_int8$)' \ - -E '(bert|fp16)' + -E '(bert|fp16|regression_llvm-cpu_lowering_config)' diff --git a/build_tools/docker/README.md b/build_tools/docker/README.md index c1e8b647b338..cf5df8c97c51 100644 --- a/build_tools/docker/README.md +++ b/build_tools/docker/README.md @@ -18,12 +18,6 @@ To explore an image interactively, use `docker run`, e.g. docker run --interactive --tty --rm base ``` -Production versions of the images can be downloaded from GCR: - -```shell -docker pull gcr.io/iree-oss/base:prod -``` - You can find more information in the [official Docker docs](https://docs.docker.com/get-started/overview/). @@ -93,21 +87,4 @@ python3 build_tools/docker/manage_images.py --image all 4. Commit the changes and send a PR for review. The CI will use the updated digest references to test the new images. -5. Merge your PR after is approved and all CI tests pass. **Please remember to - complete the step below**. - -### Part 3. Updating the `:prod` tag - -Kokoro builds preload images tagged with `prod` on VM creation, so after -changing the images used, you should also update the images tagged as `prod` -in GCR. This also makes development significantly easier for others who need to -modify the `docker` images. - -6. We use `build_tools/docker/prod_digests.txt` as a source of truth for which - versions of the images on GCR should have the `:prod` tag. The following - command will ensure that you are at upstream HEAD on the `main` branch before - it updates the tags. - - ```shell - python3 build_tools/docker/manage_prod.py - ``` +5. Merge your PR after is approved and all CI tests pass. diff --git a/build_tools/docker/android/Dockerfile b/build_tools/docker/android/Dockerfile index e54d11dda407..87bfd6a4ae50 100644 --- a/build_tools/docker/android/Dockerfile +++ b/build_tools/docker/android/Dockerfile @@ -7,13 +7,13 @@ # An image for cross-compiling IREE towards Android. FROM gcr.io/iree-oss/base@sha256:5d43683c6b50aebe1fca6c85f2012f3b0fa153bf4dd268e8767b619b1891423a -ARG NDK_VERSION=r21d +ARG NDK_VERSION=r25b WORKDIR /install-ndk ENV ANDROID_NDK "/usr/src/android-ndk-${NDK_VERSION}" -RUN wget -q "https://dl.google.com/android/repository/android-ndk-${NDK_VERSION?}-linux-x86_64.zip" \ - && unzip -q "android-ndk-${NDK_VERSION?}-linux-x86_64.zip" -d /usr/src/ \ +RUN wget -q "https://dl.google.com/android/repository/android-ndk-${NDK_VERSION?}-linux.zip" \ + && unzip -q "android-ndk-${NDK_VERSION?}-linux.zip" -d /usr/src/ \ && rm -rf /install-ndk WORKDIR / diff --git a/build_tools/docker/frontends-nvidia/Dockerfile b/build_tools/docker/frontends-nvidia/Dockerfile index ed2bea82bf7e..d20ff057e3ae 100644 --- a/build_tools/docker/frontends-nvidia/Dockerfile +++ b/build_tools/docker/frontends-nvidia/Dockerfile @@ -8,7 +8,7 @@ # The NVidia drivers need to *exactly* match between the host machine and the # docker image. -FROM gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5 +FROM gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817 # We use .deb files that we host because we have to pin the version exactly to # match the host machine and packages routinely dissapear from the Ubuntu diff --git a/build_tools/docker/frontends-swiftshader/Dockerfile b/build_tools/docker/frontends-swiftshader/Dockerfile index 13cc7b4b847b..299d55f883ee 100644 --- a/build_tools/docker/frontends-swiftshader/Dockerfile +++ b/build_tools/docker/frontends-swiftshader/Dockerfile @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -FROM gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5 +FROM gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817 COPY --from=gcr.io/iree-oss/swiftshader@sha256:5027d56cdfee743d956bffd035668f7784166a486c48c74b42e5882cb0c289bf \ /swiftshader /swiftshader diff --git a/build_tools/docker/frontends/Dockerfile b/build_tools/docker/frontends/Dockerfile index 990a56b40cb2..e844c98e97c2 100644 --- a/build_tools/docker/frontends/Dockerfile +++ b/build_tools/docker/frontends/Dockerfile @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -FROM gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5 +FROM gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab WORKDIR /install-kws diff --git a/build_tools/docker/manage_prod.py b/build_tools/docker/manage_prod.py deleted file mode 100755 index 37acd8ec643a..000000000000 --- a/build_tools/docker/manage_prod.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2020 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Uses prod_digests.txt to update GCR's :prod tags. - -Usage: - Pull all images that should have :prod tags, tag them with :prod and push - them to GCR. This will make sure that you are at upstream head on the main - branch before pushing: - python3 build_tools/docker/manage_prod.py -""" - -import os -import utils - -if __name__ == "__main__": - # Ensure the user has the correct authorization if they try to push to GCR. - utils.check_gcloud_auth() - - # Only allow the :prod tag to be pushed from the version of - # `prod_digests.txt` at upstream HEAD on the main branch. - utils.run_command( - [os.path.normpath("build_tools/scripts/git/git_update.sh"), "main"]) - - with open(utils.PROD_DIGESTS_PATH, "r") as f: - images_with_digests = [line.strip() for line in f.readlines()] - - for image_with_digest in images_with_digests: - image_url, _ = image_with_digest.split("@") - prod_image_url = f"{image_url}:prod" - - utils.run_command(["docker", "pull", image_with_digest]) - utils.run_command(["docker", "tag", image_with_digest, prod_image_url]) - utils.run_command(["docker", "push", prod_image_url]) diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt index 70cf265e83b3..a3b16ba5660e 100644 --- a/build_tools/docker/prod_digests.txt +++ b/build_tools/docker/prod_digests.txt @@ -1,12 +1,12 @@ gcr.io/iree-oss/base@sha256:5d43683c6b50aebe1fca6c85f2012f3b0fa153bf4dd268e8767b619b1891423a gcr.io/iree-oss/swiftshader@sha256:5027d56cdfee743d956bffd035668f7784166a486c48c74b42e5882cb0c289bf gcr.io/iree-oss/samples@sha256:ea1bfce1c853e0b3d1afad094086535f903950dc81810024c4cf6347d90aea8a -gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5 -gcr.io/iree-oss/frontends-nvidia@sha256:e934ed09e9e60c28ebe11a02f37a993dd975db40118d410c4279d0fa2d4e6b9a -gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 +gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817 +gcr.io/iree-oss/frontends-nvidia@sha256:28cd43f36b1ca0633bbd915911abe6d22b4aa16093f074e87016305322a0eba1 +gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 gcr.io/iree-oss/gradle-android@sha256:d9d0f880c3ac995b9e8a23bbf8079b80f6842851654016c5f362c747c09aaf93 -gcr.io/iree-oss/riscv@sha256:720bc0215d8462ea14352edc22710a6ce4c0c1daff581d179dd173885f1d8a35 +gcr.io/iree-oss/riscv@sha256:d6f0e293a50faf5abbd564c1d1bb9dc6456d7ce93d07b131c381fa64c1daed62 gcr.io/iree-oss/nvidia@sha256:7c2f56db65e656c15e6c96b5812a8275dd53c82bf41221192f9ba8a451aad870 gcr.io/iree-oss/emscripten@sha256:8ccc1c8de11919faf23aeaa585b13e5c5050952db76c101d6f61367280a3546f -gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5 +gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab gcr.io/iree-oss/manylinux2014_x86_64-release@sha256:b09c10868f846308bad2eab253a77d0a3f097816c40342bc289d8e62509bc5f9 diff --git a/build_tools/docker/riscv/Dockerfile b/build_tools/docker/riscv/Dockerfile index e2449eb3ab78..a5a9956a306e 100644 --- a/build_tools/docker/riscv/Dockerfile +++ b/build_tools/docker/riscv/Dockerfile @@ -8,11 +8,11 @@ FROM gcr.io/iree-oss/base@sha256:5d43683c6b50aebe1fca6c85f2012f3b0fa153bf4dd268e8767b619b1891423a AS install-riscv WORKDIR /install-riscv -RUN wget "https://storage.googleapis.com/iree-shared-files/toolchain_iree_rvv-intrinsic.tar.gz" -RUN tar -xf "toolchain_iree_rvv-intrinsic.tar.gz" -C /usr/src/ -RUN wget "https://storage.googleapis.com/iree-shared-files/toolchain_iree_rv32.tar.gz" -RUN tar -xf "toolchain_iree_rv32.tar.gz" -C /usr/src/ -RUN wget "https://storage.googleapis.com/iree-shared-files/qemu-riscv.tar.gz" +RUN wget --no-verbose "https://storage.googleapis.com/iree-shared-files/toolchain_iree_20220918.tar.gz" +RUN tar -xf "toolchain_iree_20220918.tar.gz" -C /usr/src/ +RUN wget --no-verbose "https://storage.googleapis.com/iree-shared-files/toolchain_iree_rv32_20220918.tar.gz" +RUN tar -xf "toolchain_iree_rv32_20220918.tar.gz" -C /usr/src/ +RUN wget --no-verbose "https://storage.googleapis.com/iree-shared-files/qemu-riscv.tar.gz" RUN tar -xf "qemu-riscv.tar.gz" -C /usr/src/ FROM gcr.io/iree-oss/base@sha256:5d43683c6b50aebe1fca6c85f2012f3b0fa153bf4dd268e8767b619b1891423a AS final diff --git a/build_tools/github_actions/runner/gcp/create_image.sh b/build_tools/github_actions/runner/gcp/create_image.sh index a1006e2a66a9..517684c1cfcc 100755 --- a/build_tools/github_actions/runner/gcp/create_image.sh +++ b/build_tools/github_actions/runner/gcp/create_image.sh @@ -6,24 +6,68 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -set -euo pipefail +set -o errexit # Exit if any command fails +set -o errtrace # make ERR trap inherit +set -o pipefail # return error if any part of a pipe errors +set -o nounset # error if an undefined variable is used TIME_STRING="$(date +%Y-%m-%d-%s)" +SUCCESS_DELETE_INSTANCE=1 +FAILURE_DELETE_INSTANCE=0 + INSTANCE_NAME="${INSTANCE_NAME:-github-runner-template-cpu-${TIME_STRING}}" IMAGE_NAME="${IMAGE_NAME:-github-runner-cpu-${TIME_STRING}}" ZONE="${ZONE:-us-central1-a}" PROJECT=iree-oss BASE_IMAGE="${BASE_IMAGE:-projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20220902}" +# The size of the base image +IMAGE_SIZE_GB=10 # It takes a little bit to bring up ssh on the instance. I haven't found a # better way to wait for this than just polling. MAX_IP_ATTEMPTS=5 MAX_SSH_ATTEMPTS=10 MAX_SCP_ATTEMPTS=5 +DELETE_INSTANCE_CMD=( + gcloud + compute + instances + delete + "${INSTANCE_NAME}" + --zone="${ZONE}" +) + +function cleanup_reminder() { + echo "Make sure to delete ${INSTANCE_NAME} when you're done debugging:" + echo "${DELETE_INSTANCE_CMD[@]}" +} + +function failure_exit() { + local exit_code="$?" + trap - INT ERR EXIT + if (( exit_code != 0 )); then + echo "Image creation was not successful." + if (( FAILURE_DELETE_INSTANCE==1 )); then + echo "Attempting to delete instance ${INSTANCE_NAME}" + "${DELETE_INSTANCE_CMD[@]}" --quiet + exit "${exit_code}" + else + cleanup_reminder + fi + fi + exit "${exit_code}" +} + +trap failure_exit INT ERR EXIT + SCRIPT_DIR="$(dirname -- "$( readlink -f -- "$0"; )")"; -CREATE_INSTANCE_ARGS=( +CREATE_INSTANCE_CMD=( + gcloud + compute + instances + create "${INSTANCE_NAME}" --project=iree-oss --zone="${ZONE}" @@ -37,7 +81,7 @@ CREATE_INSTANCE_ARGS=( --provisioning-model=STANDARD --no-service-account --no-scopes - --create-disk="boot=yes,device-name=${INSTANCE_NAME},image=${BASE_IMAGE},mode=rw,size=10,type=projects/${PROJECT}/zones/${ZONE}/diskTypes/pd-balanced" + --create-disk="boot=yes,device-name=${INSTANCE_NAME},image=${BASE_IMAGE},mode=rw,size=${IMAGE_SIZE_GB},type=projects/${PROJECT}/zones/${ZONE}/diskTypes/pd-balanced,auto-delete=yes" --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring @@ -80,7 +124,6 @@ function wait_for_ssh() { while (( failed_attempts <= max_attempts )) && ! ssh_output="$(ssh_ping 2>&1)"; do echo -n '.' failed_attempts="$(( failed_attempts+1 ))" - sleep 1 done if (( failed_attempts > max_attempts )); then @@ -92,7 +135,7 @@ function wait_for_ssh() { function create_image() { echo "Creating instance for boot disk" - (set -x; gcloud compute instances create "${CREATE_INSTANCE_ARGS[@]}") + (set -x; "${CREATE_INSTANCE_CMD[@]}") # We could only use the ssh check below, but it's much nicer to know why an # an instance isn't responsive and this is something we can check first. @@ -100,46 +143,44 @@ function create_image() { wait_for_ip "${MAX_IP_ATTEMPTS}" wait_for_ssh "${MAX_SSH_ATTEMPTS}" - local log_file="$(mktemp)" - touch "${log_file}" echo "" - echo "Streaming startup logs from instance" - tail -f "${log_file}" & - local -i failed_scp_attempts=0 - local last_line="" - local scp_output="" - # Is waiting for a certain line in the logs kind of hacky? yes - # Is there a better way to do it? probably - # Does the better way involve a bunch of fiddling about? also probably - while (( failed_scp_attempts < MAX_SCP_ATTEMPTS )) && [[ "${last_line}" != "Setup complete" ]]; do - ret=0 - scp_output="$(gcloud compute scp \ - --zone="${ZONE}" \ - "${INSTANCE_NAME}:/startup.log" \ - "${log_file}" 2>&1)" || ret=$? - if (( ret != 0 )); then - failed_scp_attempts="$(( failed_scp_attempts+1 ))" - sleep 1 - else - last_line="$(tail --lines=1 "${log_file}")" - fi - done + local log_file="$(mktemp --tmpdir ${INSTANCE_NAME}.XXX.startup.log)" + echo "Streaming startup logs from instance to stdout and ${log_file}" + + # Get the PID of the startup script + local startup_pid="$(gcloud compute ssh "${INSTANCE_NAME}" --zone="${ZONE}" \ + --no-user-output-enabled \ + --command='systemctl show --property=ExecMainPID --value google-startup-scripts')" - if (( failed_scp_attempts >= MAX_SCP_ATTEMPTS )); then - echo "Was unable to copy logs from instance. Output from scp:" - echo "${scp_output}" + echo "" + echo "*******************" + + # -t forces a pseudo-tty which allows us to run tail with a follow + gcloud compute ssh "${INSTANCE_NAME}" --zone="${ZONE}" \ + --no-user-output-enabled \ + --ssh-flag="-t" \ + --command="tail --follow=name --retry --pid=${startup_pid} /startup.log" \ + | tee "${log_file}" + + echo "*******************" + echo "" + + local exit_code="$(gcloud compute ssh "${INSTANCE_NAME}" --command="cat /startup-exit.txt")" + + if [[ "${exit_code}" != +([0-9]) ]]; then + echo "Failed to retrieve exit code from startup script (got '${exit_code}')." exit 1 fi - if [[ "${last_line}" != "Setup complete" ]]; then - echo "Instance did not complete its setup. Please check the logs above." - exit 1 + if (( exit_code != 0 )); then + echo "Image setup failed with code '${exit_code}'. See logs above." + exit "${exit_code}" fi echo "Startup finished successfully." - echo "Deleting log file" + echo "Deleting remote log file" gcloud compute ssh "${INSTANCE_NAME}" --zone="${ZONE}" \ --no-user-output-enabled \ --command="sudo rm /startup.log" @@ -154,8 +195,13 @@ function create_image() { --source-disk="${INSTANCE_NAME}" \ --source-disk-zone="${ZONE}" - echo "Deleting instance" - gcloud compute instances delete "${INSTANCE_NAME}" --zone="${ZONE}" --quiet + if (( SUCCESS_DELETE_INSTANCE == 1 )); then + echo "Deleting instance" + "${DELETE_INSTANCE_CMD[@]}" --quiet + else + echo "Not deleting instance because SUCCESS_DELETE_INSTANCE=${SUCCESS_DELETE_INSTANCE}" + cleanup_reminder + fi echo "Successfully created image: ${IMAGE_NAME}" } diff --git a/build_tools/github_actions/runner/gcp/create_templates.sh b/build_tools/github_actions/runner/gcp/create_templates.sh index 9cff7d57c3f6..7bed83df3238 100755 --- a/build_tools/github_actions/runner/gcp/create_templates.sh +++ b/build_tools/github_actions/runner/gcp/create_templates.sh @@ -19,7 +19,13 @@ TEMPLATE_BASE_NAME="${TEMPLATE_BASE_NAME:-github-runner}" TEMPLATE_CONFIG_REPO="${TEMPLATE_CONFIG_REPO:-iree-org/iree}" TEMPLATE_CONFIG_REF="${TEMPLATE_CONFIG_REF:-$(git rev-parse HEAD)}" GPU_IMAGE="github-runner-gpu-2022-08-15-1660603500" -CPU_IMAGE="github-runner-2022-07-28-1659048799" +# Due to an early misconfiguration, this boot image is really large (1TB), so we +# need a 1TB boot disk. +# TODO(gcmn): Shrink the image and disk size. +GPU_DISK_SIZE_GB=1000 +CPU_IMAGE="github-runner-cpu-2022-09-22-1663865258" +# The image is only 10GB, but we need some space for Docker images and such. +CPU_DISK_SIZE_GB=100 if (( TESTING==0 )); then if [[ "${TEMPLATE_CONFIG_REPO}" != iree-org/iree ]]; then @@ -111,13 +117,13 @@ function create_template() { --machine-type=a2-highgpu-1g --maintenance-policy=TERMINATE --accelerator=count=1,type=nvidia-tesla-a100 - --create-disk="auto-delete=yes,boot=yes,image=projects/iree-oss/global/images/${GPU_IMAGE},mode=rw,size=1000,type=pd-balanced" + --create-disk="auto-delete=yes,boot=yes,image=projects/iree-oss/global/images/${GPU_IMAGE},mode=rw,size=${GPU_DISK_SIZE_GB},type=pd-balanced" ) elif [[ "${type}" == cpu ]]; then args+=( --machine-type=n1-standard-96 --maintenance-policy=MIGRATE - --create-disk="auto-delete=yes,boot=yes,image=projects/iree-oss/global/images/${CPU_IMAGE},mode=rw,size=1000,type=pd-balanced" + --create-disk="auto-delete=yes,boot=yes,image=projects/iree-oss/global/images/${CPU_IMAGE},mode=rw,size=${CPU_DISK_SIZE_GB},type=pd-balanced" ) else echo "Got unrecognized type '${type}'" >2 diff --git a/build_tools/github_actions/runner/gcp/image_setup.sh b/build_tools/github_actions/runner/gcp/image_setup.sh index f94224163e3d..4880180c4f95 100644 --- a/build_tools/github_actions/runner/gcp/image_setup.sh +++ b/build_tools/github_actions/runner/gcp/image_setup.sh @@ -17,8 +17,55 @@ set -o pipefail # return error if any part of a pipe errors set -o nounset # error if an undefined variable is used +GCLOUD_VERSION=402.0.0 +GCLOUD_ARCHIVE_DIGEST=a9902b57d4cba2ebb76d7354570813d3d8199c36b95a1111a1b7fea013beaaf9 + + +function save_exit_code() { + local exit_code="$?" + echo "${exit_code}" > /startup-exit.txt + trap - EXIT + exit "${exit_code}" +} + +trap save_exit_code EXIT INT TERM + +function apt_maybe_purge() { + # Remove and purge packages if they are installed and don't error if they're + # not or if they're not findable in the ppa. + local -a to_remove=() + for pkg in "$@"; do + ret=0 + if dpkg --status $pkg &> /dev/null; then + to_remove+=("${pkg}") + fi + done + if (( "${#to_remove[@]}" != 0 )); then + apt-get remove --purge --autoremove "${to_remove[@]}" + fi +} + function startup() { - #################################### APT ##################################### + # Shut down in 5 hours. Makes sure this instance doesn't hang around forever + # if setup fails. Someone can cancel the shutdown with `shutdown -c`. + nohup shutdown -h +300 & + cd / + + ########################### Create the runner user ########################### + + # GCE "helpfully" creates users for apparently any account that has ever + # logged in on any VM. Delete it if it's there. + userdel --force --remove runner || true + adduser --system --group "runner" + groupadd docker + usermod --append --groups docker runner + usermod --append --groups sudo runner + groups runner # Print out the groups of runner to verify this worked + + echo "enabling passwordless sudo for runner user" + echo "runner ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/99-runner + + #################################### Apt ##################################### # Disable apt prompts export DEBIAN_FRONTEND="noninteractive" @@ -32,7 +79,7 @@ function startup() { systemctl disable apt-daily-upgrade.service # Don't install documentation (except copyrights) since this is a CI system. - cat > /etc/dpkg/dpkg.cfg.d/github-actions < /etc/dpkg/dpkg.cfg.d/99-github-actions < /etc/apt/apt.conf.d/github-actions < /etc/apt/apt.conf.d/99-github-actions < /etc/sudoers.d/runner - - + ############################## Fix gcloud Installation Snap ############################### + + # Snap literally won't let you disable automatic updates. The only thing + # that's installed through snap here is the gcloud CLI, which we definitely + # don't want automatically updating (beyond our general desire to not + # automatically update on ephemeral machines). So we just delete snap entirely + # and install the CLI via apt (above) + systemctl stop snapd + apt_maybe_purge snapd gnome-software-plugin-snap + rm -rf /home/*/snap + rm -rf /root/snap + + curl --silent --fail --show-error --location \ + https://packages.cloud.google.com/apt/doc/apt-key.gpg \ + | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg + echo \ + "deb [arch=amd64 signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \ + > /etc/apt/sources.list.d/google-cloud-sdk.list + apt-get update && apt-get install google-cloud-cli + + # This setting is now enabled by default. It sounds great, but unfortunately + # doing such an upload requires *delete* permissions on the bucket, which we + # deliberately do not give runners. For the life of me, I could not figure out + # how to use `gcloud config set` (the "proper" way to set properties) to work + # on the global properties. + cat <> /usr/lib/google-cloud-sdk/properties +[storage] +parallel_composite_upload_enabled = False +EOF ############################### Install Docker ############################### - # Remove Docker stuff that may already be installed, proceeding if they're not. - apt-get remove containerd docker docker-engine docker.io moby-engine moby-cli runc || true + # Remove Docker stuff that may already be installed by all its various names + apt_maybe_purge containerd docker docker-engine docker.io moby-engine moby-cli runc # Install the latest Docker - curl -sfSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg + curl --silent --fail --show-error --location \ + https://download.docker.com/linux/ubuntu/gpg \ + | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg echo \ - "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list + "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" \ + > /etc/apt/sources.list.d/docker.list apt-get update apt-get install docker-ce docker-ce-cli containerd.io @@ -134,7 +197,7 @@ EOF # Make sure the runner user can use docker runuser --user runner -- docker ps - ################################### Cleanup #################################### + ################################### Cleanup ################################## apt-get clean rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* @@ -150,7 +213,9 @@ EOF # And clear others find /var/log/ -type f -exec truncate -s 0 {} \; - # This specific log line is load bearing, as it's referenced in create_image.sh + echo "Disk usage after setup" + df -h / + echo "Setup complete" } diff --git a/build_tools/github_actions/runner/gcp/update_instance_groups.py b/build_tools/github_actions/runner/gcp/update_instance_groups.py index b5abdc57613d..bf174faf92a6 100755 --- a/build_tools/github_actions/runner/gcp/update_instance_groups.py +++ b/build_tools/github_actions/runner/gcp/update_instance_groups.py @@ -274,6 +274,7 @@ def parse_args(): " https://cloud.google.com/compute/docs/instance-groups/updating-migs." )) subparser_base.add_argument("--env", + "--environment", default="testing", help="The environment for the MIGs.", choices=["prod", "testing"]) diff --git a/build_tools/benchmarks/suites/models/model_groups.py b/build_tools/python/CMakeLists.txt similarity index 57% rename from build_tools/benchmarks/suites/models/model_groups.py rename to build_tools/python/CMakeLists.txt index e788493fdaff..7c892dee9746 100644 --- a/build_tools/benchmarks/suites/models/model_groups.py +++ b/build_tools/python/CMakeLists.txt @@ -1,10 +1,7 @@ -## Copyright 2022 The IREE Authors +# Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Defines the groups of models.""" -from . import tflite_models - -MOBILE = [tflite_models.MOBILENET_V2] +add_subdirectory(e2e_test_framework) diff --git a/build_tools/benchmarks/suites/__init__.py b/build_tools/python/__init__.py similarity index 100% rename from build_tools/benchmarks/suites/__init__.py rename to build_tools/python/__init__.py diff --git a/build_tools/benchmarks/suites/definitions/__init__.py b/build_tools/python/benchmarks/__init__.py similarity index 100% rename from build_tools/benchmarks/suites/definitions/__init__.py rename to build_tools/python/benchmarks/__init__.py diff --git a/build_tools/benchmarks/suites/device_specs/__init__.py b/build_tools/python/benchmarks/iree/__init__.py similarity index 100% rename from build_tools/benchmarks/suites/device_specs/__init__.py rename to build_tools/python/benchmarks/iree/__init__.py diff --git a/build_tools/benchmarks/suites/iree_benchmarks.py b/build_tools/python/benchmarks/iree/definitions.py similarity index 53% rename from build_tools/benchmarks/suites/iree_benchmarks.py rename to build_tools/python/benchmarks/iree/definitions.py index 82840f97dc79..e2650fb8f77d 100644 --- a/build_tools/benchmarks/suites/iree_benchmarks.py +++ b/build_tools/python/benchmarks/iree/definitions.py @@ -6,16 +6,34 @@ """Defines IREE benchmarks.""" import itertools -from typing import List, Tuple +from typing import List, Sequence, Tuple -from .device_specs import linux_x86_64_specs -from .models import model_groups -from .definitions import common_definitions, iree_definitions -from . import unique_ids +from e2e_test_framework.device_specs import linux_x86_64_specs +from e2e_test_framework.models import model_groups +from e2e_test_framework.definitions import common_definitions, iree_definitions +from e2e_test_framework import unique_ids MODULE_BENCHMARK_TOOL = "iree-benchmark-module" +def _generate_run_specs( + compile_specs: Sequence[iree_definitions.CompileSpec], + run_configs: Sequence[iree_definitions.RunConfig], + device_spec: common_definitions.DeviceSpec, + input_data: common_definitions.ModelInputData = common_definitions. + RANDOM_MODEL_INPUT_DATA, +) -> List[iree_definitions.RunSpec]: + """Generates the run specs from the product of compile specs and run configs. + """ + return [ + iree_definitions.RunSpec(compile_spec=compile_spec, + run_config=run_config, + target_device_spec=device_spec, + input_data=input_data) for compile_spec, + run_config in itertools.product(compile_specs, run_configs) + ] + + class Linux_x86_64_Benchmarks(object): """Benchmarks on x86_64 linux devices.""" @@ -33,31 +51,23 @@ class Linux_x86_64_Benchmarks(object): @classmethod def generate( cls - ) -> Tuple[List[iree_definitions.BenchmarkCompileSpec], - List[iree_definitions.BenchmarkRunSpec]]: + ) -> Tuple[List[iree_definitions.CompileSpec], + List[iree_definitions.RunSpec]]: """Generates IREE compile and run specs.""" default_run_configs = cls._generate_default_run_configs() - # Generate compile specs for mobile models. - mobile_model_compile_specs = [ - iree_definitions.BenchmarkCompileSpec( + compile_specs = [ + iree_definitions.CompileSpec( compile_config=cls.CASCADELAKE_COMPILE_CONFIG, model=model) - for model in model_groups.MOBILE + for model in model_groups.SMALL + model_groups.LARGE ] + run_specs = _generate_run_specs( + compile_specs=compile_specs, + run_configs=default_run_configs, + device_spec=linux_x86_64_specs.GCP_C2_STANDARD_16) - # Generate run specs for mobile models. - mobile_model_run_specs = [] - for compile_spec, run_config in itertools.product( - mobile_model_compile_specs, default_run_configs): - mobile_model_run_specs.append( - iree_definitions.BenchmarkRunSpec( - compile_spec=compile_spec, - run_config=run_config, - target_device_spec=linux_x86_64_specs.GCP_C2_STANDARD_16, - input_data=common_definitions.RANDOM_MODEL_INPUT_DATA)) - - return (mobile_model_compile_specs, mobile_model_run_specs) + return (compile_specs, run_specs) @staticmethod def _generate_default_run_configs() -> List[iree_definitions.RunConfig]: @@ -68,7 +78,7 @@ def _generate_default_run_configs() -> List[iree_definitions.RunConfig]: tags=["full-inference", "default-flags"], loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, driver=iree_definitions.RuntimeDriver.LOCAL_SYNC, - benchmark_tool=MODULE_BENCHMARK_TOOL) + tool=MODULE_BENCHMARK_TOOL) ] for thread_num in [1, 4, 8]: run_configs.append( @@ -77,6 +87,12 @@ def _generate_default_run_configs() -> List[iree_definitions.RunConfig]: tags=[f"{thread_num}-thread", "full-inference", "default-flags"], loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, driver=iree_definitions.RuntimeDriver.LOCAL_TASK, - benchmark_tool=MODULE_BENCHMARK_TOOL, + tool=MODULE_BENCHMARK_TOOL, extra_flags=[f"--task_topology_group_count={thread_num}"])) return run_configs + + +def generate( +) -> Tuple[List[iree_definitions.CompileSpec], List[iree_definitions.RunSpec]]: + """Generates all compile and run specs for IREE benchmarks.""" + return Linux_x86_64_Benchmarks.generate() diff --git a/build_tools/benchmarks/suites/CMakeLists.txt b/build_tools/python/e2e_test_framework/CMakeLists.txt similarity index 60% rename from build_tools/benchmarks/suites/CMakeLists.txt rename to build_tools/python/e2e_test_framework/CMakeLists.txt index ce00b7085ea5..9c73e8674e1f 100644 --- a/build_tools/benchmarks/suites/CMakeLists.txt +++ b/build_tools/python/e2e_test_framework/CMakeLists.txt @@ -4,11 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -################################################################################ -# Tests -################################################################################ - -benchmark_tool_py_test( +iree_build_tools_py_test( NAME cmake_rule_generator_test SRC diff --git a/build_tools/benchmarks/suites/models/__init__.py b/build_tools/python/e2e_test_framework/__init__.py similarity index 100% rename from build_tools/benchmarks/suites/models/__init__.py rename to build_tools/python/e2e_test_framework/__init__.py diff --git a/build_tools/benchmarks/suites/cmake_rule_generator.py b/build_tools/python/e2e_test_framework/cmake_rule_generator.py similarity index 88% rename from build_tools/benchmarks/suites/cmake_rule_generator.py rename to build_tools/python/e2e_test_framework/cmake_rule_generator.py index 3075ac87a591..5c2af7f79e1f 100644 --- a/build_tools/benchmarks/suites/cmake_rule_generator.py +++ b/build_tools/python/e2e_test_framework/cmake_rule_generator.py @@ -9,31 +9,31 @@ """ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Sequence import os import pathlib import string import urllib.parse -from .definitions import common_definitions, iree_definitions -from . import iree_benchmarks - -TEMPLATE_DIR = pathlib.Path(__file__).parent +from e2e_test_framework.definitions import common_definitions, iree_definitions -def read_template_from_file(template_name: str) -> string.Template: - with open(TEMPLATE_DIR / template_name, "r") as f: - return string.Template(f.read()) +def read_template_from_file(template_path: pathlib.Path) -> string.Template: + return string.Template(template_path.read_text()) +TEMPLATE_DIR = pathlib.Path(__file__).parent DOWNLOAD_ARTIFACT_CMAKE_TEMPLATE = read_template_from_file( - "iree_download_artifact_template.cmake") + TEMPLATE_DIR / "iree_download_artifact_template.cmake") TFLITE_IMPORT_CMAKE_TEMPLATE = read_template_from_file( - "iree_tflite_import_template.cmake") + TEMPLATE_DIR / "iree_tflite_import_template.cmake") TF_IMPORT_CMAKE_TEMPLATE = read_template_from_file( - "iree_tf_import_template.cmake") + TEMPLATE_DIR / "iree_tf_import_template.cmake") IREE_BYTECODE_MODULE_CMAKE_TEMPLATE = read_template_from_file( - "iree_bytecode_module_template.cmake") + TEMPLATE_DIR / "iree_bytecode_module_template.cmake") + +# Archive extensions used to pack models. +ARCHIVE_FILE_EXTENSIONS = [".tar", ".gz"] @dataclass @@ -82,9 +82,15 @@ def add_model_rule(self, model: common_definitions.Model) -> ModelRule: target_name = f"model-{model.id}" model_url = urllib.parse.urlparse(model.source_url) - _, file_ext = os.path.splitext(model_url.path) - # Model path: /_. - model_path = f"{self._model_artifacts_dir}/{model.id}_{model.name}{file_ext}" + + # Drop the archive extensions. + file_exts = pathlib.PurePath(model_url.path).suffixes + while len(file_exts) > 0 and file_exts[-1] in ARCHIVE_FILE_EXTENSIONS: + file_exts.pop() + model_ext = "".join(file_exts) + + # Model path: /_ + model_path = f"{self._model_artifacts_dir}/{model.id}_{model.name}{model_ext}" if model_url.scheme == "https": cmake_rule = DOWNLOAD_ARTIFACT_CMAKE_TEMPLATE.substitute( @@ -259,10 +265,10 @@ def _generate_iree_compile_target_flags( return flags -def _generate_iree_benchmark_rules(common_rule_factory: CommonRuleFactory, - iree_artifacts_dir: str) -> List[str]: +def _generate_iree_rules( + common_rule_factory: CommonRuleFactory, iree_artifacts_dir: str, + compile_specs: Sequence[iree_definitions.CompileSpec]) -> List[str]: iree_rule_factory = IreeRuleFactory(iree_artifacts_dir) - compile_specs, _ = iree_benchmarks.Linux_x86_64_Benchmarks.generate() for compile_spec in compile_specs: model = compile_spec.model compile_config = compile_spec.compile_config @@ -280,8 +286,9 @@ def _generate_iree_benchmark_rules(common_rule_factory: CommonRuleFactory, return iree_rule_factory.generate_cmake_rules() -def generate_benchmark_rules(model_artifacts_dir: str, - iree_artifacts_dir: str) -> List[str]: +def generate_rules( + model_artifacts_dir: str, iree_artifacts_dir: str, + iree_compile_specs: Sequence[iree_definitions.CompileSpec]) -> List[str]: """Generates cmake rules to build benchmarks. Args: @@ -289,12 +296,13 @@ def generate_benchmark_rules(model_artifacts_dir: str, variable syntax in the path. iree_artifacts_dir: root directory to store generated IREE artifacts. Can contain CMake variable syntax in the path. + iree_compile_specs: compile specs for IREE targets. Returns: List of CMake rules. """ common_rule_factory = CommonRuleFactory(model_artifacts_dir) - iree_rules = _generate_iree_benchmark_rules(common_rule_factory, - iree_artifacts_dir) + iree_rules = _generate_iree_rules(common_rule_factory, iree_artifacts_dir, + iree_compile_specs) # Currently the rules are simple so the common rules can be always put at the # top. Need a topological sort once the dependency gets complicated. return common_rule_factory.generate_cmake_rules() + iree_rules diff --git a/build_tools/benchmarks/suites/cmake_rule_generator_test.py b/build_tools/python/e2e_test_framework/cmake_rule_generator_test.py similarity index 97% rename from build_tools/benchmarks/suites/cmake_rule_generator_test.py rename to build_tools/python/e2e_test_framework/cmake_rule_generator_test.py index 3ade14f0bab2..f7ca945a82f7 100644 --- a/build_tools/benchmarks/suites/cmake_rule_generator_test.py +++ b/build_tools/python/e2e_test_framework/cmake_rule_generator_test.py @@ -5,8 +5,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from suites.definitions import common_definitions, iree_definitions -from suites import cmake_rule_generator +from e2e_test_framework.definitions import common_definitions, iree_definitions +from e2e_test_framework import cmake_rule_generator import unittest diff --git a/build_tools/python/e2e_test_framework/definitions/__init__.py b/build_tools/python/e2e_test_framework/definitions/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/build_tools/benchmarks/suites/definitions/common_definitions.py b/build_tools/python/e2e_test_framework/definitions/common_definitions.py similarity index 100% rename from build_tools/benchmarks/suites/definitions/common_definitions.py rename to build_tools/python/e2e_test_framework/definitions/common_definitions.py diff --git a/build_tools/benchmarks/suites/definitions/iree_definitions.py b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py similarity index 87% rename from build_tools/benchmarks/suites/definitions/iree_definitions.py rename to build_tools/python/e2e_test_framework/definitions/iree_definitions.py index 9ecce2eceffa..627df01c49bc 100644 --- a/build_tools/benchmarks/suites/definitions/iree_definitions.py +++ b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py @@ -3,14 +3,14 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Classes for IREE benchmark definitions.""" +"""Classes for IREE compilation and run definitions.""" import dataclasses from dataclasses import dataclass from enum import Enum from typing import List -from . import common_definitions +from e2e_test_framework.definitions import common_definitions class TargetBackend(Enum): @@ -62,21 +62,21 @@ class RunConfig(object): tags: List[str] loader: RuntimeLoader driver: RuntimeDriver - benchmark_tool: str + tool: str extra_flags: List[str] = dataclasses.field(default_factory=list) @dataclass(frozen=True) -class BenchmarkCompileSpec(object): +class CompileSpec(object): """Describes a compile target to generate the module.""" compile_config: CompileConfig model: common_definitions.Model @dataclass(frozen=True) -class BenchmarkRunSpec(object): - """Describes a run target to be benchmarked.""" - compile_spec: BenchmarkCompileSpec +class RunSpec(object): + """Describes a run target.""" + compile_spec: CompileSpec run_config: RunConfig target_device_spec: common_definitions.DeviceSpec input_data: common_definitions.ModelInputData diff --git a/build_tools/python/e2e_test_framework/device_specs/__init__.py b/build_tools/python/e2e_test_framework/device_specs/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/build_tools/benchmarks/suites/device_specs/linux_x86_64_specs.py b/build_tools/python/e2e_test_framework/device_specs/linux_x86_64_specs.py similarity index 84% rename from build_tools/benchmarks/suites/device_specs/linux_x86_64_specs.py rename to build_tools/python/e2e_test_framework/device_specs/linux_x86_64_specs.py index 2366344ab746..9ae4c1c3974e 100644 --- a/build_tools/benchmarks/suites/device_specs/linux_x86_64_specs.py +++ b/build_tools/python/e2e_test_framework/device_specs/linux_x86_64_specs.py @@ -5,8 +5,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Defines x86_64 linux devices.""" -from .. import unique_ids -from ..definitions import common_definitions +from e2e_test_framework import unique_ids +from e2e_test_framework.definitions import common_definitions GCP_C2_STANDARD_16 = common_definitions.DeviceSpec( id=unique_ids.DEVICE_SPEC_GCP_C2_STANDARD_16, diff --git a/build_tools/benchmarks/suites/iree_bytecode_module_template.cmake b/build_tools/python/e2e_test_framework/iree_bytecode_module_template.cmake similarity index 100% rename from build_tools/benchmarks/suites/iree_bytecode_module_template.cmake rename to build_tools/python/e2e_test_framework/iree_bytecode_module_template.cmake diff --git a/build_tools/python/e2e_test_framework/iree_download_artifact_template.cmake b/build_tools/python/e2e_test_framework/iree_download_artifact_template.cmake new file mode 100644 index 000000000000..3c740898b9e8 --- /dev/null +++ b/build_tools/python/e2e_test_framework/iree_download_artifact_template.cmake @@ -0,0 +1,10 @@ +# Fetch the model from "$__SOURCE_URL" +iree_fetch_artifact( + NAME + "$__TARGET_NAME" + SOURCE_URL + "$__SOURCE_URL" + OUTPUT + "$__OUTPUT_PATH" + UNPACK +) diff --git a/build_tools/benchmarks/suites/iree_tf_import_template.cmake b/build_tools/python/e2e_test_framework/iree_tf_import_template.cmake similarity index 60% rename from build_tools/benchmarks/suites/iree_tf_import_template.cmake rename to build_tools/python/e2e_test_framework/iree_tf_import_template.cmake index c8b33d0e9b0a..364e4489c2bd 100644 --- a/build_tools/benchmarks/suites/iree_tf_import_template.cmake +++ b/build_tools/python/e2e_test_framework/iree_tf_import_template.cmake @@ -5,3 +5,5 @@ iree_import_tf_model( ENTRY_FUNCTION "$__ENTRY_FUNCTION" OUTPUT_MLIR_FILE "$__OUTPUT_PATH" ) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "$${_PACKAGE_NAME}_$__TARGET_NAME") diff --git a/build_tools/benchmarks/suites/iree_tflite_import_template.cmake b/build_tools/python/e2e_test_framework/iree_tflite_import_template.cmake similarity index 56% rename from build_tools/benchmarks/suites/iree_tflite_import_template.cmake rename to build_tools/python/e2e_test_framework/iree_tflite_import_template.cmake index 0e05e42c9336..82616f9fb559 100644 --- a/build_tools/benchmarks/suites/iree_tflite_import_template.cmake +++ b/build_tools/python/e2e_test_framework/iree_tflite_import_template.cmake @@ -4,3 +4,5 @@ iree_import_tflite_model( SOURCE "$__SOURCE_MODEL_PATH" OUTPUT_MLIR_FILE "$__OUTPUT_PATH" ) +# Mark dependency so users can import models without compiling them. +add_dependencies(iree-benchmark-import-models "$${_PACKAGE_NAME}_$__TARGET_NAME") diff --git a/build_tools/python/e2e_test_framework/models/__init__.py b/build_tools/python/e2e_test_framework/models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/build_tools/python/e2e_test_framework/models/model_groups.py b/build_tools/python/e2e_test_framework/models/model_groups.py new file mode 100644 index 000000000000..43d23b71d72e --- /dev/null +++ b/build_tools/python/e2e_test_framework/models/model_groups.py @@ -0,0 +1,28 @@ +## Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Defines the groups of models.""" + +from e2e_test_framework.models import tf_models, tflite_models + +# Small models that require less computational resources. +SMALL = [ + tflite_models.DEEPLABV3_FP32, + tflite_models.MOBILESSD_FP32, + tflite_models.POSENET_FP32, + tflite_models.MOBILEBERT_FP32, + tflite_models.MOBILEBERT_INT8, + tflite_models.MOBILEBERT_FP16, + tflite_models.MOBILENET_V1, + tflite_models.MOBILENET_V2, + tflite_models.MOBILENET_V3SMALL, + tflite_models.PERSON_DETECT_INT8, + tflite_models.EFFICIENTNET_INT8, +] + +# Large models that require more computational resources. +LARGE = [ + tf_models.MINILM_L12_H384_UNCASED_INT32_SEQLEN128, +] diff --git a/build_tools/python/e2e_test_framework/models/tf_models.py b/build_tools/python/e2e_test_framework/models/tf_models.py new file mode 100644 index 000000000000..138296a7a642 --- /dev/null +++ b/build_tools/python/e2e_test_framework/models/tf_models.py @@ -0,0 +1,20 @@ +## Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Defines Tensorflow models.""" + +from e2e_test_framework import unique_ids +from e2e_test_framework.definitions import common_definitions + +MINILM_L12_H384_UNCASED_INT32_SEQLEN128 = common_definitions.Model( + id=unique_ids.MODEL_MINILM_L12_H384_UNCASED_INT32_SEQLEN128, + name="MiniLML12H384Uncased", + tags=["int32", "seqlen128"], + source_type=common_definitions.ModelSourceType.EXPORTED_TF, + # Converted from https://huggingface.co/microsoft/MiniLM-L12-H384-uncased/commit/44acabbec0ef496f6dbc93adadea57f376b7c0ec + source_url= + "https://storage.googleapis.com/iree-model-artifacts/minilm-l12-h384-uncased-seqlen128-tf-model.tar.gz", + entry_function="predict", + input_types=["1x128xi32", "1x128xi32", "1x128xi32"]) diff --git a/build_tools/python/e2e_test_framework/models/tflite_models.py b/build_tools/python/e2e_test_framework/models/tflite_models.py new file mode 100644 index 000000000000..4de8ba2ba715 --- /dev/null +++ b/build_tools/python/e2e_test_framework/models/tflite_models.py @@ -0,0 +1,131 @@ +## Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Defines TFLite models.""" + +from e2e_test_framework import unique_ids +from e2e_test_framework.definitions import common_definitions + +DEEPLABV3_FP32 = common_definitions.Model( + id=unique_ids.MODEL_DEEPLABV3_FP32, + name="DeepLabV3_fp32", + tags=["fp32"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1 + source_url= + "https://storage.googleapis.com/iree-model-artifacts/deeplabv3.tflite", + entry_function="main", + input_types=["1x257x257x3xf32"]) + +MOBILESSD_FP32 = common_definitions.Model( + id=unique_ids.MODEL_MOBILESSD_FP32, + name="MobileSSD_fp32", + tags=["fp32"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/mobile_ssd_v2_float_coco.tflite + source_url= + "https://storage.googleapis.com/iree-model-artifacts/mobile_ssd_v2_float_coco.tflite", + entry_function="main", + input_types=["1x320x320x3xf32"]) + +POSENET_FP32 = common_definitions.Model( + id=unique_ids.MODEL_POSENET_FP32, + name="PoseNet_fp32", + tags=["fp32"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://tfhub.dev/tensorflow/lite-model/posenet/mobilenet/float/075/1/default/1 + source_url= + "https://storage.googleapis.com/iree-model-artifacts/posenet.tflite", + entry_function="main", + input_types=["1x353x257x3xf32"]) + +MOBILEBERT_FP32 = common_definitions.Model( + id=unique_ids.MODEL_MOBILEBERT_FP32, + name="MobileBertSquad_fp32", + tags=["fp32"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://tfhub.dev/iree/lite-model/mobilebert/fp32/1 + source_url= + "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-float.tflite", + entry_function="main", + input_types=["1x384xi32", "1x384xi32", "1x384xi32"]) + +MOBILEBERT_INT8 = common_definitions.Model( + id=unique_ids.MODEL_MOBILEBERT_INT8, + name="MobileBertSquad_int8", + tags=["int8"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://tfhub.dev/iree/lite-model/mobilebert/int8/1 + source_url= + "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-quant.tflite", + entry_function="main", + input_types=["1x384xi32", "1x384xi32", "1x384xi32"]) + +MOBILEBERT_FP16 = common_definitions.Model( + id=unique_ids.MODEL_MOBILEBERT_FP16, + name="MobileBertSquad_fp16", + tags=["fp16"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 + source_url= + "https://storage.googleapis.com/iree-model-artifacts/mobilebertsquad.tflite", + entry_function="main", + input_types=["1x384xi32", "1x384xi32", "1x384xi32"]) + +MOBILENET_V1 = common_definitions.Model( + id=unique_ids.MODEL_MOBILENET_V1, + name="MobileNetV1_fp32", + tags=["fp32", "imagenet"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://tfhub.dev/iree/lite-model/mobilenet_v1_100_224/fp32/1 + source_url= + "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v1_224_1.0_float.tflite", + entry_function="main", + input_types=["1x224x224x3xf32"]) + +MOBILENET_V2 = common_definitions.Model( + id=unique_ids.MODEL_MOBILENET_V2, + name="MobileNetV2_fp32", + tags=["fp32", "imagenet"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite + source_url= + "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite", + entry_function="main", + input_types=["1x224x224x3xf32"]) + +MOBILENET_V3SMALL = common_definitions.Model( + id=unique_ids.MODEL_MOBILENET_V3SMALL, + name="MobileNetV3Small_fp32", + tags=["fp32", "imagenet"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # https://tfhub.dev/google/imagenet/mobilenet_v3_small_100_224/classification/5 + # Manually exported to tflite with static batch dimension + source_url= + "https://storage.googleapis.com/iree-model-artifacts/MobileNetV3SmallStaticBatch.tflite", + entry_function="main", + input_types=["1x224x224x3xf32"]) + +PERSON_DETECT_INT8 = common_definitions.Model( + id=unique_ids.MODEL_PERSON_DETECT_INT8, + name="PersonDetect_int8", + tags=["int8"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite + source_url= + "https://storage.googleapis.com/iree-model-artifacts/person_detect.tflite", + entry_function="main", + input_types=["1x96x96x1xi8"]) + +EFFICIENTNET_INT8 = common_definitions.Model( + id=unique_ids.MODEL_EFFICIENTNET_INT8, + name="EfficientNet_int8", + tags=["int8"], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + # Mirror of https://tfhub.dev/tensorflow/lite-model/efficientnet/lite0/int8/2 + source_url= + "https://storage.googleapis.com/iree-model-artifacts/efficientnet_lite0_int8_2.tflite", + entry_function="main", + input_types=["1x224x224x3xui8"]) diff --git a/build_tools/python/e2e_test_framework/unique_ids.py b/build_tools/python/e2e_test_framework/unique_ids.py new file mode 100644 index 000000000000..cfb749f6ddce --- /dev/null +++ b/build_tools/python/e2e_test_framework/unique_ids.py @@ -0,0 +1,31 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""List of unique random IDs in the benchmark suites. + +Each ID should be generated from uuid.uuid4(). +""" + +# Models +MODEL_DEEPLABV3_FP32 = "c36c63b0-220a-4d78-8ade-c45ce47d89d3" +MODEL_MOBILESSD_FP32 = "0e466f69-91d6-4e50-b62b-a82b6213a231" +MODEL_POSENET_FP32 = "5afc3014-d29d-4e88-a840-fbaf678acf2b" +MODEL_MOBILEBERT_FP32 = "cc69d69f-6d1f-4a1a-a31e-e021888d0d28" +MODEL_MOBILEBERT_INT8 = "e3997104-a3d2-46b4-9fbf-39069906d123" +MODEL_MOBILEBERT_FP16 = "73a0402e-271b-4aa8-a6a5-ac05839ca569" +MODEL_MOBILENET_V1 = "78eab9e5-9ff1-4769-9b55-933c81cc9a0f" +MODEL_MOBILENET_V2 = "7d45f8e5-bb5e-48d0-928d-8f125104578f" +MODEL_MOBILENET_V3SMALL = "58855e40-eba9-4a71-b878-6b35e3460244" +MODEL_PERSON_DETECT_INT8 = "bc1338be-e3df-44fd-82e4-40ba9560a073" +MODEL_EFFICIENTNET_INT8 = "4a6f545e-1b4e-41a5-9236-792aa578184b" +MODEL_MINILM_L12_H384_UNCASED_INT32_SEQLEN128 = "ecf5c970-ee97-49f0-a4ed-df1f34e9d493" + +# Devices +DEVICE_SPEC_GCP_C2_STANDARD_16 = "9a4804f1-b1b9-46cd-b251-7f16a655f782" + +# IREE benchmarks +IREE_COMPILE_CONFIG_LINUX_CASCADELAKE = "e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" +IREE_RUN_CONFIG_LOCAL_SYNC = "13fc65a9-e5dc-4cbb-9c09-25b0b08f4c03" +IREE_RUN_CONFIG_LOCAL_TASK_BASE = "c7c4a15e-b20c-4898-bb4a-864f34ff34b2" diff --git a/build_tools/scripts/download_file.py b/build_tools/scripts/download_file.py index 492ad1ca6e2a..7e7a6898cadc 100755 --- a/build_tools/scripts/download_file.py +++ b/build_tools/scripts/download_file.py @@ -34,6 +34,10 @@ def parse_arguments(): required=True, metavar="", help="Output file path") + parser.add_argument("--unpack", + action='store_true', + default=False, + help="Unpack the downloaded file if it's an archive.") return parser.parse_args() @@ -50,24 +54,26 @@ def main(args): f"Failed to download file with status {response.status} {response.msg}" ) - if args.source_url.endswith(".tar.gz"): - # Open tar.gz in the streaming mode. - with tarfile.open(fileobj=response, mode="r|*") as tar_file: - if os.path.exists(args.output): - shutil.rmtree(args.output) - os.makedirs(args.output) - tar_file.extractall(args.output) + if args.unpack: + if args.source_url.endswith(".tar.gz"): + # Open tar.gz in the streaming mode. + with tarfile.open(fileobj=response, mode="r|*") as tar_file: + if os.path.exists(args.output): + shutil.rmtree(args.output) + os.makedirs(args.output) + tar_file.extractall(args.output) + return + elif args.source_url.endswith(".gz"): + # Open gzip from a file-like object, which will be in the streaming mode. + with gzip.open(filename=response, mode="rb") as input_file: + with open(args.output, "wb") as output_file: + shutil.copyfileobj(input_file, output_file) + return - elif args.source_url.endswith(".gz"): - # Open gzip from a file-like object, which will be in the streaming mode. - with gzip.open(filename=response, mode="rb") as input_file: - with open(args.output, "wb") as output_file: - shutil.copyfileobj(input_file, output_file) - - else: - with open(args.output, "wb") as output_file: - # Streaming copy. - shutil.copyfileobj(response, output_file) + # Fallback to download the file only. + with open(args.output, "wb") as output_file: + # Streaming copy. + shutil.copyfileobj(response, output_file) if __name__ == "__main__": diff --git a/build_tools/scripts/integrate/README.md b/build_tools/scripts/integrate/README.md index 43788d7e36b8..c035a57e83a7 100644 --- a/build_tools/scripts/integrate/README.md +++ b/build_tools/scripts/integrate/README.md @@ -351,8 +351,8 @@ under docker, we can find the hash from CI log. An example from a log: ``` -[18:30:23 UTC] docker run --volume=/tmpfs/src/github/iree:/tmpfs/src/github/iree --workdir=/tmpfs/src/github/iree --rm --user=1003:1004 --volume=/tmpfs/fake_etc/group:/etc/group:ro --volume=/tmpfs/fake_etc/passwd:/etc/passwd:ro --volume=/tmpfs/fake_home:/home/kbuilder --volume=/home/kbuilder/.config/gcloud:/home/kbuilder/.config/gcloud:ro gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build.sh -Unable to find image 'gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01' locally +[18:30:23 UTC] docker run --volume=/tmpfs/src/github/iree:/tmpfs/src/github/iree --workdir=/tmpfs/src/github/iree --rm --user=1003:1004 --volume=/tmpfs/fake_etc/group:/etc/group:ro --volume=/tmpfs/fake_etc/passwd:/etc/passwd:ro --volume=/tmpfs/fake_home:/home/kbuilder --volume=/home/kbuilder/.config/gcloud:/home/kbuilder/.config/gcloud:ro gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build.sh +Unable to find image 'gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9' locally sha256:aeb8de9fb7af3913d385ec6b274320197d61aa7bc51a6e8bc0deba644da3e405: Pulling from iree-oss/frontends-swiftshader ``` @@ -360,7 +360,7 @@ You can find the hash tag from log and run the below command. It makes sure that you have the enviroment as same as CI bot and requires less local setup. ``` -docker run --interactive --tty --rm --volume=$PWD:/src/iree --workdir=/src/iree gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 +docker run --interactive --tty --rm --volume=$PWD:/src/iree --workdir=/src/iree gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 ``` To repro failures in `iree/e2e/`: diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD b/compiler/src/iree/compiler/Codegen/Common/BUILD index 31c60d3f2ce7..4641153ed84f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD @@ -180,16 +180,21 @@ iree_compiler_cc_library( "GPUDistributeSharedMemoryCopy.cpp", "GPUPipelining.cpp", "GPUVectorization.cpp", + "LinalgOpInfo.cpp", "MemrefCopyToLinalg.cpp", "PadDynamicAlloc.cpp", "RemoveTrivialLoops.cpp", "TestPartitionableLoopsInterface.cpp", "TileAndDistributeToWorkgroupsPass.cpp", "TileDispatchUsingInterface.cpp", + "UserConfig.cpp", "VectorReductionToGPU.cpp", - "VectorizeConv.cpp", "WorkGroupSwizzle.cpp", ], + hdrs = [ + "LinalgOpInfo.h", + "UserConfig.h", + ], deps = [ ":CommonPasses", ":TransformDialectInterpreterPass", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 59a6dac83f70..9e85fa2e6ddd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -145,20 +145,24 @@ iree_cc_library( iree_cc_library( NAME Common + HDRS + "LinalgOpInfo.h" + "UserConfig.h" SRCS "DecomposeLinalgGeneric.cpp" "FoldAffineMinInDistributedLoops.cpp" "GPUDistributeSharedMemoryCopy.cpp" "GPUPipelining.cpp" "GPUVectorization.cpp" + "LinalgOpInfo.cpp" "MemrefCopyToLinalg.cpp" "PadDynamicAlloc.cpp" "RemoveTrivialLoops.cpp" "TestPartitionableLoopsInterface.cpp" "TileAndDistributeToWorkgroupsPass.cpp" "TileDispatchUsingInterface.cpp" + "UserConfig.cpp" "VectorReductionToGPU.cpp" - "VectorizeConv.cpp" "WorkGroupSwizzle.cpp" DEPS ::CommonPasses diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp index dc8d509dc0cc..ccd9ff1b0cd1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp @@ -254,7 +254,11 @@ static Value linearizeIndices(Value sourceValue, ValueRange indices, // First try to get the strides from the MemRef type itself. This applies to // cases where we have static shapes and only the leading dimension is // dynamic. - if (AffineMap linearLayoutMap = getStridedLinearLayoutMap(sourceType)) { + SmallVector strides; + int64_t offset; + if (succeeded(getStridesAndOffset(sourceType, strides, offset))) { + AffineMap linearLayoutMap = + makeStridedLinearLayoutMap(strides, offset, builder.getContext()); // Dynamic strides/offset will create symbols. There should be none for the // static case. if (linearLayoutMap.getNumSymbols() == 0) { diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp index fe95881b3df2..505c0eecb9cc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp @@ -14,9 +14,11 @@ #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" @@ -24,14 +26,29 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#define DEBUG_TYPE "iree-gpu-distribute-shared-memory-copy" + using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern; using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns; +/// Prints the given `funcOp` after a leading `step` comment header. +void debugPrint(mlir::func::FuncOp funcOp, const char *step) { + LLVM_DEBUG({ + llvm::dbgs() << "//--- " << step << " ---//\n"; + funcOp.print(llvm::dbgs(), mlir::OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); +} + //====---------------------------------------------------------------------===// // Pass to lower workgroup memory copy to distibuted // transfer_read/transfer_write ops. //====---------------------------------------------------------------------===// +// Markers for intermediate transformations. +static const llvm::StringRef kCopyToDistribute = "copy_to_distribute"; +static const llvm::StringRef kCopyDistributed = "copy_distributed"; + namespace mlir { namespace iree_compiler { @@ -49,14 +66,15 @@ static void populateTilingCopyToWorkgroupMemPatterns( // We tile to 4 as we want each thread to load 4 element in a cyclic // distribution. SmallVector tileSizesVal; - MemRefType lhsMemRefType = cast(operation) - .getOperand(0) + MemRefType dstMemRefType = cast(operation) + .getOutputOperand(0) + ->get() .getType() .cast(); - unsigned rank = lhsMemRefType.getRank(); + unsigned rank = dstMemRefType.getRank(); int copyTileSize = - copyVectorNumBits / lhsMemRefType.getElementTypeBitWidth(); + copyVectorNumBits / dstMemRefType.getElementTypeBitWidth(); for (unsigned i = 0; i < rank - 1; i++) { int64_t t = (rank - i) <= kNumGPUDims ? 1 : 0; tileSizesVal.push_back( @@ -89,26 +107,18 @@ static void populateTilingCopyToWorkgroupMemPatterns( StringAttr::get(patterns.getContext(), getVectorizeMarker()))); } -static void populateVectorizationPatterns(RewritePatternSet &patterns) { - VectorizationPatterns::insert( - patterns, linalg::LinalgVectorizationOptions(), - linalg::LinalgTransformationFilter(StringAttr::get( - patterns.getContext(), getCopyToWorkgroupMemoryMarker()))); -} - -/// Compute a vector size so that the numer of elements is equal to the flat +/// Compute a tile size so that the numer of iteraton is equal to the flat /// workgroup size. -static Optional> getGPUNativeVectorSize( - Operation *op, int64_t flatWorkgroupSize, - const llvm::SmallDenseSet &opsToIgnore) { - auto vt = dyn_cast(op); - if (!vt) return llvm::None; - if (opsToIgnore.count(vt)) return llvm::None; - if (!vt.permutation_map().isMinorIdentity()) return llvm::None; - ArrayRef shape = vt.getVectorType().getShape(); - int targetVectorSize = - copyVectorNumBits / vt.getVectorType().getElementTypeBitWidth(); - SmallVector unroll; +static Optional> getTileToDistributableSize( + linalg::GenericOp copyOp, int64_t flatWorkgroupSize) { + SmallVector shape = copyOp.getStaticLoopRanges(); + unsigned bitWidth = copyOp.getOutputOperand(0) + ->get() + .getType() + .cast() + .getElementTypeBitWidth(); + int targetVectorSize = copyVectorNumBits / bitWidth; + SmallVector unroll; assert(shape.back() % targetVectorSize == 0); int64_t threadsAvailable = flatWorkgroupSize; for (auto &dim : llvm::enumerate(llvm::reverse(shape))) { @@ -123,18 +133,131 @@ static Optional> getGPUNativeVectorSize( assert(threadsAvailable == 1); unroll.resize(shape.size(), 1); std::reverse(unroll.begin(), unroll.end()); - if (unroll == shape) return llvm::None; return unroll; } -static void populateVectorUnrollPatterns( - RewritePatternSet &patterns, int64_t flatWorkgroupSize, - const llvm::SmallDenseSet &opsToIgnore) { - auto getShape = [flatWorkgroupSize, &opsToIgnore](Operation *op) { - return getGPUNativeVectorSize(op, flatWorkgroupSize, opsToIgnore); +/// Pattern to tile copies using serial loops into a shape that can be +/// distributed onto thread. +static void populateTileToUnroll(RewritePatternSet &patterns, + int64_t flatWorkgroupSize) { + linalg::TileSizeComputationFunction wgCopyTileSizeFn = + [flatWorkgroupSize](OpBuilder &builder, Operation *operation) { + SmallVector tileSizesVal; + auto copyOp = dyn_cast(operation); + if (!copyOp) return tileSizesVal; + Optional> staticSize = + getTileToDistributableSize(copyOp, flatWorkgroupSize); + for (int64_t dim : *staticSize) { + tileSizesVal.push_back( + builder.create(operation->getLoc(), dim)); + } + return tileSizesVal; + }; + + auto tilingOptions = linalg::LinalgTilingOptions() + .setLoopType(linalg::LinalgTilingLoopType::Loops) + .setTileSizeComputationFunction(wgCopyTileSizeFn); + patterns.insert( + linalg::GenericOp::getOperationName(), patterns.getContext(), + tilingOptions, + linalg::LinalgTransformationFilter( + {StringAttr::get(patterns.getContext(), + getCopyToWorkgroupMemoryMarker())}, + StringAttr::get(patterns.getContext(), kCopyToDistribute))); +} + +/// Break up the flat id onto the static loop ranges. +SmallVector getIds(OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges, + Value flatThreadId) { + SmallVector infos; + Value id = flatThreadId; + AffineExpr d0 = b.getAffineDimExpr(0); + for (Range r : llvm::reverse(parallelLoopRanges)) { + linalg::ProcInfo info; + auto offset = r.offset.dyn_cast(); + auto stride = r.stride.dyn_cast(); + auto size = r.size.dyn_cast(); + assert(offset && stride && size); + int64_t numThreadsDim = (size.cast().getInt() - + offset.cast().getInt()) / + stride.cast().getInt(); + Value dimId = id; + if (infos.size() != parallelLoopRanges.size() - 1) + dimId = makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId}); + info.procId = dimId; + info.nprocs = b.create(loc, numThreadsDim); + info.distributionMethod = + linalg::DistributionMethod::CyclicNumProcsEqNumIters; + infos.push_back(info); + id = makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim), {id}); + } + std::reverse(infos.begin(), infos.end()); + return infos; +} + +/// Return the shape of copy op that can be vectorized to a +/// transfer_read/transfer_write of size `targetVectorSize`. +SmallVector getNativeDstShape(linalg::GenericOp copyOp) { + unsigned bitWidth = copyOp.getOutputOperand(0) + ->get() + .getType() + .cast() + .getElementTypeBitWidth(); + int targetVectorSize = copyVectorNumBits / bitWidth; + SmallVector dstShape; + for (int64_t dim : copyOp.getStaticLoopRanges()) { + // Skip tiling of dimension of size 1 to simplify distribution. + dstShape.push_back(dim == 1 ? 0 : 1); + } + dstShape.back() = targetVectorSize; + return dstShape; +} + +/// Distribute linalg copy onto threads based on the flat id. +static void populateTilingAndDistribute(RewritePatternSet &patterns, + Value flatThreadId) { + linalg::TileSizeComputationFunction wgCopyTileSizeFn = + [](OpBuilder &builder, Operation *operation) { + SmallVector tileSizesVal; + auto copyOp = dyn_cast(operation); + if (!copyOp) return tileSizesVal; + SmallVector staticSize = getNativeDstShape(copyOp); + for (int64_t dim : staticSize) { + tileSizesVal.push_back( + builder.create(operation->getLoc(), dim)); + } + return tileSizesVal; + }; + auto getCopyThreadProcInfoFn = [flatThreadId]( + OpBuilder &builder, Location loc, + ArrayRef parallelLoopRanges) { + return getIds(builder, loc, parallelLoopRanges, flatThreadId); }; - vector::populateVectorUnrollPatterns( - patterns, vector::UnrollVectorOptions().setNativeShapeFn(getShape)); + linalg::LinalgLoopDistributionOptions copyInvocationDistributionOptions; + copyInvocationDistributionOptions.procInfo = getCopyThreadProcInfoFn; + + auto tilingOptions = + linalg::LinalgTilingOptions() + .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops) + .setTileSizeComputationFunction(wgCopyTileSizeFn) + .setDistributionOptions(copyInvocationDistributionOptions); + patterns.insert( + linalg::GenericOp::getOperationName(), patterns.getContext(), + tilingOptions, + linalg::LinalgTransformationFilter( + {StringAttr::get(patterns.getContext(), kCopyToDistribute)}, + StringAttr::get(patterns.getContext(), kCopyDistributed))); +} + +static void populateVectorizationPatterns(RewritePatternSet &patterns) { + VectorizationPatterns::insert( + patterns, linalg::LinalgVectorizationOptions(), + linalg::LinalgTransformationFilter( + {StringAttr::get(patterns.getContext(), + getCopyToWorkgroupMemoryMarker()), + StringAttr::get(patterns.getContext(), kCopyDistributed)}, + llvm::None)); } /// Return a flattened Id Value by combining the 3D gpu thread IDs. @@ -158,58 +281,6 @@ static Value createFlatId(func::FuncOp funcOp, return flatThreadId; } -/// Distribute a transfer read operations on the given thread ids. -static void distributeTransferRead( - func::FuncOp funcOp, Value flatThreadId, int64_t flatWorkgroupSize, - const llvm::SmallDenseSet &opsToIgnore) { - funcOp.walk([&](vector::TransferReadOp readOp) { - if (opsToIgnore.count( - cast(readOp.getOperation()))) - return WalkResult::advance(); - OpBuilder b(readOp); - Value id = flatThreadId; - SmallVector multiplier; - auto shape = readOp.getVectorType().getShape(); - int targetVectorSize = - copyVectorNumBits / readOp.getVectorType().getElementTypeBitWidth(); - SmallVector ids; - SmallVector exprs; - AffineExpr d0 = getAffineDimExpr(0, b.getContext()); - int64_t numThreads = flatWorkgroupSize; - for (auto &dim : llvm::enumerate(llvm::reverse(shape))) { - int64_t threads = - dim.index() == 0 ? (dim.value() / targetVectorSize) : dim.value(); - // If we don't need to distribute the dimension, skip it. - if (threads == 1) continue; - exprs.push_back(getAffineDimExpr(shape.size() - dim.index() - 1, - funcOp->getContext())); - multiplier.push_back(threads); - Value dimId = id; - assert(numThreads % threads == 0); - if (numThreads / threads > 1) { - dimId = - makeComposedAffineApply(b, funcOp.getLoc(), d0 % threads, {dimId}); - } - ids.push_back(dimId); - numThreads = numThreads / threads; - id = makeComposedAffineApply(b, funcOp.getLoc(), d0.floorDiv(threads), - {id}); - if (numThreads <= 1) break; - } - std::reverse(ids.begin(), ids.end()); - Optional ops = - vector::distributPointwiseVectorOp( - b, readOp, ids, multiplier, - AffineMap::get(shape.size(), 0, exprs, funcOp.getContext())); - if (ops.has_value()) { - SmallPtrSet extractOp({ops->extract, ops->insert}); - readOp.getResult().replaceAllUsesExcept(ops->insert.getResult(), - extractOp); - } - return WalkResult::advance(); - }); -} - /// Hoist allocations to the top of the loop if they have no dependencies. static void hoistAlloc(func::FuncOp funcOp) { SmallVector allocs; @@ -243,13 +314,38 @@ static void removeRedundantBarriers(func::FuncOp funcOp) { }); } +/// Return the number of iteration if it is static, otherwise returns 0. +static int64_t numIteration(scf::ForOp forOp) { + auto lbCstOp = forOp.getLowerBound().getDefiningOp(); + auto ubCstOp = forOp.getUpperBound().getDefiningOp(); + auto stepCstOp = forOp.getStep().getDefiningOp(); + if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 || + ubCstOp.value() < 0 || stepCstOp.value() < 0) + return 0; + int64_t tripCount = + mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value()); + return tripCount; +} + +/// Fully unroll all the static loops unless they are part of the ignore map. +static void UnrollSharedMemoryLoops( + func::FuncOp funcOp, const llvm::SmallDenseSet &loopsToIgnore) { + SmallVector forOpsToUnroll; + funcOp.walk([&](scf::ForOp forOp) { + if (!loopsToIgnore.count(forOp)) forOpsToUnroll.push_back(forOp); + }); + for (scf::ForOp forOp : llvm::reverse(forOpsToUnroll)) { + (void)loopUnrollByFactor(forOp, numIteration(forOp)); + } +} + namespace { class GPUDistributeSharedMemoryCopyPass : public GPUDistributeSharedMemoryCopyBase< GPUDistributeSharedMemoryCopyPass> { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { func::FuncOp funcOp = getOperation(); @@ -273,49 +369,54 @@ class GPUDistributeSharedMemoryCopyPass workgroupSize[0] * workgroupSize[1] * workgroupSize[2]; bool isAligned = llvm::all_of( copiesToWorkgroupMem, [flatWorkgroupSize](linalg::GenericOp copyOp) { - MemRefType lhsMemRefType = - copyOp.getOperand(0).getType().cast(); - auto shape = lhsMemRefType.getShape(); + MemRefType dstMemRefType = + copyOp.getOutputOperand(0)->get().getType().cast(); + auto shape = dstMemRefType.getShape(); int targetVectorSize = - copyVectorNumBits / lhsMemRefType.getElementTypeBitWidth(); + copyVectorNumBits / dstMemRefType.getElementTypeBitWidth(); return canPerformVectorAccessUsingAllThreads(shape, flatWorkgroupSize, targetVectorSize); }); + debugPrint(funcOp, "After initial IR cleanup"); + if (isAligned) { - // Ignore all the exisiting vector transfer ops. - llvm::SmallDenseSet opsToIgnore; - funcOp.walk([&](VectorTransferOpInterface transferOp) { - opsToIgnore.insert(transferOp); - }); - // Step 1. Vectorize the shared memory copy. - RewritePatternSet vectorizationPatterns(context); - populateVectorizationPatterns(vectorizationPatterns); + // Ignore all the exisiting loop + llvm::SmallDenseSet loopsToIgnore; + funcOp.walk([&](scf::ForOp loop) { loopsToIgnore.insert(loop); }); + + // Step 1. tile copies to get to a shape that can be distributed to + // 128bits per lane copies. + RewritePatternSet serialTilingPatterns(context); + populateTileToUnroll(serialTilingPatterns, flatWorkgroupSize); if (failed(applyPatternsAndFoldGreedily( - funcOp, std::move(vectorizationPatterns)))) { + funcOp, std::move(serialTilingPatterns)))) { return signalPassFailure(); } + debugPrint(funcOp, "After step 1: tiling"); - // Step 2. Unroll transfer_read/transfer_write to a vector with the number - // of element equal to `targetVectorSize * targetVectorSize`. The. - // transfer op generated can. then be distributed to a single op of target - // size. - RewritePatternSet vectorUnrollPatterns(context); - populateVectorUnrollPatterns(vectorUnrollPatterns, flatWorkgroupSize, - opsToIgnore); + // Calculate a flat id that will then be broken down during distribution. + Value flatId = createFlatId(funcOp, workgroupSize); + // Step 2. Distribute the linalg op onto threads. + RewritePatternSet tileAndDistributePatterns(context); + populateTilingAndDistribute(tileAndDistributePatterns, flatId); if (failed(applyPatternsAndFoldGreedily( - funcOp, std::move(vectorUnrollPatterns)))) { + funcOp, std::move(tileAndDistributePatterns)))) { return signalPassFailure(); } - // Step 3. Distribute the transfer ops onto the flat ids. - Value flatId = createFlatId(funcOp, workgroupSize); - distributeTransferRead(funcOp, flatId, flatWorkgroupSize, opsToIgnore); - // Propagate vector distribution to the chain of ops. - RewritePatternSet distributePatterns(context); - vector::populatePropagateVectorDistributionPatterns(distributePatterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, - std::move(distributePatterns)))) { + debugPrint(funcOp, "After step 2: thread distribution"); + + // Step 3. Vectorize the distributed copies. + RewritePatternSet vectorizationPatterns(context); + populateVectorizationPatterns(vectorizationPatterns); + if (failed(applyPatternsAndFoldGreedily( + funcOp, std::move(vectorizationPatterns)))) { return signalPassFailure(); } + debugPrint(funcOp, "After step 3: vectorization"); + + // Step4. Finally unroll all the loop created + UnrollSharedMemoryLoops(funcOp, loopsToIgnore); + debugPrint(funcOp, "After step 4: unrolling"); } else { // Fall back to basic tiling for cases where workgroup memory size is not // well aligned on the number of threads. @@ -328,6 +429,8 @@ class GPUDistributeSharedMemoryCopyPass funcOp, std::move(threadLevelTilingPatterns)))) { return signalPassFailure(); } + debugPrint(funcOp, "After tiling for unaligned case"); + // Apply canonicalization patterns. RewritePatternSet threadTilingCanonicalizationPatterns = linalg::getLinalgTilingCanonicalizationPatterns(context); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp index 109c1d008b41..f2c9ae51a89b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp @@ -7,11 +7,13 @@ #include "iree/compiler/Codegen/PassDetail.h" #include "iree/compiler/Codegen/Passes.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/SideEffectUtils.h" //====---------------------------------------------------------------------===// // Pass to pipeline copy to shared memory for matmul op. @@ -23,6 +25,63 @@ namespace iree_compiler { static const StringLiteral kPipeliningLoopMarker = "__pipelining_K_loop__"; static const StringLiteral kPipeliningGlobalLoad = "__pipelining_global_load__"; +// Returns a new predicated operation to support unpeeled epilogue. Unpeeled +// epilogue needs to handle the last iterations within the mainloop which +// requires predicating operations, for e.g., OOB global memory access. This +// helper function predicates operations (where predication is avialable), +// checks if unpredicated operations are side-effect free and acceptable to +// execute speculatively. +static Operation* replaceOpWithPredicatedOp(Operation* op, Value pred, + PatternRewriter& rewriter) { + // Predication is only supported for AsyncCopyOp. Thus, for operations which + // are *not* AsyncCopyOp additional checks are requrired in order to be issued + // speculatively. + if (!isa(op)) { + // Return/execute the op if it is a side effect free. + if (mlir::isSideEffectFree(op)) return op; + // Return/execute the op if it is barrier, commit group, or ldmatrix op. + if (isa( + op)) + return op; + // Return/execute the op if it is a shared memory load. + if (auto loadOp = dyn_cast(op)) { + unsigned loadAddrSpace = + loadOp.getBase().getType().cast().getMemorySpaceAsInt(); + if (loadAddrSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) + return op; + } + // If we are here that means the operation does not have predication support + // and cannot be speculatively executed. Thus, unpeeled epilogue is not + // supported. + assert(false && + "Unpeeled epilogue not supported with a side-effect instruction " + "with no predication."); + } + + // Replace mainloop AsyncCopy with AsyncCopy(zfill) inline asm. + auto asyncCopyOp = dyn_cast(op); + auto loc = asyncCopyOp->getLoc(); + + // Create srcElement Value based on the pred. + // The next few lins generate the below code: + // srcElement = (pred) ? dstElements : 0; + Value dstElements = + rewriter.create(loc, asyncCopyOp.getDstElementsAttr()); + Value c0Index = rewriter.create(loc, 0); + auto srcElements = + rewriter.create(loc, pred, dstElements, c0Index); + auto asyncCopyZfillOp = rewriter.create( + loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), + asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), + asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, + UnitAttr()); + + rewriter.eraseOp(asyncCopyOp); + + // Return the newly create predicated AsyncCopyZfillOp. + return asyncCopyZfillOp; +} + /// Helper to recursively add operation dependencies within `block` to `dep` /// set. static void addDepOps(llvm::SmallDenseSet& dep, Operation* op, @@ -84,7 +143,9 @@ static void setAsyncAnnotations(Operation* op, namespace { struct GPUPipeliningPass : public GPUPipeliningBase { - GPUPipeliningPass(unsigned depth) : depth(depth) {} + GPUPipeliningPass(bool epiloguePeeling, unsigned depth) : depth(depth) { + this->epiloguePeeling = epiloguePeeling; + } void runOnOperation() override { auto funcOp = getOperation(); MLIRContext* context = &getContext(); @@ -142,6 +203,17 @@ struct GPUPipeliningPass : public GPUPipeliningBase { }; options.getScheduleFn = getSchedule; options.annotateFn = setAnnotation; + + // Use un-peeled epilogue (i.e. epiloguePeeling=flase) only when predication + // is avialable a.k.a. AsyncCopyOp. + if (!epiloguePeeling) { + options.peelEpilogue = false; + options.predicateFn = [](Operation* op, Value pred, + PatternRewriter& rewriter) { + return replaceOpWithPredicatedOp(op, pred, rewriter); + }; + } + RewritePatternSet pipeliningPatterns(context); scf::populateSCFLoopPipeliningPatterns(pipeliningPatterns, options); if (failed(applyPatternsAndFoldGreedily(funcOp, @@ -155,9 +227,14 @@ struct GPUPipeliningPass : public GPUPipeliningBase { }; } // namespace +/// Pass options +/// epiloguePeeling - try enable/disable epilogue peeling. +/// true : Peel epilogue (no additional checks required) +/// false : Try and use unpeeled epilogue (check if predication is supported is +/// avialable) std::unique_ptr> createGPUPipeliningPass( - unsigned depth) { - return std::make_unique(depth); + bool epiloguePeeling, unsigned depth) { + return std::make_unique(epiloguePeeling, depth); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp index b633279df024..4f8d29cc1110 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp @@ -38,8 +38,8 @@ static void populateVectorizationPatterns(RewritePatternSet &patterns) { StringAttr::get(ctx, getVectorizeMarker())}, llvm::None); f.setMatchByDefault(); - VectorizationPatterns::insert(patterns, - opt, f); + VectorizationPatterns::insert(patterns, opt, f); patterns.add(ctx); patterns.add( ctx, f.addOpFilter(), opt); @@ -57,6 +57,18 @@ struct GPUVectorizationPass void runOnOperation() override { auto funcOp = getOperation(); MLIRContext *context = &getContext(); + + // Pre-process convolution ops. + RewritePatternSet decompositionPattern(funcOp.getContext()); + linalg::LinalgTransformationFilter f( + {StringAttr::get(context, getWorkgroupKTiledMarker())}, + StringAttr::get(context, getVectorizeMarker())); + f.setMatchByDefault(); + linalg::populateDecomposeConvolutionPatterns(decompositionPattern, f); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(decompositionPattern)))) + return signalPassFailure(); + RewritePatternSet vectorizationPatterns(context); populateVectorizationPatterns(vectorizationPatterns); if (generateContract) { diff --git a/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.cpp new file mode 100644 index 000000000000..a3e6222d5fa0 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.cpp @@ -0,0 +1,129 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/LinalgOpInfo.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" + +using namespace mlir::linalg; + +namespace mlir { +namespace iree_compiler { + +/// Returns true if `map` is a tranpose. A transpose map is a projected +/// permutation with or without zeros in results where there exist at least two +/// dimensions di and dj such that di < dj and result_pos(di) > result_pos(dj). +/// Examples: +/// +/// (d0, d1, d2) -> (d0, d2) is not a transpose map. +/// (d0, d1, d2) -> (d2, d0) is a transpose map. +/// (d0, d1, d2) -> (d1, d2) is not a transpose map. +/// (d0, d1, d2) -> (d0, 0, d1) is not a transpose map. +/// (d0, d1, d2) -> (d2, 0, d1) is a transpose map. +/// (d0, d1, d2) -> (d1, 0) is not a transpose map. +/// +// TODO(dcaballe): Discern between "memcopy" transposes and "shuffle" +// transposes. +// TODO(dcaballe): Move to Affine utils? +static bool isTransposeMap(AffineMap map) { + // A transpose map must be a projected permutation with or without + // broadcasted/reduction dimensions. + if (!map.isProjectedPermutation(/*allowZeroInResults=*/true)) { + return false; + } + + // Check that the projected permutation has at least two result dimensions + // that are actually transposed by comparing its input position. + unsigned prevDim = 0; + for (AffineExpr expr : map.getResults()) { + if (auto constExpr = expr.dyn_cast()) { + // Constant zero expression, guaranteed by 'allowZeroInResults' above. + continue; + } else if (auto dimExpr = expr.dyn_cast()) { + if (prevDim > dimExpr.getPosition()) { + return true; + } + prevDim = dimExpr.getPosition(); + } else { + return false; + } + } + + return false; +} + +/// The default filter passes all op operands. +static bool defaultTransposeMapFilter(AffineMap map) { return true; } + +LinalgOpInfo::LinalgOpInfo(linalg::LinalgOp linalgOp) + : transposeMapFilter(defaultTransposeMapFilter) { + computeInfo(linalgOp); +} +LinalgOpInfo::LinalgOpInfo(linalg::LinalgOp linalgOp, + TransposeMapFilter transposeMapFilter) + : transposeMapFilter(transposeMapFilter) { + computeInfo(linalgOp); +} + +/// Returns true if a LinalgOp implements a transpose. +// TODO(dcaballe): +// * Consider transpose + reductions. +// * Consider input and output transposes. +static SmallVector computeTransposeInfo( + LinalgOp linalgOp, TransposeMapFilter transposeMapFilter) { + SmallVector transposeOperands; + + // Reductions are not supported. + if (linalgOp.getNumReductionLoops() > 0) { + return transposeOperands; + } + + // Inverse map to use transfer op permutation logic. + AffineMap outputInversedMap = inversePermutation( + linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0))); + + SmallVector inputInversedMaps; + for (OpOperand *linalgOperand : linalgOp.getInputOperands()) { + auto map = linalgOp.getTiedIndexingMap(linalgOperand); + if (!map.isProjectedPermutation(/*allowZeroInResults=*/true)) { + return transposeOperands; + } + AffineMap inverseMap = inverseAndBroadcastProjectedPermutation(map); + if (isTransposeMap(inverseMap) && transposeMapFilter(inverseMap)) { + transposeOperands.push_back(linalgOperand); + } + } + + // Multiple outputs are not supported yet. + if (linalgOp.getNumOutputs() != 1) { + return transposeOperands; + } + + if (isTransposeMap(outputInversedMap) && + transposeMapFilter(outputInversedMap)) { + transposeOperands.push_back(linalgOp.getOutputOperand(0)); + } + + return transposeOperands; +} + +static bool computeReductionInfo(LinalgOp linalgOp) { + return linalgOp.getNumReductionLoops() > 1; +} + +static bool computeDynamicInfo(LinalgOp linalgOp) { + return linalgOp.hasDynamicShape(); +} + +void LinalgOpInfo::computeInfo(LinalgOp linalgOp) { + transposeOperands = computeTransposeInfo(linalgOp, transposeMapFilter); + reductionTrait = computeReductionInfo(linalgOp); + dynamicTrait = computeDynamicInfo(linalgOp); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.h b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.h new file mode 100644 index 000000000000..34633538877a --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.h @@ -0,0 +1,52 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_ +#define IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_ +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace mlir { + +namespace linalg { +class LinalgOp; +} + +namespace iree_compiler { + +/// Returns true if a map represents the appropriate transpose. Pass this into +/// the LinalgOpInfo for additional transpose granularity. +using TransposeMapFilter = std::function; + +class LinalgOpInfo { + public: + LinalgOpInfo(linalg::LinalgOp linalgOp); + LinalgOpInfo(linalg::LinalgOp linalgOp, + TransposeMapFilter transposeMapFilter); + + bool isTranspose() const { return !transposeOperands.empty(); } + bool isReduction() const { return reductionTrait; } + bool isDynamic() const { return dynamicTrait; } + + ArrayRef getTransposeOperands() const { + return transposeOperands; + } + + private: + void computeInfo(linalg::LinalgOp); + + TransposeMapFilter transposeMapFilter; + bool transposeTrait; + bool reductionTrait; + bool dynamicTrait; + SmallVector transposeOperands; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_ diff --git a/compiler/src/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp index 707fbdeb3783..708f87e80bc7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" @@ -97,48 +98,27 @@ static bool isWorkgroupLoop(const LoopTilingAndDistributionInfo &info) { }); } -/// Infer the number of workgroups by looking at the tiled loop and the number -/// of element per workgroups. +/// Infer the number of workgroups from exportOp. static SmallVector getNumWorkgroup( func::FuncOp funcOp, IREE::HAL::ExecutableExportOp exportOp) { - auto allLoops = getTiledAndDistributedLoopInfo(funcOp); - auto wgLoops = - llvm::to_vector<3>(llvm::make_filter_range(allLoops, isWorkgroupLoop)); - SmallVector workloadSize(wgLoops.size()); - for (LoopTilingAndDistributionInfo &tileInfo : wgLoops) { - if (tileInfo.processorDistributionDim >= workloadSize.size()) return {}; - if (!tileInfo.untiledLowerBound.is() || - !tileInfo.untiledUpperBound.is() || - !tileInfo.untiledStep.is()) { - continue; + SmallVector result; + + Block *body = exportOp.getWorkgroupCountBody(); + if (!body) return result; + + auto returnOp = cast(body->getTerminator()); + assert(returnOp.getNumOperands() == 3); + + for (unsigned i = 0; i < 3; ++i) { + Operation *defOp = returnOp.getOperand(i).getDefiningOp(); + if (auto indexOp = dyn_cast_or_null(defOp)) { + result.push_back(indexOp.value()); + } else { + return SmallVector(); } - int64_t lb = tileInfo.untiledLowerBound.get() - .cast() - .getInt(); - int64_t ub = tileInfo.untiledUpperBound.get() - .cast() - .getInt(); - int64_t step = - tileInfo.untiledStep.get().cast().getInt(); - if (step == 0) return SmallVector(); - workloadSize[tileInfo.processorDistributionDim] = (ub - lb) / step; } - auto translationInfo = getTranslationInfo(exportOp); - if (!translationInfo) return SmallVector(); - SmallVector workloadPerWorkgroup = - translationInfo.getWorkloadPerWorkgroupVals(); - if (workloadSize.size() != workloadPerWorkgroup.size()) { - return SmallVector(); - } - SmallVector numWorkgroups; - for (auto pair : llvm::zip(workloadSize, workloadPerWorkgroup)) { - auto workload = std::get<0>(pair); - auto size = std::get<1>(pair); - numWorkgroups.push_back(llvm::divideCeil(workload, size)); - } - numWorkgroups.resize(kNumMaxParallelDims, 1); - return numWorkgroups; + return result; } static LogicalResult removeOneTripTiledLoops(func::FuncOp funcOp, diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp index dcd8ee384d05..cb57f497e8c9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp @@ -87,8 +87,7 @@ static LogicalResult lowerToUnitWorkgroupCount( static LogicalResult lowerDispatchWorkgroupCountFromDagRootOp( IREE::Flow::DispatchWorkgroupCountFromDagRootOp workgroupCountOp, ArrayRef computeOps, SmallVectorImpl &tileSizes, - SmallVector &interchange, - SmallVectorImpl &workloadPerWorkgroup) { + SmallVector &interchange) { auto workloadValues = workgroupCountOp.operands(); // Find the lowering configuration of the root operation. @@ -148,25 +147,46 @@ static LogicalResult lowerDispatchWorkgroupCountFromDagRootOp( llvm::DenseSet partitionableLoopsSet; partitionableLoopsSet.insert(partitionableLoops.begin(), partitionableLoops.end()); + for (auto workload : llvm::enumerate(workloadValues)) { if (!partitionableLoopsSet.count(workload.index())) { tileSizes[workload.index()] = 0; } - if (tileSizes[workload.index()] == 0) { + int64_t tileSize = tileSizes[workload.index()]; + + if (tileSize == 0) { numTiles.push_back(one); continue; } - if (tileSizes[workload.index()] == 1) { - numTiles.push_back(workload.value()); - continue; + + // When the loop range is known to be static, let's directly use it. + int64_t loopRange = ShapedType::kDynamicSize; + + if (auto linalgOp = dyn_cast(*rootOp)) { + loopRange = linalgOp.getStaticLoopRanges()[workload.index()]; + } + + if (loopRange != ShapedType::kDynamicSize) { + if (tileSize == 1) { + Value workload = builder.create(loc, loopRange); + numTiles.push_back(workload); + continue; + } + int64_t nTileI64 = (loopRange + tileSize - 1) / tileSize; + Value nTiles = builder.create(loc, nTileI64); + numTiles.push_back(nTiles); + } else { + if (tileSize == 1) { + numTiles.push_back(workload.value()); + continue; + } + AffineExpr s0; + bindSymbols(workgroupCountOp.getContext(), s0); + AffineMap numTilesMap = AffineMap::get(0, 1, s0.ceilDiv(tileSize)); + Value nTiles = + builder.create(loc, numTilesMap, workload.value()); + numTiles.push_back(nTiles); } - AffineExpr s0; - bindSymbols(workgroupCountOp.getContext(), s0); - AffineMap numTilesMap = - AffineMap::get(0, 1, s0.ceilDiv(tileSizes[workload.index()])); - Value nTiles = - builder.create(loc, numTilesMap, workload.value()); - numTiles.push_back(nTiles); } // If there is interchange, first apply interchange on the number of tiles. @@ -187,7 +207,6 @@ static LogicalResult lowerDispatchWorkgroupCountFromDagRootOp( // If the loop isnt tiled, skip it. if (tileSizes[partitionedLoop] == 0) continue; numWorkgroups.push_back(numTiles[partitionedLoop]); - workloadPerWorkgroup.push_back(tileSizes[partitionedLoop]); } numWorkgroups.resize(kNumMaxParallelDims, one); workgroupCountOp->replaceAllUsesWith(numWorkgroups); @@ -195,28 +214,6 @@ static LogicalResult lowerDispatchWorkgroupCountFromDagRootOp( return success(); } -/// Update the workload_per_wg value on the TranslationInfoAttr. -// TODO(ravishankarm): The workload_per_wg field should be deprecated. This -// is just transition before all dependencies on it can be removed. -static LogicalResult updateTranslationInfoAttr( - IREE::HAL::ExecutableExportOp exportOp, - ArrayRef workloadPerWorkgroup) { - IREE::Codegen::DispatchLoweringPassPipeline passPipeline = - IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault; - if (auto translationInfo = getTranslationInfo(exportOp)) { - // Expect the `workload_per_wg` to be empty. - if (!translationInfo.getWorkloadPerWorkgroupVals().empty()) { - return exportOp.emitOpError( - "expected workload_per_wg to be empty at this stage"); - } - passPipeline = translationInfo.getDispatchLoweringPassPipeline(); - } - auto newTranslationInfoAttr = IREE::Codegen::TranslationInfoAttr::get( - exportOp.getContext(), passPipeline, workloadPerWorkgroup); - setTranslationInfo(exportOp, newTranslationInfoAttr); - return success(); -} - //===---------------------------------------------------------------------===// // Patterns and methods for tile and distribute of Linalg ops to workgroups. //===---------------------------------------------------------------------===// @@ -280,18 +277,13 @@ void TileAndDistributeToWorkgroupsPass::runOnOperation() { IREE::Flow::DispatchWorkgroupCountFromDagRootOp defaultWorkgroupCountOp = *(ops.begin()); - SmallVector tileSizes, interchange, workloadPerWorkgroup; + SmallVector tileSizes, interchange; if (failed(lowerDispatchWorkgroupCountFromDagRootOp( - defaultWorkgroupCountOp, computeOps, tileSizes, interchange, - workloadPerWorkgroup))) { + defaultWorkgroupCountOp, computeOps, tileSizes, interchange))) { defaultWorkgroupCountOp.emitOpError( "failed to lower default number of workgroups"); return signalPassFailure(); } - if (failed(updateTranslationInfoAttr(exportOp, workloadPerWorkgroup))) { - exportOp.emitOpError("failed to update translation info"); - return signalPassFailure(); - } // If there are no compute ops, nothing more to do. if (computeOps.empty()) continue; @@ -335,6 +327,13 @@ void TileAndDistributeToWorkgroupsPass::runOnOperation() { } } + // If tiling didn't happen because there are no tile sizes we are + // potentially left with a marker that will confuse the following passes so + // we remove the intermediate markers. + funcOp->walk([&](Operation *op) { + op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker); + }); + LLVM_DEBUG({ llvm::dbgs() << "--- After Tile + Distribute ---\n"; funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD index 89f1f97ade46..0d99cfa7f46f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD @@ -62,13 +62,19 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen:PassHeaders", "//compiler/src/iree/compiler/Codegen/Common:CommonPasses", "//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces", + "//compiler/src/iree/compiler/Codegen/Utils", + "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:ArithmeticUtils", "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:BufferizationTransforms", + "@llvm-project//mlir:LinalgTransformOps", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TransformDialect", ], diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt index 722ebbebcd0d..fe0912968c1f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt @@ -34,14 +34,20 @@ iree_cc_library( IREEDialectsTransforms IREELinalgTransformDialect LLVMSupport + MLIRAffineDialect MLIRArithmeticDialect + MLIRArithmeticUtils MLIRBufferizationDialect MLIRBufferizationTransforms + MLIRLinalgTransformOps + MLIRLinalgTransforms MLIRPass MLIRTransformDialect iree::compiler::Codegen::Common::CommonPasses iree::compiler::Codegen::Interfaces::BufferizationInterfaces iree::compiler::Codegen::PassHeaders + iree::compiler::Codegen::Utils + iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 9f47507cb717..aeee5470f1ca 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -14,8 +14,14 @@ #include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h" #include "iree/compiler/Codegen/Passes.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Pass/PassManager.h" @@ -304,35 +310,46 @@ LogicalResult rewriteForeachThreadToWorkgroup( "scf.foreach_thread with rank > 3 does not lower to workgroup"); // Step 0. Outline the compute workload region and set up the workload - // operands. - auto maybeWorkgroupCounts = getNumThreads(rewriter, foreachThreadOp); - if (failed(maybeWorkgroupCounts) || - llvm::any_of(*maybeWorkgroupCounts, [](OpFoldResult ofr) { - return !getConstantIntValue(ofr).has_value(); - })) - return foreachThreadOp->emitError( - "unsupported dynamic workgroup_count atm --- need to slice out " - "workgroup_count computation into ExecutableExport::workgroup_count. " - "This region may require arbitrary computations and cannot magically " - "match what the `stream.cmd.dispatch` has already imposed on us at a " - "distance. For now we must specify the number of values properly when " - "applying the topLevel tile_to_foreach_thread_op"); - - SmallVector workgroupCounts; - for (OpFoldResult ofr : *maybeWorkgroupCounts) - workgroupCounts.push_back(getConstantIntValue(ofr).value()); - if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp, - exportOp))) - return foreachThreadOp->emitOpError( - "failed to populate workload region for dispatchOp: ") - << exportOp; + // operands, if this has not been done already. + // Using `transform.iree.tile_to_foreach_thread_and_workgroup_count_region` is + // the preferred way to set up tiling and workgroup_count region **at the same + // time**. + // + // The block of code below will be retired once there is enough confidence we + // can do everything without it. This includes in particular providing custom + // fusion heuristics at the flow level: at this time, the only way to fully + // control fusion of more advanced cases is to use the transform dialect at + // the flow level and explicitly match the ops we want to fuse. + // Once fusion is customizable enough in perpetuity, we can retire this. + if (exportOp.getWorkgroupCount().empty()) { + auto maybeWorkgroupCounts = getNumThreads(rewriter, foreachThreadOp); + if (failed(maybeWorkgroupCounts) || + llvm::any_of(*maybeWorkgroupCounts, [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) + return foreachThreadOp->emitError( + "unsupported dynamic workgroup_count atm --- need to slice out " + "workgroup_count computation into ExecutableExport::workgroup_count. " + "This region may require arbitrary computations and cannot magically " + "match what the `stream.cmd.dispatch` has already imposed on us at a " + "distance. For now we must specify the number of values properly " + "when applying the topLevel tile_to_foreach_thread_op"); + SmallVector workgroupCounts; + for (OpFoldResult ofr : *maybeWorkgroupCounts) + workgroupCounts.push_back(getConstantIntValue(ofr).value()); + if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp, + exportOp))) { + return foreachThreadOp->emitOpError( + "failed to populate workload region for dispatchOp: ") + << exportOp; + } + } // Step 1. Create the workgroup id and count ops. Location loc = foreachThreadOp.getLoc(); BlockAndValueMapping bvm; SmallVector workgroupIdOps, workgroupCountOps; - for (int64_t rank : - llvm::seq(0, foreachThreadOp.getThreadIndices().size())) { + for (int64_t rank : llvm::seq(0, 3)) { workgroupIdOps.push_back( rewriter.create(loc, rank)); workgroupCountOps.push_back( @@ -357,9 +374,15 @@ LogicalResult rewriteForeachThreadToWorkgroup( // Step 4. RAUW thread indices to thread ops. SmallVector threadIndices = *getThreadIndices(rewriter, foreachThreadOp); + assert(workgroupIdOps.size() == 3 && "3 workgroup id ops are required"); + assert(threadIndices.size() == 3 && "3 thread id dimensions are required"); for (auto it : llvm::zip(threadIndices, workgroupIdOps)) { - if (!std::get<0>(it)) continue; - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + Value val = std::get<0>(it); + if (!val) continue; + for (Operation *user : llvm::make_early_inc_range(val.getUsers())) { + rewriter.updateRootInPlace( + user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); }); + } } // Step 5. Barriers omitted given unique topLevel scf::ForeachThreadOp. @@ -394,10 +417,6 @@ transform_dialect::ForeachThreadToWorkgroupOp::applyToOne( state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); } - if (!exportOp.getWorkgroupCount().empty()) - return emitDefaultSilenceableFailure(target) - << "export op must have an empty workgroup count region that " - "the transform fills --- the transform is not applied"; scf::ForeachThreadOp topLevelForeachThreadOp; auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { @@ -424,5 +443,144 @@ transform_dialect::ForeachThreadToWorkgroupOp::applyToOne( return DiagnosedSilenceableFailure(success()); } +void transform_dialect::ForeachThreadToWorkgroupOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::producesHandle(getTransformed(), effects); +} + +//===---------------------------------------------------------------------===// +// TileToForeachThreadAndWorkgroupCountRegion +//===---------------------------------------------------------------------===// + +/// Lower the ops within the workgroup count region of `exportOp` that +/// represents the workgroup count calculation, to the actual +/// computation that returns the number of workgroups. For now +/// this lowers the `flow.dispatch.workgroup_count_from_dag_root` op +/// to `ceilDiv(workload, tileSizes)`. +static LogicalResult lowerWorkgroupCountComputingRegion( + RewriterBase &rewriter, HAL::ExecutableExportOp exportOp, + ArrayRef tileSizes) { + Region &r = exportOp.getWorkgroupCount(); + if (!r.hasOneBlock()) { + return rewriter.notifyMatchFailure(exportOp, + "expected export op to have a workgroup " + "count region with a single block"); + } + auto workgroupCountOps = + r.front().getOps(); + if (!llvm::hasSingleElement(workgroupCountOps)) { + return rewriter.notifyMatchFailure( + exportOp, + "expected region to have a single " + "flow.dispatch.workgroup_count_from_dag_root op"); + } + auto workgroupCountOp = *workgroupCountOps.begin(); + auto workload = workgroupCountOp.getOperands(); + + if (tileSizes.size() > workload.size()) { + return rewriter.notifyMatchFailure( + exportOp, + "number of tile sizes overflow the dimension from the workload"); + } + + SmallVector workgroupCount; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(workgroupCountOp); + Location loc = workgroupCountOp.getLoc(); + for (auto tileSize : llvm::enumerate(tileSizes)) { + if (isConstantIntValue(tileSize.value(), 0)) { + workgroupCount.push_back(workload[tileSize.index()]); + continue; + } + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + auto m = AffineMap::get(0, 2, s0.ceilDiv(s1)); + OpFoldResult count = makeComposedFoldedAffineApply( + rewriter, loc, m, + ArrayRef{workload[tileSize.index()], tileSize.value()}); + workgroupCount.push_back(count); + } + workgroupCount = llvm::to_vector(llvm::reverse(workgroupCount)); + workgroupCount.resize(3, rewriter.getIndexAttr(1)); + rewriter.replaceOp(workgroupCountOp, getValueOrCreateConstantIndexOp( + rewriter, loc, workgroupCount)); + return success(); +} + +SmallVector transform_dialect:: + TileToForeachThreadAndWorkgroupCountRegion::getMixedNumThreads() { + return getMixedSizes(getStaticNumThreads(), getNumThreads()); +} + +SmallVector transform_dialect:: + TileToForeachThreadAndWorkgroupCountRegion::getMixedTileSizes() { + return getMixedSizes(getStaticTileSizes(), getTileSizes()); +} + +LogicalResult +transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::verify() { + if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) + return emitOpError("either num_threads or tile_sizes must be specified"); + return success(); +} + +void transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::onlyReadsHandle(getTileSizes(), effects); + transform::onlyReadsHandle(getNumThreads(), effects); + transform::producesHandle(getResults(), effects); +} + +DiagnosedSilenceableFailure +transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::apply( + transform::TransformResults &transformResults, + transform::TransformState &state) { + ArrayRef targetOps = state.getPayloadOps(getTarget()); + assert(targetOps.size() == 1 && "expected single target op in payload"); + auto funcOp = targetOps.front()->getParentOfType(); + FailureOr exportOp = getEntryPoint(funcOp); + if (failed(exportOp)) { + state.getTopLevel()->emitOpError("couldn't find export op for func"); + return DiagnosedSilenceableFailure(reportUnknownTransformError(funcOp)); + } + + SmallVector mixedTileSizes = getMixedTileSizes(); + if (mixedTileSizes.empty()) { + exportOp.value()->emitOpError("require tile sizes to be specified"); + return DiagnosedSilenceableFailure( + reportUnknownTransformError(exportOp.value())); + } + + /// Lower the workgroup count region in keeping with the way dispatch + /// regions are created by default in IREEs compilation flow. + IRRewriter rewriter(getContext()); + if (failed(lowerWorkgroupCountComputingRegion(rewriter, exportOp.value(), + mixedTileSizes))) { + exportOp.value()->emitOpError("failed to lower workgroup count region"); + return DiagnosedSilenceableFailure( + reportUnknownTransformError(exportOp.value())); + } + + ArrayRef targets = state.getPayloadOps(getTarget()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + DiagnosedSilenceableFailure diag = transform::tileToForeachThreadOpImpl( + rewriter, state, cast(getOperation()), + targets, getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), + tileOps, tiledOps); + + if (!diag.succeeded()) return diag; + + transformResults.set(getForeachThreadOp().cast(), tileOps); + transformResults.set(getTiledOp().cast(), tiledOps); + + return DiagnosedSilenceableFailure(success()); +} + #define GET_OP_CLASSES #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.cpp.inc" diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td index 0b7d3e13a642..7571fb7347ad 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td @@ -101,7 +101,7 @@ def IREEBufferizeOp : Op, TransformOpInterface, TransformEachOpTrait]> { let description = [{ @@ -153,4 +153,51 @@ def ForeachThreadToWorkgroupOp : Op, + TransformOpInterface]> { + let description = [{ + Wrapper around `structured.tile_to_foreach_thread_op` for use within IREE. + + In addition to tile and distribute using `scf.foreach_thread`, lowers the + the `workgroup_count` region of the export op corresponding to the parent + `func.func` of the target to return the number of workgroups. + Please see the doc of `structured.tile_to_foreach_thread_op` for full + description of op semantics. + }]; + + let arguments = (ins PDL_Operation:$target, + Variadic:$num_threads, + Variadic:$tile_sizes, + DefaultValuedAttr:$static_num_threads, + DefaultValuedAttr:$static_tile_sizes, + OptionalAttr:$thread_dim_mapping); + let results = (outs PDL_Operation:$foreach_thread_op, + PDL_Operation:$tiled_op); + let assemblyFormat = [{ + $target oilist( + `num_threads` custom($num_threads, + $static_num_threads, + "ShapedType::kDynamicSize") | + `tile_sizes` custom($tile_sizes, + $static_tile_sizes, + "ShapedType::kDynamicSize")) + (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict + }]; + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedNumThreads(); + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedTileSizes(); + }]; +} + + #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS diff --git a/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp new file mode 100644 index 000000000000..418a1a2f5817 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp @@ -0,0 +1,32 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/UserConfig.h" + +namespace mlir { +namespace iree_compiler { + +/// Propagate the configuration annotated in the incoming IR. +LogicalResult setUserConfig( + func::FuncOp entryPointFn, Operation *computeOp, + IREE::Codegen::CompilationInfoAttr compilationInfo) { + if (auto translationInfo = getTranslationInfo(entryPointFn)) { + return computeOp->emitOpError( + "multiple ops within dispatch trying to set the translation " + "info"); + } + + SmallVector workgroupSize = compilationInfo.getWorkgroupSizeVals(); + setTranslationInfo(entryPointFn, compilationInfo.getTranslationInfo(), + workgroupSize); + + setLoweringConfig(computeOp, compilationInfo.getLoweringConfig()); + eraseCompilationInfo(computeOp); + return success(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/UserConfig.h b/compiler/src/iree/compiler/Codegen/Common/UserConfig.h new file mode 100644 index 000000000000..ef22f004c42d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/UserConfig.h @@ -0,0 +1,17 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/LoweringConfig.h" + +namespace mlir { +namespace iree_compiler { + +/// Sets compilation configuration annotated in the incoming IR. +LogicalResult setUserConfig(func::FuncOp entryPointFn, Operation *computeOp, + IREE::Codegen::CompilationInfoAttr compilationInfo); + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp index 7101d36900b7..e9f1a4f4cdda 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp @@ -50,7 +50,7 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input, .create(loc, laneVal, i, /*width=*/size, /*mode=*/gpu::ShuffleMode::XOR) - .result(); + .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } return laneVal; diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorizeConv.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorizeConv.cpp deleted file mode 100644 index 578541cedfee..000000000000 --- a/compiler/src/iree/compiler/Codegen/Common/VectorizeConv.cpp +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Codegen/PassDetail.h" -#include "iree/compiler/Codegen/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#define DEBUG_TYPE "iree-vectorize-conv" - -namespace mlir { -namespace iree_compiler { - -namespace { - -/// Vectorizes linalg.conv_2d_nhwc_hwcf for a single GPU -/// invocation. Therefore, the linalg.conv op should have a very specific form; -/// other patterns are expected to tile and distribute larger convolutions into -/// this form for a single GPU invocation. -/// -/// The linalg.conv op should follow: -/// - Filter: HfWfCiCo format -/// - Input : NHiWiCi format -/// - Output: NHoWoCo format -/// - For output: -/// - N must be 1. -/// - Co must be a multiple of 4. -/// - For input: -/// - Ci must be < 4. -/// - For filter: -/// - Hf must be 1. -/// - Wf must be 1. -/// - No dilation. -/// - No padding. -/// -/// Output channel is requried to be a multiple of 4 so that we can process -/// them with load4/store4, which is native to GPUs. Similarly for the input -/// channel size requirement. -struct VectorizeLinalgConv : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) const override { - LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n"); - - // This pattern does not handle convolutions with dilation. - if (auto dilations = convOp.getDilations()) { - auto values = dilations.getValues(); - if (llvm::any_of(values, [](const APInt &value) { - return value.getSExtValue() != 1; - })) { - return failure(); - } - } - - Value input = convOp.image(); - Value filter = convOp.filter(); - Value output = convOp.getOutputs()[0]; - - auto inputType = input.getType().cast(); - auto filterType = filter.getType().cast(); - auto outputType = output.getType().cast(); - - // The filter/input/output view should have static sizes to vectorize. - if (!inputType.hasStaticShape() || !filterType.hasStaticShape() || - !outputType.hasStaticShape()) { - return failure(); - } - - auto filterShape = filterType.getShape(); - auto outputShape = outputType.getShape(); - - // The output batch dimension should be 1. - if (outputShape[0] != 1) return failure(); - - // We addtionally expect the filter height/width dimensions are both 1 to - // simplify vectorization. Other patterns can generate loops to create 1x1 - // filter subivews. - if (filterShape[0] != 1 || filterShape[1] != 1) return failure(); - - int64_t numInputChannels = filterShape[2]; - int64_t numOutputChannels = filterShape[3]; - if (numInputChannels > 4 || numOutputChannels % 4 != 0) return failure(); - - int64_t numOutputHeights = outputShape[1]; - int64_t numOutputWidths = outputShape[2]; - int64_t heightStride = convOp.getStrides().getValues()[0]; - int64_t widthStride = convOp.getStrides().getValues()[1]; - - // This invocation handles a batch of - // (numOutputHeights * numOutputWidths * numOutputChannels). - LLVM_DEBUG({ - llvm::dbgs() << "# output height: " << numOutputHeights << "\n"; - llvm::dbgs() << "# output width: " << numOutputWidths << "\n"; - llvm::dbgs() << "# output channels: " << numOutputChannels << "\n"; - llvm::dbgs() << "height stride: " << heightStride << "\n"; - llvm::dbgs() << "width stride: " << widthStride << "\n"; - }); - - MLIRContext *context = convOp.getContext(); - Location loc = convOp.getLoc(); - - Type elementType = filterType.getElementType(); - auto filterVectorType = - VectorType::get({numInputChannels, numOutputChannels}, elementType); - auto vector1x4Type = VectorType::get({1, 4}, elementType); - auto inputVectorType = VectorType::get({1, numInputChannels}, elementType); - Value zero = rewriter.createOrFold(loc, 0); - - // Load the entire filter subview. - SmallVector filterIndices(4, zero); - Value wholeFilter = rewriter.create( - loc, filterVectorType, filter, filterIndices); - - // Get filter slices so that later we can use them for dot product with the - // input. Both the height and width dimensions are 1; so we just need to - // loop over input and output channel dimensions. - SmallVector, 4> filterVectors(numInputChannels); - for (int ic = 0; ic < numInputChannels; ++ic) { - auto &thisInputChannel = filterVectors[ic]; - thisInputChannel.reserve(numOutputChannels / 4); - for (int oc = 0; oc < numOutputChannels / 4; ++oc) { - Value slice = rewriter.create( - loc, wholeFilter, /*offsets=*/ArrayRef({ic, oc * 4}), - /*sizes=*/ArrayRef({1, 4}), - /*strides=*/ArrayRef({1, 1})); - thisInputChannel.push_back(slice); - } - } - - // Build indexing maps for a later vector contraction op. - AffineExpr dim0 = getAffineDimExpr(0, context); // M - AffineExpr dim1 = getAffineDimExpr(1, context); // N - AffineExpr dim2 = getAffineDimExpr(2, context); // K - auto map02 = AffineMap::get(3, 0, {dim0, dim2}, context); - auto map21 = AffineMap::get(3, 0, {dim2, dim1}, context); - auto map01 = AffineMap::get(3, 0, {dim0, dim1}, context); - ArrayAttr indexingMaps = - rewriter.getAffineMapArrayAttr({map02, map21, map01}); - - // Also build iterator types for the vector contraction op. - ArrayAttr iterators = rewriter.getStrArrayAttr( - {getParallelIteratorTypeName(), getParallelIteratorTypeName(), - getReductionIteratorTypeName()}); - - // Compute the (numOutputHeights * numOutputWidths * numOutputChannels) - // batch. We only contribute numInputChannels accumulation along the - // reduction dimension. So read in the result from the output, compose a - // chain of numInputChannels vector dot operations, and then write out. - bool hasTensorSemantics = convOp.hasTensorSemantics(); - Value outputWrite = output; - for (int oh = 0; oh < numOutputHeights; ++oh) { - for (int ow = 0; ow < numOutputWidths; ++ow) { - // Read in the input vector for these 4 input channels a a batch. The - // input vector are used for computing all output channels so data can - // be reused. - SmallVector inputIndices(4, zero); - inputIndices[1] = rewriter.createOrFold( - loc, oh * heightStride); - inputIndices[2] = rewriter.createOrFold( - loc, ow * widthStride); - Value inputVector = rewriter.create( - loc, inputVectorType, input, inputIndices); - - for (int oc = 0; oc < numOutputChannels / 4; ++oc) { - // Read in the initial value for this output vector. - SmallVector outputIndices(4, zero); - outputIndices[1] = - rewriter.createOrFold(loc, oh); - outputIndices[2] = - rewriter.createOrFold(loc, ow); - outputIndices[3] = - rewriter.createOrFold(loc, oc * 4); - Value outputVector = rewriter.create( - loc, vector1x4Type, output, outputIndices); - - // Peform a chain of dot product and accumulation. - for (int i = 0; i < numInputChannels; ++i) { - auto inputSlice = rewriter.create( - loc, inputVector, /*offsets=*/ArrayRef({0, i}), - /*sizes=*/ArrayRef({1, 1}), - /*strides=*/ArrayRef({1, 1})); - outputVector = rewriter.create( - loc, inputSlice, filterVectors[i][oc], outputVector, - indexingMaps, iterators); - } - - // Write out the output vector. - auto writeOp = rewriter.create( - loc, outputVector, outputWrite, outputIndices); - if (hasTensorSemantics) outputWrite = writeOp.getResult(); - } - } - } - - if (hasTensorSemantics) { - rewriter.replaceOp(convOp, outputWrite); - } else { - rewriter.eraseOp(convOp); - } - - return success(); - } -}; - -/// Vectorizes linalg.depthwise_conv_2d_nhwc_hwc for a single GPU -/// invocation. Therefore, the linalg.depthwise_conv_2d_nhwc_hwc op -/// should have a very specific form; other patterns are expected to tile and -/// distribute larger convolutions into this form for a single GPU invocation. -/// -/// The linalg.depthwise_conv_2d_nhwc_hwc op should follow: -/// - Filter: HfWfC format -/// - Input : NHiWiC format -/// - Output: NHoWoC format -/// - For output: -/// - N must be 1. -/// - C must be a multiple of 4. -/// - For filter: -/// - Hf must be 1. -/// - Wf must be 1. -/// - No dilation. -/// - No padding. -/// -/// Channel is requried to be a multiple of 4 so that we can process them with -/// load4/store4, which is native to GPUs. -struct VectorizeLinalgDepthwiseConv - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp, - PatternRewriter &rewriter) const override { - LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n"); - - Value input = convOp.image(); - Value filter = convOp.filter(); - Value output = convOp.getOutputs()[0]; - - auto inputType = input.getType().cast(); - auto filterType = filter.getType().cast(); - auto outputType = output.getType().cast(); - - // The filter/input/output view should have static sizes to vectorize. - if (!inputType.hasStaticShape() || !filterType.hasStaticShape() || - !outputType.hasStaticShape()) { - return failure(); - } - - auto filterShape = filterType.getShape(); - auto outputShape = outputType.getShape(); - - // The output batch dimension should be 1. - if (outputShape[0] != 1) return failure(); - - // We addtionally expect the filter height/width dimensions are both 1 to - // simplify vectorization. Other patterns can generate loops to create 1x1 - // filter subivews. - if (filterShape[0] != 1 || filterShape[1] != 1) return failure(); - - int64_t numChannels = outputShape[3]; - if (numChannels % 4 != 0) return failure(); - - int64_t numOutputHeights = outputShape[1]; - int64_t numOutputWidths = outputShape[2]; - int64_t heightStride = convOp.getStrides().getValues()[0]; - int64_t widthStride = convOp.getStrides().getValues()[1]; - - // This invocation handles a batch of (numOutputHeights * numOutputWidths * - // numChannels). - LLVM_DEBUG({ - llvm::dbgs() << "# output height: " << numOutputHeights << "\n"; - llvm::dbgs() << "# output width: " << numOutputWidths << "\n"; - llvm::dbgs() << "# channels: " << numChannels << "\n"; - llvm::dbgs() << "height stride: " << heightStride << "\n"; - llvm::dbgs() << "width stride: " << widthStride << "\n"; - }); - - Location loc = convOp.getLoc(); - - Type elementType = filterType.getElementType(); - auto vector4Type = VectorType::get({1, 1, 1, 4}, elementType); - auto filterVectorType = VectorType::get({1, 1, numChannels}, elementType); - Value zero = rewriter.createOrFold(loc, 0); - - // Load the entire filter subview. - SmallVector filterIndices(3, zero); - Value wholeFilter = rewriter.create( - loc, filterVectorType, filter, filterIndices); - - // Compute the (numOutputHeights * numOutputWidths * numChannels) output - // batch. We only contribute numChannels accumulation along the reduction - // dimension. - bool hasTensorSemantics = convOp.hasTensorSemantics(); - Value outputWrite = output; - for (int oc = 0; oc < numChannels / 4; ++oc) { - Value filterVector = rewriter.create( - loc, wholeFilter, /*offsets=*/ArrayRef{0, 0, oc * 4}, - /*sizes=*/ArrayRef{1, 1, 4}, - /*strides=*/ArrayRef{1, 1, 1}); - filterVector = - rewriter.create(loc, vector4Type, filterVector); - - for (int oh = 0; oh < numOutputHeights; ++oh) { - for (int ow = 0; ow < numOutputWidths; ++ow) { - // Read in the initial value for this output vector. - SmallVector outputIndices(4, zero); - outputIndices[1] = - rewriter.createOrFold(loc, oh); - outputIndices[2] = - rewriter.createOrFold(loc, ow); - outputIndices[3] = - rewriter.createOrFold(loc, oc * 4); - Value outputVector = rewriter.create( - loc, vector4Type, output, outputIndices); - - // Read in the input vector for these 4 input channels a a batch. - SmallVector inputIndices(4, zero); - inputIndices[1] = rewriter.createOrFold( - loc, oh * heightStride); - inputIndices[2] = rewriter.createOrFold( - loc, ow * widthStride); - inputIndices[3] = - rewriter.createOrFold(loc, oc * 4); - Value inputVector = rewriter.create( - loc, vector4Type, input, inputIndices); - - // Peform element-wise product and accumulation. - outputVector = rewriter.create( - loc, inputVector, filterVector, outputVector); - - // Write out the output vector. - auto writeOp = rewriter.create( - loc, outputVector, outputWrite, outputIndices); - if (hasTensorSemantics) outputWrite = writeOp.getResult(); - } - } - } - - if (hasTensorSemantics) { - rewriter.replaceOp(convOp, outputWrite); - } else { - rewriter.eraseOp(convOp); - } - return success(); - } -}; - -struct LinalgToVectorVectorizeConvPass - : public LinalgToVectorVectorizeConvBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(&getContext()); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -void populateLinalgToVectorVectorizeConvPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - patterns.insert(context); -} - -std::unique_ptr> -createLinalgToVectorVectorizeConvPass() { - return std::make_unique(); -} - -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD b/compiler/src/iree/compiler/Codegen/Common/test/BUILD index a3c1bfa4421a..40e88b6f350d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD @@ -29,6 +29,7 @@ iree_lit_test_suite( "fold_affine_min_in_distributed_loops.mlir", "fold_tensor_extract_op.mlir", "forop_canonicalization.mlir", + "gpu_pipeline.mlir", "gpu_vectorization.mlir", "iree_comprehensive_bufferize.mlir", "pad_dynamic_alloc.mlir", @@ -40,7 +41,6 @@ iree_lit_test_suite( "transform_dialect_apply_pattern_op.mlir", "transpose_canonicalization.mlir", "type_propagation.mlir", - "vectorize_linalg_conv.mlir", "vectorize_tensor_pad.mlir", "warp_reduction.mlir", ], diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 684fc1873a2f..75fe159be114 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -25,6 +25,7 @@ iree_lit_test_suite( "fold_affine_min_in_distributed_loops.mlir" "fold_tensor_extract_op.mlir" "forop_canonicalization.mlir" + "gpu_pipeline.mlir" "gpu_vectorization.mlir" "iree_comprehensive_bufferize.mlir" "pad_dynamic_alloc.mlir" @@ -36,7 +37,6 @@ iree_lit_test_suite( "transform_dialect_apply_pattern_op.mlir" "transpose_canonicalization.mlir" "type_propagation.mlir" - "vectorize_linalg_conv.mlir" "vectorize_tensor_pad.mlir" "warp_reduction.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir b/compiler/src/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir index eb961bf18a4d..41297c556eb6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir @@ -64,8 +64,8 @@ builtin.module { // CHECK-DAG: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][0, 0, 0] [2, 1, 3] // CHECK-DAG: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0] [2, 1, 3] // CHECK: linalg.generic -// CHECK-SAME: ins(%[[SOURCE_SUBVIEW]] : memref<2x3xf32, #{{[a-zA-Z0-9]+}}>) -// CHECK-SAME: outs(%[[DEST_SUBVIEW]] : memref<2x3xf32, #{{[a-zA-Z0-9]+}}>) +// CHECK-SAME: ins(%[[SOURCE_SUBVIEW]] : memref<2x3xf32, strided<[24, 1]>>) +// CHECK-SAME: outs(%[[DEST_SUBVIEW]] : memref<2x3xf32, strided<[48, 1]>>) // ----- diff --git a/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir index adb183d36650..bbdd0ee532e1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-gpu-distribute-shared-memory-copy))))' --cse %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-gpu-distribute-shared-memory-copy))))' --fold-memref-alias-ops --canonicalize --cse %s | FileCheck %s // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)> @@ -48,9 +48,9 @@ hal.executable private @shared_mem_cpy { // CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]] // CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]] // CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> + // CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> // CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]] // CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> - // CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> // CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -64,9 +64,9 @@ hal.executable private @shared_mem_cpy { // CHECK: %[[Y1:.*]] = affine.apply #[[$MAP3]]()[%[[TX]], %[[TY]], %[[TZ]]] // CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> + // CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> // CHECK: %[[Y2:.*]] = affine.apply #[[$MAP4]]()[%[[TX]], %[[TY]], %[[TZ]]] // CHECK: %[[R3:.*]] = vector.transfer_read %{{.*}}[%[[Y2]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> - // CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> // CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -80,11 +80,11 @@ hal.executable private @shared_mem_cpy { // CHECK: %[[X1:.*]] = affine.apply #[[$MAP5]]()[%[[TX]], %[[TY]], %[[TZ]]] // CHECK: %[[R4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32> + // CHECK: vector.transfer_write %[[R4]], %{{.*}}[%[[C0]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> // CHECK: %[[R5:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32> + // CHECK: vector.transfer_write %[[R5]], %{{.*}}[%[[C1]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> // CHECK: %[[R6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32> - // CHECK: vector.transfer_write %[[R4]], %{{.*}}[%c0, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> - // CHECK: vector.transfer_write %[[R5]], %{{.*}}[%c1, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> - // CHECK: vector.transfer_write %[[R6]], %{{.*}}[%c2, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> + // CHECK: vector.transfer_write %[[R6]], %{{.*}}[%[[C2]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -100,3 +100,61 @@ hal.executable private @shared_mem_cpy { } } } + +// ----- + +// CHECK-DAG: #[[$OFFSET_MAP:.+]] = affine_map<()[s0] -> (s0 * 4)> + +#pipeline_layout = #hal.pipeline.layout]> +]> + +hal.executable private @unaligned_shared_memory_copy { + hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> { + hal.executable.export @unaligned_shared_memory_copy layout(#pipeline_layout) attributes { + workgroup_size = [32: index, 8: index, 1:index] + } { + ^bb0(%arg0: !hal.device, %arg1 : index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + + // CHECK-LABEL: func.func @unaligned_shared_memory_copy + // CHECK-SAME: (%[[GLOBAL_MEM:.+]]: memref<56x32xf32, {{.+}}>, %[[SHARED_MEM:.+]]: memref<56x32xf32, 3>) + func.func @unaligned_shared_memory_copy( + %global : memref<56x32xf32, strided<[128, 1], offset: ?>>, %shared : memref<56x32xf32, 3>) { + + // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index + // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + // CHECK-DAG: %[[C56:.+]] = arith.constant 56 : index + // CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index + + // CHECK-DAG: %[[TID_X:.+]] = gpu.thread_id x + // CHECK-DAG: %[[TID_Y:.+]] = gpu.thread_id y + + // CHECK: scf.for %[[IV_Y:.+]] = %[[TID_Y]] to %[[C56]] step %[[C8]] { + // CHECK: %[[OFFSET_X:.+]] = affine.apply #[[$OFFSET_MAP]]()[%[[TID_X]]] + // CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C128]] { + // CHECK: %[[GLOBAL_SUBVIEW:.+]] = memref.subview %[[GLOBAL_MEM]][%[[IV_Y]], %[[IV_X]]] [1, 4] [1, 1] + // CHECK-SAME: : memref<56x32xf32, {{.+}}> to memref<1x4xf32, {{.+}}> + // CHECK: %[[SHARED_SUBVIEW:.+]] = memref.subview %[[SHARED_MEM]][%[[IV_Y]], %[[IV_X]]] [1, 4] [1, 1] + // CHECK-SAME: : memref<56x32xf32, 3> to memref<1x4xf32, strided<[32, 1], offset: ?>, 3> + // CHECK: linalg.generic + // CHECK-SAME: ins(%[[GLOBAL_SUBVIEW]] + // CHECK-SAME: outs(%[[SHARED_SUBVIEW]] + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + ins(%global : memref<56x32xf32, strided<[128, 1], offset: ?>>) + outs(%shared : memref<56x32xf32, 3>) + attrs = {__internal_linalg_transform__ = "copy_to_workgroup_memory"} { + ^bb0(%arg0: f32, %arg1: f32): + linalg.yield %arg0 : f32 + } + return + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir new file mode 100644 index 000000000000..55bff1e80179 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir @@ -0,0 +1,54 @@ +// Test un-peeled epilogue generating AsyncCopyOp using zfill +// RUN: iree-opt --iree-gpu-pipelining=epilogue-peeling=false %s | FileCheck %s + +func.func @_matmul_f16_f16_dispatch_0_fill_3456x1024() { + %c2048 = arith.constant 2048 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> + %1 = gpu.thread_id x + %2 = gpu.thread_id y + %3 = gpu.thread_id z + %4 = memref.alloc() : memref<4x32x40xf16, 3> + %5 = memref.alloc() : memref<4x32x40xf16, 3> + %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<3456x2048xf16> + memref.assume_alignment %6, 64 : memref<3456x2048xf16> + %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<2048x1024xf16> + memref.assume_alignment %7, 64 : memref<2048x1024xf16> + %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<3456x1024xf16> + memref.assume_alignment %8, 64 : memref<3456x1024xf16> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %9 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%1, %2, %3] + %10 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%1] + %11 = scf.for %arg0 = %c0 to %c2048 step %c32 iter_args(%arg1 = %0) -> (!gpu.mma_matrix<16x16xf16, "COp">) { + gpu.barrier + %14 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 8 - (s1 floordiv 4) * 32)>()[%arg0, %1] + %15 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 16 + s2 * 32 + s3 * 32 + s0 floordiv 4)>()[%1, %2, %3, %workgroup_id_y] + %16 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) mod 4)>(%arg0) + %17 = nvgpu.device_async_copy %6[%15, %14], %4[%16, %9, %10], 8 : memref<3456x2048xf16> to memref<4x32x40xf16, 3> + %18 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 16 + s3 * 32 + s1 floordiv 4)>()[%arg0, %1, %2, %3] + %19 = affine.apply affine_map<()[s0, s1] -> (s0 * 8 + s1 * 32 - (s0 floordiv 4) * 32)>()[%1, %workgroup_id_x] + %20 = nvgpu.device_async_copy %7[%18, %19], %5[%16, %9, %10], 8 : memref<2048x1024xf16> to memref<4x32x40xf16, 3> + %21 = nvgpu.device_async_create_group %17, %20 + nvgpu.device_async_wait %21 + gpu.barrier + %22 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%2] + %23 = gpu.subgroup_mma_load_matrix %4[%16, %22, %c0] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %24 = gpu.subgroup_mma_load_matrix %4[%16, %22, %c16] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %25 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 16)>()[%1] + %26 = gpu.subgroup_mma_load_matrix %5[%16, %c0, %25] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp"> + %27 = gpu.subgroup_mma_load_matrix %5[%16, %c16, %25] {leadDimension = 40 : index} : memref<4x32x40xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp"> + %28 = gpu.subgroup_mma_compute %23, %26, %arg1 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + %29 = gpu.subgroup_mma_compute %24, %27, %28 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + scf.yield %29 : !gpu.mma_matrix<16x16xf16, "COp"> + } + %12 = affine.apply affine_map<()[s0, s1] -> (s0 * 16 + s1 * 32)>()[%2, %workgroup_id_y] + %13 = affine.apply affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 32) * 16)>()[%1, %workgroup_id_x] + gpu.subgroup_mma_store_matrix %11, %8[%12, %13] {leadDimension = 1024 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<3456x1024xf16> + return +} +// CHECK-LABEL: func.func @_matmul_f16_f16_dispatch_0_fill_3456x1024 +// CHECK: %[[CP_ID:.*]] = nvgpu.device_async_copy %[[GMEMPTR:.*]][%[[IDX:.*]]%[[IDY:.*]]], %[[SMEMPTR:.*]][%[[IDK_S:.*]]%[[IDX_S:.*]]%[[IDY_S:.*]]], 8, %[[PRED:.*]] : memref<3456x2048xf16> to memref<4x32x40xf16, 3> \ No newline at end of file diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir index 710e66522378..10349026465b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir @@ -35,7 +35,6 @@ func.func @matmul() { // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, s0)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK: func.func @matmul() // CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0] // CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1] @@ -1837,9 +1836,9 @@ module { // CHECK-DAG: %[[OFFSET_X:.+]] = hal.interface.constant.load[1] // CHECK: scf.for %[[IV0:.+]] = // CHECK: %[[SRC_VIEW:.+]] = memref.subview %[[SRC]][%[[IV0]]] -// CHECK-SAME: : memref to memref -// CHECK: %[[DST_VIEW:.+]] = memref.subview %[[DST]][0, %{{.+}}] [1, %{{.+}}] -// CHECK-SAME: : memref to memref +// CHECK-SAME: : memref to memref> +// CHECK: %[[DST_VIEW:.+]] = memref.subview %[[DST]][0, %{{[a-zA-Z0-9]+}}] [1, %{{[a-zA-Z0-9]+}}] +// CHECK-SAME: : memref to memref> // CHECK: linalg.generic {{.*}} ins(%[[SRC_VIEW]] {{.*}} outs(%[[DST_VIEW]] // ----- diff --git a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir index fabefab0d52d..7f5ee28cfaa3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir @@ -6,9 +6,9 @@ func.func @dynamic_alloc(%id : index) { %cst = arith.constant dense<0.000000e+00> : vector<4xf32> %dim = affine.min affine_map<()[s0] -> (s0 * -64 + 7, 64)>()[%id] // CHECK: %[[A:.*]] = memref.alloc() : memref<1x64x32xf32, 3> -// CHECK: %[[S:.*]] = memref.subview %[[A]][0, 0, 0] [1, %{{.*}}, 32] [1, 1, 1] : memref<1x64x32xf32, 3> to memref<1x?x32xf32, #{{.*}}, 3> +// CHECK: %[[S:.*]] = memref.subview %[[A]][0, 0, 0] [1, %{{.*}}, 32] [1, 1, 1] : memref<1x64x32xf32, 3> to memref<1x?x32xf32, strided<[2048, 32, 1]>, 3> %0 = memref.alloc(%dim) : memref<1x?x32xf32, 3> -// CHECK: vector.store %{{.*}}, %[[S]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<1x?x32xf32, #{{.*}}, 3>, vector<4xf32> +// CHECK: vector.store %{{.*}}, %[[S]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<1x?x32xf32, strided<[2048, 32, 1]>, 3>, vector<4xf32> vector.store %cst, %0[%c0, %c0, %c0] : memref<1x?x32xf32, 3>, vector<4xf32> return } @@ -20,9 +20,9 @@ func.func @dynamic_alloc_max_0(%id : index) { %dim = affine.min affine_map<()[s0] -> (s0 * -64 + 7, 64)>()[%id] %dim1 = affine.max affine_map<()[s0] -> (s0, 0)>()[%dim] // CHECK: %[[A:.*]] = memref.alloc() : memref<1x64x32xf32, 3> -// CHECK: %[[S:.*]] = memref.subview %[[A]][0, 0, 0] [1, %{{.*}}, 32] [1, 1, 1] : memref<1x64x32xf32, 3> to memref<1x?x32xf32, #{{.*}}, 3> +// CHECK: %[[S:.*]] = memref.subview %[[A]][0, 0, 0] [1, %{{.*}}, 32] [1, 1, 1] : memref<1x64x32xf32, 3> to memref<1x?x32xf32, strided<[2048, 32, 1]>, 3> %0 = memref.alloc(%dim1) : memref<1x?x32xf32, 3> -// CHECK: vector.store %{{.*}}, %[[S]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<1x?x32xf32, #{{.*}}, 3>, vector<4xf32> +// CHECK: vector.store %{{.*}}, %[[S]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<1x?x32xf32, strided<[2048, 32, 1]>, 3>, vector<4xf32> vector.store %cst, %0[%c0, %c0, %c0] : memref<1x?x32xf32, 3>, vector<4xf32> return } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir index c01a7b7a319a..331df5a2208e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir @@ -58,7 +58,7 @@ hal.executable private @dispatch_0 { ]> // CHECK-LABEL: func.func @workgroup_tile_loop() -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info hal.executable private @workgroup_tile_loop { hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> { hal.executable.export @workgroup_tile_loop layout(#pipeline_layout) attributes { @@ -66,8 +66,8 @@ hal.executable private @workgroup_tile_loop { } { ^bb0(%arg0 : !hal.device, %arg1 : index): %c1 = arith.constant 1 : index - %0 = affine.apply affine_map<(d0) -> (d0 ceildiv 32)>(%arg1) - hal.return %0, %c1, %c1 : index, index, index + %c64 = arith.constant 64 : index + hal.return %c64, %c1, %c1 : index, index, index } builtin.module { func.func @workgroup_tile_loop() { @@ -97,7 +97,7 @@ hal.executable private @workgroup_tile_loop { ]> // CHECK-LABEL: func.func @workgroup_tile_loop_negative() -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info hal.executable private @workgroup_tile_loop_negative { hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> { hal.executable.export @workgroup_tile_loop_negative layout(#pipeline_layout) attributes { @@ -138,7 +138,7 @@ hal.executable private @workgroup_tile_loop_negative { // CHECK-LABEL: func.func @both_workgroup_and_workitem() // CHECK-NOT: scf.for // CHECK: gpu.barrier -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info hal.executable private @both_workgroup_and_workitem { hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> { hal.executable.export @both_workgroup_and_workitem layout(#pipeline_layout) attributes { @@ -146,9 +146,10 @@ hal.executable private @both_workgroup_and_workitem { workgroup_size = [8: index, 2: index, 1: index] } { ^bb0(%arg0 : !hal.device, %arg1: index, %arg2 : index, %arg3 : index): - %0 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg2) - %1 = affine.apply affine_map<(d0) -> (d0 ceildiv 32)>(%arg3) - hal.return %1, %0, %arg1 : index, index, index + %c1 = arith.constant 1 : index + %c14 = arith.constant 14 : index + %c112 = arith.constant 112: index + hal.return %c1, %c14, %c112 : index, index, index } builtin.module { func.func @both_workgroup_and_workitem() { @@ -200,7 +201,7 @@ hal.executable private @both_workgroup_and_workitem { #device_target_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]}> #pipeline_layout = #hal.pipeline.layout, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]> #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #map0 = affine_map<()[s0] -> (s0 ceildiv 4)> #map1 = affine_map<()[s0] -> (s0 * 4)> #map2 = affine_map<()[s0, s1] -> (-((s0 * -4 + 4) mod (s1 * 4)) + 4)> diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir index 1e18d9241e76..825756228852 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir @@ -53,7 +53,7 @@ hal.executable private @matmul_tensors { // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @matmul_tensors // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device, @@ -149,7 +149,7 @@ hal.executable private @add { } } // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @add // CHECK: hal.executable.export public @add // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -222,7 +222,7 @@ hal.executable private @add4D { } } // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @add4D // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device, @@ -293,7 +293,7 @@ hal.executable private @batch_matmul_tensors { } } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @batch_matmul_tensors // CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[WORKLOAD_0:[a-zA-Z0-9_]+]]: index @@ -355,18 +355,18 @@ hal.executable private @preset_config_matmul_tensors { } } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @preset_config // CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[WORKLOAD_0:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_2:[a-zA-Z0-9_]+]]: index) // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[WORKLOAD_0]]] -// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[WORKLOAD_1]]] -// CHECK: hal.return %[[D1]], %[[D0]], %[[C1]] +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK: hal.return %[[C32]], %[[C4]], %[[C1]] // CHECK: func.func @preset_config() // CHECK: scf.for %[[IV0:.+]] = // CHECK: scf.for %[[IV1:.+]] = @@ -431,7 +431,7 @@ hal.executable public @copy_op { // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @copy_op // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device, @@ -517,7 +517,7 @@ hal.executable private @static_1d_fft_stage2 { } } // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @static_1d_fft_stage2 // CHECK: hal.executable.export public @static_1d_fft_stage2 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -567,7 +567,7 @@ hal.executable private @static_3d_fft_stage3 { } } // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @static_3d_fft_stage3 // CHECK: hal.executable.export public @static_3d_fft_stage3 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -650,7 +650,7 @@ hal.executable private @outs_fusion { } } // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @outs_fusion // CHECK: hal.executable.export public @outs_fusion_fn // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -727,7 +727,7 @@ hal.executable private @conv { } } // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @conv // CHECK: hal.executable.export public @conv // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -799,10 +799,10 @@ hal.executable private @conv_static { } } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 40)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 ceildiv 48)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 20)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 40)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 48)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @conv_static // CHECK: hal.executable.export public @conv_static // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -813,10 +813,9 @@ hal.executable private @conv_static { // CHECK-SAME: %[[WORKLOAD_3:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_4:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_5:[a-zA-Z0-9_]+]]: index) -// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[WORKLOAD_1]]] -// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[WORKLOAD_2]]] -// CHECK-DAG: %[[D2:.+]] = affine.apply #[[MAP2]]()[%[[WORKLOAD_3]]] -// CHECK: hal.return %[[D2]], %[[D1]], %[[D0]] : index, index, index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: hal.return %[[C2]], %[[C2]], %[[C4]] : index, index, index // CHECK: func.func @conv_static() // CHECK: scf.for %[[IV0:.+]] = // CHECK: scf.for %[[IV1:.+]] = @@ -873,9 +872,9 @@ hal.executable private @generic_static { } } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @generic_static // CHECK: hal.executable.export public @generic_static // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -883,9 +882,8 @@ hal.executable private @generic_static { // CHECK-SAME: %[[WORKLOAD_0:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_1:[a-zA-Z0-9_]+]]: index) // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[WORKLOAD_0]]] -// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[WORKLOAD_1]]] -// CHECK: hal.return %[[D1]], %[[D0]], %[[C1]] : index, index, index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK: hal.return %[[C3]], %[[C1]], %[[C1]] : index, index, index // CHECK: func.func @generic_static() // CHECK: scf.for %[[IV0:.+]] = // CHECK: scf.for %[[IV1:.+]] = @@ -938,9 +936,9 @@ hal.executable private @matmul_static { } } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 28)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @matmul_static // CHECK: hal.executable.export public @matmul_static // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -949,9 +947,9 @@ hal.executable private @matmul_static { // CHECK-SAME: %[[WORKLOAD_1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_2:[a-zA-Z0-9_]+]]: index) // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[WORKLOAD_0]]] -// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[WORKLOAD_1]]] -// CHECK: hal.return %[[D1]], %[[D0]], %[[C1]] : index, index, index +// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK: hal.return %[[C5]], %[[C7]], %[[C1]] : index, index, index // ----- @@ -999,9 +997,9 @@ hal.executable private @restrict_num_workgroups { } } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 7)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 7)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @restrict_num_workgroups // CHECK: hal.executable.export public @restrict_num_workgroups // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -1012,9 +1010,10 @@ hal.executable private @restrict_num_workgroups { // CHECK-SAME: %[[WORKLOAD_3:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_4:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[WORKLOAD_5:[a-zA-Z0-9_]+]]: index) -// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[WORKLOAD_2]]] -// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[WORKLOAD_3]]] -// CHECK: hal.return %[[D1]], %[[D0]], %[[WORKLOAD_1]] : index, index, index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[C9:.+]] = arith.constant 9 : index +// CHECK: hal.return %[[C9]], %[[C1]], %[[C7]] : index, index, index // ----- @@ -1071,16 +1070,16 @@ hal.executable private @reduction { } } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @reduction // CHECK: hal.executable.export public @reduction // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[WORKLOAD:[a-zA-Z0-9_]+]]: index) // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[WORKLOAD]]] -// CHECK: hal.return %[[D0]], %[[C1]], %[[C1]] : index, index, index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: hal.return %[[C2]], %[[C1]], %[[C1]] : index, index, index // CHECK: func.func @reduction // CHECK: scf.for %[[IV0:.+]] = // CHECK: %[[INIT:.+]] = linalg.init_tensor @@ -1140,7 +1139,7 @@ hal.executable private @gemm_unit_N { } } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @gemm_unit_N // CHECK: hal.executable.export public @gemm_unit_N // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -1275,7 +1274,7 @@ hal.executable private @generic_unit_dims { } } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable private @generic_unit_dims // CHECK: hal.executable.export public @generic_unit_dims // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -1463,15 +1462,15 @@ hal.executable private @rank_reduced_slice { } } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)> // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 + 10)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @rank_reduced_slice // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK-NEXT: %[[WORKLOAD:[a-zA-Z0-9]+]]: index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[WORKLOAD]]] -// CHECK: hal.return %[[D0]], %[[C1]], %[[C1]] +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 +// CHECK: hal.return %[[C5]], %[[C1]], %[[C1]] // CHECK: func.func @rank_reduced_slice() // CHECK-DAG: %[[SRC_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0) // CHECK-SAME: : !flow.dispatch.tensor @@ -1537,7 +1536,7 @@ hal.executable private @matmul_interchange { } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @matmul_interchange // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK-NEXT: (%[[DEVICE:.+]]: !hal.device, diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vectorize_linalg_conv.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vectorize_linalg_conv.mlir deleted file mode 100644 index 4ec8bd9b4697..000000000000 --- a/compiler/src/iree/compiler/Codegen/Common/test/vectorize_linalg_conv.mlir +++ /dev/null @@ -1,292 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-codegen-vectorize-linalg-conv --canonicalize -cse %s | FileCheck %s - -func.func @vectorize_conv(%filter: memref<1x1x3x4xf32>, %input: memref<1x2x2x3xf32>, %output: memref<1x2x2x4xf32>) { - linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} - ins (%input, %filter: memref<1x2x2x3xf32>, memref<1x1x3x4xf32>) - outs (%output: memref<1x2x2x4xf32>) - return -} - -// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK: func.func @vectorize_conv -// CHECK-SAME: %[[FILTER_SUBVIEW:.+]]: memref<1x1x3x4xf32>, -// CHECK-SAME: %[[INPUT_SUBVIEW:.+]]: memref<1x2x2x3xf32>, -// CHECK-SAME: %[[OUTPUT_SUBVIEW:.+]]: memref<1x2x2x4xf32> - -// CHECK: %[[FLOAT_ZERO:.+]] = arith.constant 0.000000e+00 : f32 - -// Read in the filter and get slices -// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<1x1x3x4xf32>, vector<3x4xf32> -// CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> -// CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> -// CHECK: %[[FILTER_2:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> - -// Handle batch #0 -// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x4xf32>, vector<1x4xf32> -// CHECK: %[[INPUT_0_0:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_0_1:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_0_2:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x2x2x4xf32> - -// Handle batch #1 -// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x4xf32>, vector<1x4xf32> -// CHECK: %[[INPUT_1_0:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_1_1:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_1_2:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x2x2x4xf32> - -// Handle batch #2 -// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x4xf32>, vector<1x4xf32> -// CHECK: %[[INPUT_2_0:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_0]], %[[FILTER_0]], %[[OUTPUT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_2_1:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_2_2:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x2x2x4xf32> - -// Handle batch #3 -// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {in_bounds = [true, true]} : memref<1x2x2x4xf32>, vector<1x4xf32> -// CHECK: %[[INPUT_3_0:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_0]], %[[FILTER_0]], %[[OUTPUT_3]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_3_1:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: %[[INPUT_3_2:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> -// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x2x2x4xf32> - -// ----- - -// CHECK-LABEL: func.func @do_not_vectorize_conv_with_non_1_batch -func.func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<1x1x4x4xf32>, %input: memref<2x1x7x4xf32>, %output: memref<2x1x4x4xf32>) { - // CHECK: linalg.conv_2d_nhwc_hwcf - linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} - ins (%input, %filter: memref<2x1x7x4xf32>, memref<1x1x4x4xf32>) - outs (%output: memref<2x1x4x4xf32>) - return -} - -// ----- - -// CHECK-LABEL: func.func @do_not_vectorize_conv_with_non_1_filter_height -func.func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<2x1x4x4xf32>, %input: memref<1x2x7x4xf32>, %output: memref<1x1x4x4xf32>) { - // CHECK: linalg.conv_2d_nhwc_hwcf - linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} - ins (%input, %filter: memref<1x2x7x4xf32>, memref<2x1x4x4xf32>) - outs (%output: memref<1x1x4x4xf32>) - return -} - -// ----- - -// CHECK-LABEL: func.func @do_not_vectorize_conv_with_non_1_filter_width -func.func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<1x2x4x4xf32>, %input: memref<1x1x8x4xf32>, %output: memref<1x1x4x4xf32>) { - // CHECK: linalg.conv_2d_nhwc_hwcf - linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} - ins (%input, %filter: memref<1x1x8x4xf32>, memref<1x2x4x4xf32>) - outs (%output: memref<1x1x4x4xf32>) - return -} - -// ----- - -// CHECK-LABEL: func.func @do_not_vectorize_conv_with_non_1_dilation -func.func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<1x1x4x4xf32>, %input: memref<1x1x7x4xf32>, %output: memref<1x1x4x4xf32>) { - // CHECK: linalg.conv_2d_nhwc_hwcf - linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : vector<2xi64>, strides = dense<2> : vector<2xi64>} - ins (%input, %filter: memref<1x1x7x4xf32>, memref<1x1x4x4xf32>) - outs (%output: memref<1x1x4x4xf32>) - return -} - -// ----- - -func.func @vectorize_depthwise_conv(%input: memref<1x3x3x8xf32>, %filter: memref<1x1x8xf32>, %output: memref<1x2x2x8xf32>) { - linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%input, %filter : memref<1x3x3x8xf32>, memref<1x1x8xf32>) outs(%output : memref<1x2x2x8xf32>) - return -} - -// CHECK-LABEL: func.func @vectorize_depthwise_conv -// CHECK-SAME: %[[INPUT_SUBVIEW:.+]]: memref<1x3x3x8xf32>, -// CHECK-SAME: %[[FILTER_SUBVIEW:.+]]: memref<1x1x8xf32>, -// CHECK-SAME: %[[OUTPUT_SUBVIEW:.+]]: memref<1x2x2x8xf32> - -// CHECK: %[[FLOAT_ZERO:.+]] = arith.constant 0.000000e+00 : f32 - -// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_SUBVIEW]][%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8xf32>, vector<1x1x8xf32> - -// Common filter #0 -// CHECK: %[[FILTER_0_SLICE:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x1x8xf32> to vector<1x1x4xf32> -// CHECK: %[[FILTER_0:.+]] = vector.shape_cast %[[FILTER_0_SLICE]] : vector<1x1x4xf32> to vector<1x1x1x4xf32> - -// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0_0]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c2, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_0]], %[[OUTPUT_0_1]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1_0]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c2, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_0]], %[[OUTPUT_1_1]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// Common filter #1 -// CHECK: %[[FILTER_1_SLICE:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0, 0, 4], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x1x8xf32> to vector<1x1x4xf32> -// CHECK: %[[FILTER_1:.+]] = vector.shape_cast %[[FILTER_1_SLICE]] : vector<1x1x4xf32> to vector<1x1x1x4xf32> - -// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_1]], %[[OUTPUT_0_0]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c4] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c2, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_1]], %[[OUTPUT_0_1]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c4] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c0, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_1]], %[[OUTPUT_1_0]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c4] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x2x2x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c2, %c4], %cst {in_bounds = [true, true, true, true]} : memref<1x3x3x8xf32>, vector<1x1x1x4xf32> -// CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_1]], %[[OUTPUT_1_1]] : vector<1x1x1x4xf32> -// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c4] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x2x2x8xf32> - -// ----- - -// CHECK-LABEL: func.func @do_not_vectorize_depthwise_conv_with_non_1_filter_height -func.func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<1x2x3x4xf32>, %filter: memref<2x1x4xf32>, %output: memref<1x1x2x4xf32>) { - // CHECK: linalg.depthwise_conv_2d_nhwc_hwc - linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} - ins(%input, %filter : memref<1x2x3x4xf32>, memref<2x1x4xf32>) - outs(%output : memref<1x1x2x4xf32>) - return -} - -// ----- - -// CHECK-LABEL: func.func @do_not_vectorize_depthwise_conv_with_non_1_filter_width -func.func @do_not_vectorize_depthwise_conv_with_non_1_filter_width(%input: memref<1x1x4x4xf32>, %filter: memref<1x2x4xf32>, %output: memref<1x1x2x4xf32>) { - // CHECK: linalg.depthwise_conv_2d_nhwc_hwc - linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} - ins(%input, %filter : memref<1x1x4x4xf32>, memref<1x2x4xf32>) - outs(%output : memref<1x1x2x4xf32>) - return -} - -// ----- - -func.func @vectorize_conv(%filter: tensor<1x1x3x4xf32>, %input: tensor<1x2x2x3xf32>, %init: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> { - %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} - ins (%input, %filter: tensor<1x2x2x3xf32>, tensor<1x1x3x4xf32>) - outs (%init: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> - return %0 : tensor<1x2x2x4xf32> -} - -// CHECK-LABEL: func.func @vectorize_conv -// CHECK-SAME: %[[FILTER_TENSOR:.+]]: tensor<1x1x3x4xf32>, -// CHECK-SAME: %[[INPUT_TENSOR:.+]]: tensor<1x2x2x3xf32>, -// CHECK-SAME: %[[INIT_TENSOR:.+]]: tensor<1x2x2x4xf32> - -// CHECK: vector.transfer_read %[[FILTER_TENSOR]] -// CHECK-COUNT-3: vector.extract_strided_slice - -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK-COUNT-3: vector.contract -// CHECK: %[[WRITE0:.+]] = vector.transfer_write %{{.+}}, %[[INIT_TENSOR]] - -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK-COUNT-3: vector.contract -// CHECK: %[[WRITE1:.+]] = vector.transfer_write %{{.+}}, %[[WRITE0]] - -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK-COUNT-3: vector.contract -// CHECK: %[[WRITE2:.+]] = vector.transfer_write %{{.+}}, %[[WRITE1]] - -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK-COUNT-3: vector.contract -// CHECK: %[[WRITE3:.+]] = vector.transfer_write %{{.+}}, %[[WRITE2]] - -// ----- - -func.func @vectorize_depthwise_conv(%input: tensor<1x3x3x8xf32>, %filter: tensor<1x1x8xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { - %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} - ins(%input, %filter : tensor<1x3x3x8xf32>, tensor<1x1x8xf32>) - outs(%init : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> - return %0 : tensor<1x2x2x8xf32> -} - -// CHECK-LABEL: func.func @vectorize_depthwise_conv -// CHECK-SAME: %[[INPUT_TENSOR:.+]]: tensor<1x3x3x8xf32>, -// CHECK-SAME: %[[FILTER_TENSOR:.+]]: tensor<1x1x8xf32>, -// CHECK-SAME: %[[INIT_TENSOR:.+]]: tensor<1x2x2x8xf32> - -// CHECK: vector.transfer_read %[[FILTER_TENSOR]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE0:.+]] = vector.transfer_write %{{.+}}, %[[INIT_TENSOR]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE1:.+]] = vector.transfer_write %{{.+}}, %[[WRITE0]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE2:.+]] = vector.transfer_write %{{.+}}, %[[WRITE1]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE3:.+]] = vector.transfer_write %{{.+}}, %[[WRITE2]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE4:.+]] = vector.transfer_write %{{.+}}, %[[WRITE3]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE5:.+]] = vector.transfer_write %{{.+}}, %[[WRITE4]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE6:.+]] = vector.transfer_write %{{.+}}, %[[WRITE5]] - -// CHECK: vector.transfer_read %[[INIT_TENSOR]] -// CHECK: vector.transfer_read %[[INPUT_TENSOR]] -// CHECK: vector.fma -// CHECK: %[[WRITE7:.+]] = vector.transfer_write %{{.+}}, %[[WRITE6]] diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vectorize_tensor_pad.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vectorize_tensor_pad.mlir index b191b3768e92..8cd4d37ac362 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/vectorize_tensor_pad.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/vectorize_tensor_pad.mlir @@ -4,7 +4,7 @@ // https://reviews.llvm.org/D117021 // Once it lands, this pattern can be replaced. -func.func @pad_tensor(%source: tensor<1x?x?x3xf32>, %low1: index, %low2: index, %high1: index, %high2: index) -> tensor<1x2x2x3xf32> { +func.func @tensor_pad(%source: tensor<1x?x?x3xf32>, %low1: index, %low2: index, %high1: index, %high2: index) -> tensor<1x2x2x3xf32> { %cst = arith.constant 0.0 : f32 %pad = tensor.pad %source low[0, %low1, %low2, 0] high[0, %high1, %high2, 0] { ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): @@ -13,7 +13,7 @@ func.func @pad_tensor(%source: tensor<1x?x?x3xf32>, %low1: index, %low2: index, return %pad: tensor<1x2x2x3xf32> } -// CHECK-LABEL: func.func @pad_tensor +// CHECK-LABEL: func.func @tensor_pad // CHECK-SAME: (%[[SOURCE:.+]]: tensor<1x?x?x3xf32>, %[[LOW1:.+]]: index, %[[LOW2:.+]]: index, %{{.+}}: index, %{{.+}}: index) // CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp index 61c306ccda41..9c34ba3314e9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp @@ -75,13 +75,10 @@ namespace Codegen { TranslationInfoAttr TranslationInfoAttr::get( MLIRContext *context, DispatchLoweringPassPipeline passPipeline, - ArrayRef workloadPerWorkgroup, unsigned softwarePipelineDepth) { + unsigned softwarePipelineDepth) { auto pipelineAttr = DispatchLoweringPassPipelineAttr::get(context, passPipeline); - ArrayAttr workloadPerWorkgroupAttr = - getI64IntegerArrayAttr(context, workloadPerWorkgroup); - return get(context, pipelineAttr, workloadPerWorkgroupAttr, - softwarePipelineDepth); + return get(context, pipelineAttr, softwarePipelineDepth); } DispatchLoweringPassPipeline @@ -89,14 +86,10 @@ TranslationInfoAttr::getDispatchLoweringPassPipeline() { return getPassPipeline().getValue(); } -SmallVector TranslationInfoAttr::getWorkloadPerWorkgroupVals() { - return getIntegerVals(getWorkloadPerWorkgroup()); -} - LogicalResult TranslationInfoAttr::verify( function_ref emitError, IREE::Codegen::DispatchLoweringPassPipelineAttr passPipeline, - ArrayAttr workloadPerWorkgroup, unsigned softwarePipelineDepth) { + unsigned softwarePipelineDepth) { if (!passPipeline) { return emitError() << "missing pass pipeline specification"; } @@ -219,7 +212,6 @@ LogicalResult CompilationInfoAttr::verify( } if (failed(TranslationInfoAttr::verify( emitError, translationInfo.getPassPipeline(), - translationInfo.getWorkloadPerWorkgroup(), translationInfo.getSoftwarePipelineDepth()))) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h index ad7ca62b6391..cc284bd7a8ac 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h @@ -84,13 +84,12 @@ inline void setTranslationInfo( inline void setTranslationInfo( func::FuncOp entryPointFn, IREE::Codegen::DispatchLoweringPassPipeline passPipeline, - ArrayRef workloadPerWorkgroup, ArrayRef workgroupSize, - unsigned softwarePipelineDepth = 0) { + ArrayRef workgroupSize, unsigned softwarePipelineDepth = 0) { FailureOr exportOp = getEntryPoint(entryPointFn); MLIRContext *context = entryPointFn.getContext(); - auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( - context, passPipeline, workloadPerWorkgroup); + auto translationInfo = + IREE::Codegen::TranslationInfoAttr::get(context, passPipeline); setTranslationInfo(*exportOp, translationInfo, workgroupSize); } @@ -124,7 +123,7 @@ inline LogicalResult setOpConfigAndEntryPointFnTranslation( auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes); setLoweringConfig(op, config); auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( - entryPointFn->getContext(), passPipeline, {}, softwarePipelineDepth); + entryPointFn->getContext(), passPipeline, softwarePipelineDepth); setTranslationInfo(entryPointFn, translationInfo, workgroupSize); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td index f0fc15ffff2d..06f12ac70a58 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td @@ -95,37 +95,27 @@ def IREECodegen_TranslationInfoAttr : The fields are - `passPipeline` : The pass pipeline to use. - - `workloadPerWorkgroup` : Specifies how much of the original - `workload` is handled by a workgroup along `x`, `y` and `z`. If - left empty it implies that that there is a single workgroup that - does the entire `workload`. }]; let assemblyFormat = [{ - `<` `` $passPipeline (`workload_per_wg` `=` $workloadPerWorkgroup^)? + `<` `` $passPipeline (`pipeline_depth` `=` $softwarePipelineDepth^)? `>` }]; let parameters = (ins AttrParameter<"IREE::Codegen::DispatchLoweringPassPipelineAttr", "Name of the pipeline to be invoked on the translation unit.">:$passPipeline, - DefaultValuedParameter<"ArrayAttr", "ArrayAttr::get($_ctxt, {})", - "The workload mapped to a single workgroup">:$workloadPerWorkgroup, - DefaultValuedParameter<"unsigned", "1", + OptionalParameter<"unsigned", "The software pipeline depth to be used">:$softwarePipelineDepth ); let builders = [ AttrBuilder<(ins "DispatchLoweringPassPipeline":$passPipeline, - CArg<"ArrayRef", "{}">:$workloadPerWorkgroup, CArg<"unsigned", "0">:$softwarePipelineDepth)> ]; let extraClassDeclaration = [{ // Returns the lowering pass pipeline set. DispatchLoweringPassPipeline getDispatchLoweringPassPipeline(); - - // Returns values of the workloadPerWorkgroup field if set. - SmallVector getWorkloadPerWorkgroupVals(); }]; let genVerifyDecl = 1; } @@ -201,10 +191,9 @@ def IREECodegen_CompilationInfoAttr : specifies the behaviour of the compilation path chosen with `TranslationInfoAttr`. This could be added in the future. Note: Typically the values used for the first-level tiling in - `LoweringConfigAttr` and `workload_per_wg` value in the - `TranslationInfoAttr` are the same since the first-level of tile + - distribute is already done at the `Flow` level. This verification - is also a TODO. + `LoweringConfigAttr` value in the `TranslationInfoAttr` are the + same since the first-level of tile + distribute is already done + at the `Flow` level. This verification is also a TODO. }]; let parameters = (ins AttrParameter<"LoweringConfigAttr", "">:$loweringConfig, @@ -222,7 +211,7 @@ def IREECodegen_CompilationInfoAttr : let builders = [ AttrBuilder<(ins "LoweringConfigAttr":$configAttr, "TranslationInfoAttr":$translationInfo, - "ArrayRef":$workloadPerWorkgroup)>, + "ArrayRef":$workgroupSize)>, ]; let extraClassDeclaration = [{ SmallVector getWorkgroupSizeVals(); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir b/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir index d08065842ead..8d9b6965e4ae 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir @@ -1,15 +1,5 @@ // RUN: iree-opt --split-input-file %s | FileCheck %s -module { - func.func @test() attributes { - lowring_config = #iree_codegen.translation_info} { - return - } -} -// CHECK: #translation = #iree_codegen.translation_info - -// ----- - module { func.func @test() attributes { translation_info = #iree_codegen.translation_info} { diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD index ef0af0bd09d1..1346ee55d701 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD @@ -20,15 +20,18 @@ iree_compiler_cc_library( "LLVMCPUAArch64VectorLowering.cpp", "LLVMCPUCheckIRBeforeLLVMConversion.cpp", "LLVMCPUEmitVectorizationRemarks.cpp", + "LLVMCPULinkExecutables.cpp", "LLVMCPULowerExecutableTarget.cpp", "LLVMCPUSynchronizeSymbolVisibility.cpp", "LLVMCPUUnfuseFMAOps.cpp", "Passes.cpp", + "TargetMLTransformInfo.cpp", "VectorContractCustomKernels.cpp", "VerifyLinalgTransformLegality.cpp", ], hdrs = [ "KernelDispatch.h", + "TargetMLTransformInfo.h", ], deps = [ "//compiler/src/iree/compiler/Codegen:PassHeaders", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index f5aee1ad07b9..046094813ceb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -15,16 +15,19 @@ iree_cc_library( LLVMCPU HDRS "KernelDispatch.h" + "TargetMLTransformInfo.h" SRCS "ConvertToLLVM.cpp" "KernelDispatch.cpp" "LLVMCPUAArch64VectorLowering.cpp" "LLVMCPUCheckIRBeforeLLVMConversion.cpp" "LLVMCPUEmitVectorizationRemarks.cpp" + "LLVMCPULinkExecutables.cpp" "LLVMCPULowerExecutableTarget.cpp" "LLVMCPUSynchronizeSymbolVisibility.cpp" "LLVMCPUUnfuseFMAOps.cpp" "Passes.cpp" + "TargetMLTransformInfo.cpp" "VectorContractCustomKernels.cpp" "VerifyLinalgTransformLegality.cpp" DEPS diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp index c8dc24af00c3..88169132a562 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp @@ -687,7 +687,7 @@ class ConvertHALEntryPointFuncOp : public ConvertToLLVMPattern { IntegerType::get(rewriter.getContext(), 32), abiInputTypes); auto llvmFuncOp = rewriter.create( stdFuncOp.getLoc(), stdFuncOp.getName(), llvmFuncType, - LLVM::Linkage::Internal, /*dso_local=*/false, /*cconv*/ LLVM::CConv::C, + LLVM::Linkage::External, /*dso_local=*/false, /*cconv*/ LLVM::CConv::C, funcAttrs); rewriter.inlineRegionBefore(stdFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index fa1e7958445c..7051d9703a0b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -9,6 +9,9 @@ #include #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Codegen/Common/LinalgOpInfo.h" +#include "iree/compiler/Codegen/Common/UserConfig.h" +#include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" @@ -177,10 +180,12 @@ static int64_t getVectorSize(func::FuncOp entryPointFn, ShapedType shapedType) { // tile sizes for vectorization/unrolling in one shot. static SmallVector getMinTilingSizesForEachDim( func::FuncOp entryPointFn, linalg::LinalgOp op, - unsigned maxUnrollFactor = 8) { + const LinalgOpInfo &linalgOpInfo, + const TargetMLTransformInfo &targetMLTransInfo) { unsigned numLoops = op.getNumLoops(); SmallVector minTileSizes(numLoops, 1); auto inputOutputOpOperands = op.getInputAndOutputOperands(); + for (auto map : llvm::enumerate(op.getIndexingMapsArray())) { // Check the fastest varying dimension of the operand. Set the vector size // of the corresponding loop to the vector size. @@ -194,16 +199,36 @@ static SmallVector getMinTilingSizesForEachDim( auto operandType = inputOutputOpOperands[map.index()]->get().getType().cast(); int64_t tileSize = getVectorSize(entryPointFn, operandType); - // Vectorization of reductions is driven by input tensors and considering - // the output's fastest varying dim leads to large unroll factors. We limit - // the tile size for this case to 'maxUnrollFactor'. - if (op.isOutputTensor(inputOutputOpOperands[map.index()]) && - op.getNumReductionLoops() > 0) - tileSize = std::min(tileSize, maxUnrollFactor); minTileSizes[fastestVaryingDim] = std::max(minTileSizes[fastestVaryingDim], tileSize); } + + // Limit unroll factor. For now, we assume the rightmost non-one tiled + // dimension is for vectorization and any other non-one dimension is for + // unrolling. + auto limitUnrollFactor = [&](int64_t maxUnrollFactor) { + int vecDim; + for (vecDim = minTileSizes.size() - 1; vecDim >= 0; --vecDim) { + if (minTileSizes[vecDim] > 1) { + break; + } + } + for (int unrollDim = vecDim - 1; unrollDim >= 0; --unrollDim) { + minTileSizes[unrollDim] = + std::min(minTileSizes[unrollDim], maxUnrollFactor); + } + }; + + if (linalgOpInfo.isTranspose()) { + // Limit unrolling on transpose operations. + // TODO(dcaballe): Consider input and output transposes. + limitUnrollFactor(targetMLTransInfo.defaultMaxTransposeUnrollFactor); + } else { + // Limit unrolling to the default target maximum. + limitUnrollFactor(targetMLTransInfo.defaultMaxUnrollFactor); + } + return minTileSizes; } @@ -334,7 +359,7 @@ static int64_t getMaxTileSize(int64_t lb, int64_t ub, int64_t maxSize, return maxSize; } int64_t dim = ub - lb; - if (dim < vectorSizeVal) return dim; + if (dim <= maxSize && dim < vectorSizeVal) return dim; int64_t scaledUB = std::min(maxSize, dim) / vectorSizeVal * vectorSizeVal; for (int64_t i = scaledUB; i > 0; i -= vectorSizeVal) { @@ -990,7 +1015,9 @@ static bool isSupportedTransposeOp(linalg::GenericOp genericOp) { /// Sets the default lowering configuration for a generic op to use /// CPUDoubleTilingExpert pipeline. static LogicalResult setDefaultGenericOpRootConfig( - func::FuncOp entryPointFn, linalg::GenericOp genericOp) { + func::FuncOp entryPointFn, linalg::GenericOp genericOp, + const LinalgOpInfo &linalgOpInfo, + const TargetMLTransformInfo &targetMLTransInfo) { if (getLoweringConfig(genericOp)) { return success(); } @@ -1003,8 +1030,8 @@ static LogicalResult setDefaultGenericOpRootConfig( DispatchLoweringPassPipeline::CPUDefault); } - SmallVector minTileSizes = - getMinTilingSizesForEachDim(entryPointFn, genericOp); + SmallVector minTileSizes = getMinTilingSizesForEachDim( + entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo); // For generic ops we'll use the default divided by 2 to control the stack // allocation limit See #9469 for example. SmallVector maxTileSizes(numLoops, defaultWorkgroupTileSize / 2); @@ -1047,8 +1074,10 @@ static LogicalResult setDefaultGenericOpRootConfig( /// Sets the lowering configuration for a generic op implementing a /// transposition to use CPUDoubleTilingExpert pipeline. -static LogicalResult setTransposeLikeOpRootConfig(func::FuncOp entryPointFn, - linalg::GenericOp genericOp) { +static LogicalResult setTransposeLikeOpRootConfig( + func::FuncOp entryPointFn, linalg::GenericOp genericOp, + const LinalgOpInfo &linalgOpInfo, + const TargetMLTransformInfo &targetMLTransInfo) { if (getLoweringConfig(genericOp)) { return success(); } @@ -1060,8 +1089,8 @@ static LogicalResult setTransposeLikeOpRootConfig(func::FuncOp entryPointFn, } unsigned numLoops = genericOp.getNumLoops(); - SmallVector minTileSizes = - getMinTilingSizesForEachDim(entryPointFn, genericOp); + SmallVector minTileSizes = getMinTilingSizesForEachDim( + entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo); SmallVector maxTileSizes(numLoops, defaultWorkgroupTileSize); if (llvm::all_of(minTileSizes, [](int64_t vs) { return vs == 1; })) { // Nothing to vectorize just lower to loops. @@ -1116,7 +1145,9 @@ static LogicalResult setTransposeLikeOpRootConfig(func::FuncOp entryPointFn, /// workload per workgroup to a larger number, which prevents runtime overheads /// from tiny dispatches. static LogicalResult setElementwiseGenericOpRootConfig( - func::FuncOp entryPointFn, linalg::GenericOp genericOp) { + func::FuncOp entryPointFn, linalg::GenericOp genericOp, + const LinalgOpInfo &linalgOpInfo, + const TargetMLTransformInfo &targetMLTransInfo) { if (getLoweringConfig(genericOp)) { return success(); } @@ -1126,8 +1157,8 @@ static LogicalResult setElementwiseGenericOpRootConfig( if (!linalg::isElementwise(genericOp)) return success(); // Set the flow level tiling to the default. - SmallVector minTileSizes = - getMinTilingSizesForEachDim(entryPointFn, genericOp); + SmallVector minTileSizes = getMinTilingSizesForEachDim( + entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo); SmallVector maxTileSizes(numLoops, defaultWorkgroupTileSize); SmallVector flowTileSizes = getDefaultDistributedLevelTileSizes(genericOp, minTileSizes, maxTileSizes, @@ -1193,24 +1224,28 @@ static LogicalResult setElementwiseGenericOpRootConfig( /// Sets the lowering configuration for a generic op to use /// CPUDoubleTilingExpert pipeline. -static LogicalResult setRootConfig(func::FuncOp entryPointFn, - linalg::GenericOp genericOp) { - if (failed(setTransposeLikeOpRootConfig(entryPointFn, genericOp)) || - failed(setElementwiseGenericOpRootConfig(entryPointFn, genericOp)) || - failed(setDefaultGenericOpRootConfig(entryPointFn, genericOp))) { +static LogicalResult setRootConfig( + func::FuncOp entryPointFn, linalg::GenericOp genericOp, + const LinalgOpInfo &linalgOpInfo, + const TargetMLTransformInfo &targetMLTransInfo) { + if (failed(setTransposeLikeOpRootConfig(entryPointFn, genericOp, linalgOpInfo, + targetMLTransInfo)) || + failed(setElementwiseGenericOpRootConfig( + entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo)) || + failed(setDefaultGenericOpRootConfig(entryPointFn, genericOp, + linalgOpInfo, targetMLTransInfo))) { return failure(); } return success(); } -/// Sets the lowering configuration for linalg.conv_2d_nhwc_hwcf and -/// linalg.depthwise_conv_2d_nhwc_hwc operations. +/// Sets lowering configuration for conv ops. See below for supported conv ops. static LogicalResult setConvRootConfig(func::FuncOp entryPointFn, linalg::LinalgOp convOp, ArrayRef targetTileSizes, int64_t vectorSize) { - if (!isa( - convOp.getOperation())) { + if (!isa(convOp.getOperation())) { return failure(); } @@ -1262,27 +1297,46 @@ static LogicalResult setConvRootConfig(func::FuncOp entryPointFn, static SmallVector getConvWorkgroupSizes(func::FuncOp entryPointFn, linalg::LinalgOp op, int64_t vectorSize) { - bool isSupported = - isa( - op.getOperation()); + bool isSupported = isa(op.getOperation()); (void)isSupported; - assert(isSupported && "expected conv with nhwc input and hwcf kernel/filter"); + assert(isSupported && "conv op is not supported"); SmallVector tileSizes; auto variantOp = getExecutableVariantOp(entryPointFn); assert(succeeded(variantOp) && "ExecutableVariantOp not found"); - if (isX86(*variantOp) || isRISCV(*variantOp)) { - tileSizes = {1, 1, 8, vectorSize * 2, 1, 1, 8}; - } - - if (isAArch64(*variantOp)) { - tileSizes = {1, 1, 32, 64, 1, 1, 16}; - } - - // Get default hard-coded tile sizes if we couldn't compute anything better. - if (tileSizes.empty()) { - tileSizes = {1, 1, vectorSize, vectorSize, 1, 1, vectorSize}; + if (isX86(*variantOp)) { + TypeSwitch(op.getOperation()) + .Case( + [&](auto op) { tileSizes = {1, 1, 8, vectorSize * 2, 1, 1, 8}; }) + .Case( + [&](auto op) { tileSizes = {1, 1, 8, vectorSize * 2, 1, 3}; }) + .Default([&](Operation *op) { llvm_unreachable("unsupported conv"); }); + } else if (isRISCV(*variantOp)) { + TypeSwitch(op.getOperation()) + .Case( + [&](auto op) { tileSizes = {1, 1, 8, vectorSize * 2, 1, 1, 8}; }) + .Case( + [&](auto op) { tileSizes = {1, 1, 8, vectorSize, 1, 3}; }) + .Default([&](Operation *op) { llvm_unreachable("unsupported conv"); }); + } else if (isAArch64(*variantOp)) { + TypeSwitch(op.getOperation()) + .Case( + [&](auto op) { tileSizes = {1, 1, 32, 64, 1, 1, 16}; }) + .Case( + [&](auto op) { tileSizes = {1, 1, 4, 4, 1, 4}; }) + .Default([&](Operation *op) { llvm_unreachable("unsupported conv"); }); + } else { + // Get default hard-coded tile sizes if we couldn't compute anything better. + TypeSwitch(op.getOperation()) + .Case([&](auto op) { + tileSizes = {1, 1, vectorSize, vectorSize, 1, 1, vectorSize}; + }) + .Case([&](auto op) { + tileSizes = {1, 1, vectorSize, vectorSize, 1, vectorSize}; + }) + .Default([&](Operation *op) { llvm_unreachable("unsupported conv"); }); } return tileSizes; @@ -1297,13 +1351,24 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn, return setConvRootConfig(entryPointFn, convOp, targetTileSizes, vectorSize); } +/// Sets the lowering configuration for linalg.conv_2d_nchw_fchw +/// operations. +static LogicalResult setRootConfig(func::FuncOp entryPointFn, + linalg::Conv2DNchwFchwOp convOp) { + int64_t vectorSize = + getVectorSize(entryPointFn, convOp.getResult(0).getType()); + SmallVector targetTileSizes = {1, vectorSize * 2, 1, 8, 8, 1, 1}; + return setConvRootConfig(entryPointFn, convOp, targetTileSizes, vectorSize); +} + /// Sets the lowering configuration for linalg.depthwise_conv_2d_nhwc_hwc /// operations. static LogicalResult setRootConfig(func::FuncOp entryPointFn, linalg::DepthwiseConv2DNhwcHwcOp convOp) { int64_t vectorSize = getVectorSize(entryPointFn, convOp.getResult(0).getType()); - SmallVector targetTileSizes = {1, 1, 8, vectorSize * 2, 1, 3}; + SmallVector targetTileSizes = + getConvWorkgroupSizes(entryPointFn, convOp, vectorSize); return setConvRootConfig(entryPointFn, convOp, targetTileSizes, vectorSize); } @@ -1356,16 +1421,21 @@ static LogicalResult setRootConfig( } /// Redirects to methods that set the configuration based on operation type. -static LogicalResult setRootConfigImpl(func::FuncOp entryPointFn, - Operation *op) { +static LogicalResult setRootConfigImpl( + func::FuncOp entryPointFn, Operation *op, + const TargetMLTransformInfo &targetMLTransInfo) { // Do not overwrite default configuration. if (getLoweringConfig(op)) return success(); // Redirect to individual operations. auto setRootConfigFn = [&](Operation *op) -> LogicalResult { return TypeSwitch(op) - .Case( + .Case([&](auto op) { + return setRootConfig(entryPointFn, op, LinalgOpInfo(op), + targetMLTransInfo); + }) + .Case( [&](auto op) { return setRootConfig(entryPointFn, op); }) .Case( [&](auto op) { return setRootConfig(entryPointFn, op); }) @@ -1451,7 +1521,10 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn, return failure(); } } else { - if (failed(setRootConfigImpl(entryPointFn, rootOperation))) { + auto targetMLTransInfo = + TargetMLTransformInfo::getTargetMLTransformInfo(*variantOp); + if (failed(setRootConfigImpl(entryPointFn, rootOperation, + targetMLTransInfo))) { return failure(); } } @@ -1460,7 +1533,6 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn, if (!getTranslationInfo(entryPointFn)) { // Fall back, just set the translation to CPUDefault. setTranslationInfo(entryPointFn, DispatchLoweringPassPipeline::CPUDefault, - /*workloadPerWorkgroup=*/ArrayRef{}, /*workgroupSize=*/ArrayRef{}); } @@ -1474,19 +1546,8 @@ static LogicalResult setTranslationInfoAndRootConfig( for (auto computeOp : computeOps) { if (IREE::Codegen::CompilationInfoAttr compilationInfo = getCompilationInfo(computeOp)) { - // If the function already has a translation, error out. - if (auto translationInfo = getTranslationInfo(entryPointFn)) { - return computeOp->emitOpError( - "multiple ops within dispatch trying to set the translation " - "info"); - } - - SmallVector workgroupSize = - compilationInfo.getWorkgroupSizeVals(); - setTranslationInfo(entryPointFn, compilationInfo.getTranslationInfo(), - workgroupSize); - setLoweringConfig(computeOp, compilationInfo.getLoweringConfig()); - eraseCompilationInfo(computeOp); + if (failed(setUserConfig(entryPointFn, computeOp, compilationInfo))) + return failure(); } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp new file mode 100644 index 000000000000..5944ef02900f --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp @@ -0,0 +1,73 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/PassDetail.h" +#include "iree/compiler/Codegen/Passes.h" +#include "iree/compiler/Codegen/Utils/LinkingUtils.h" +#include "iree/compiler/Utils/ModuleUtils.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct LLVMCPULinkExecutablesPass + : public LLVMCPULinkExecutablesBase { + LLVMCPULinkExecutablesPass() = default; + void runOnOperation() override { + auto moduleOp = getOperation(); + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + auto sourceExecutableOps = + llvm::to_vector<8>(moduleOp.getOps()); + if (sourceExecutableOps.size() <= 1) return; + + // Guess a module name, if needed, to make the output files readable. + auto moduleName = guessModuleName(moduleOp, "llvm_module"); + + // Create our new "linked" hal.executable. + std::string linkedExecutableName = + llvm::formatv("{0}_linked_{1}", moduleName, "llvm_cpu"); + auto linkedExecutableOp = moduleBuilder.create( + moduleOp.getLoc(), linkedExecutableName); + linkedExecutableOp.setVisibility( + sourceExecutableOps.front().getVisibility()); + auto executableBuilder = + OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock()); + + // Gather all unique executable targets - we may have multiple. + auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps); + for (auto executableTargetAttr : executableTargetAttrs) { + // Add our hal.executable.variant with an empty module. + auto linkedTargetOp = + executableBuilder.create( + moduleOp.getLoc(), executableTargetAttr.getSymbolNameFragment(), + executableTargetAttr); + auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); + targetBuilder.create(moduleOp.getLoc()); + + // Try linking together all executables in moduleOp. + if (failed(linkExecutablesInto( + moduleOp, sourceExecutableOps, linkedExecutableOp, linkedTargetOp, + [](mlir::ModuleOp moduleOp) { return moduleOp; }, + targetBuilder))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> +createLLVMCPULinkExecutablesPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 81ab44bb3837..b1da2e26a9c8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -151,17 +151,6 @@ LogicalResult verifyDoubleTilingExpertPassPipelineConfig( CPUDoubleTilingPadExpert); } - // Verify that the workload per workgroup is not set. - // TODO(ravishankarm): Remove workload_per_wg eventually. - SmallVector workloadPerWorkgroup = - translationInfo.getWorkloadPerWorkgroupVals(); - if (!workloadPerWorkgroup.empty()) { - return op->emitOpError( - "workload_per_wg expected to be empty since its internal " - "compiler implementation detail") - << kNumMaxParallelDims; - } - if (loweringConfig.getTileSizes().size() != static_cast(StrategyTilingLevel::NumStrategyTileLevels)) { return op->emitOpError("expected three tiling sizes, got ") @@ -262,6 +251,14 @@ LogicalResult verifyConvTileAndDecomposeExpertConfig( owSize = shape[2]; return success(); }) + .Case([&](auto) { + // Shape: N, OC, OH, OW, (IC), KH, KW + khSize = shape[5]; + kwSize = shape[6]; + ohSize = shape[2]; + owSize = shape[3]; + return success(); + }) .Default([&](auto) { return failure(); }); if (failed(isSizeExtracted)) { return op->emitOpError("unsupported conv types"); @@ -651,5 +648,11 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) { }); } +// NOTE: this runs on the top-level program module containing all +// hal.executable ops. +void buildLLVMCPULinkingPassPipeline(OpPassManager &passManager) { + passManager.addPass(createLLVMCPULinkExecutablesPass()); +} + } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.cpp new file mode 100644 index 000000000000..bf26b161a615 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.cpp @@ -0,0 +1,38 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h" + +#include "iree/compiler/Codegen/Utils/Utils.h" + +using namespace mlir; +using namespace mlir::iree_compiler; + +namespace { + +struct RISCVTargetMLTransformInfo : TargetMLTransformInfo { + RISCVTargetMLTransformInfo() { + defaultMaxUnrollFactor = 8; + defaultMaxTransposeUnrollFactor = 1; + } +}; + +} // namespace + +namespace mlir { +namespace iree_compiler { + +const TargetMLTransformInfo TargetMLTransformInfo::getTargetMLTransformInfo( + IREE::HAL::ExecutableVariantOp variantOp) { + if (isRISCV(variantOp)) { + return RISCVTargetMLTransformInfo(); + } + + return TargetMLTransformInfo(); +}; + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h new file mode 100644 index 000000000000..bbdf4d3f93fa --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h @@ -0,0 +1,31 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_LLVMCPU_TARGETMLTRANSFORMINFO_H_ +#define IREE_COMPILER_CODEGEN_LLVMCPU_TARGETMLTRANSFORMINFO_H_ + +#include + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" + +namespace mlir { +namespace iree_compiler { + +/// Holds target specific information to specialize ML transformations. +// TODO(dcaballe): Move to a Concept-Model implementation when it's worth it. +struct TargetMLTransformInfo { + unsigned defaultMaxUnrollFactor = 8; + unsigned defaultMaxTransposeUnrollFactor = + std::numeric_limits::max(); + + static const TargetMLTransformInfo getTargetMLTransformInfo( + IREE::HAL::ExecutableVariantOp variantOp); +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_CODEGEN_LLVMCPU_TARGETMLTRANSFORMINFO_H_ diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index 53018303eb55..514b070acf50 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -47,9 +47,9 @@ static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) { SmallVector parallelIterators; SmallVector reductionIterators; for (int i = 0; i < 3; i++) { - if (isParallelIterator(iteratorTypes[i])) { + if (vector::isParallelIterator(iteratorTypes[i])) { parallelIterators.push_back(i); - } else if (isReductionIterator(iteratorTypes[i])) { + } else if (vector::isReductionIterator(iteratorTypes[i])) { reductionIterators.push_back(i); } else { return false; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/apply_scale_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/apply_scale_lowering.mlir index b53a58dd31f9..835008dea7e1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/apply_scale_lowering.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/apply_scale_lowering.mlir @@ -13,7 +13,7 @@ target_triple = "riscv64-unknown-unknown-eabi-elf" }> #map = affine_map<()[s0] -> (s0 ceildiv 2)> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info hal.executable private @apply_scale_no_vector_feature { hal.executable.variant public @embedded_elf_riscv_64, target = #executable_target_embedded_elf_riscv_64_ { hal.executable.export public @apply_scale_no_vector_feature ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} { @@ -41,7 +41,7 @@ hal.executable private @apply_scale_no_vector_feature { // 64-bit lowering is used by default if no vector features are provided. // TODO(diegocaballero): We shouldn't vectorize the code if no vector features // are provided. -// CHECK-LABEL: llvm.func internal @apply_scale_no_vector_feature +// CHECK-LABEL: llvm.func @apply_scale_no_vector_feature // CHECK: %[[ADD:.*]] = llvm.add %{{.*}}, %{{.*}} : vector<2xi64> // CHECK-NEXT: %[[SHR:.*]] = llvm.ashr %[[ADD]], %{{.*}} : vector<2xi64> // CHECK-NEXT: llvm.trunc %[[SHR]] : vector<2xi64> to vector<2xi32> @@ -61,7 +61,7 @@ hal.executable private @apply_scale_no_vector_feature { target_triple = "riscv64-unknown-unknown-eabi-elf" }> #map = affine_map<()[s0] -> (s0 ceildiv 2)> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info hal.executable private @apply_scale_v { hal.executable.variant public @embedded_elf_riscv_64, target = #executable_target_embedded_elf_riscv_64_ { hal.executable.export public @apply_scale_v ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} { @@ -87,7 +87,7 @@ hal.executable private @apply_scale_v { } // 64-bit lowering is used with '+v'. -// CHECK-LABEL: llvm.func internal @apply_scale_v +// CHECK-LABEL: llvm.func @apply_scale_v // CHECK: %[[ADD:.*]] = llvm.add %{{.*}}, %{{.*}} : vector<2xi64> // CHECK-NEXT: %[[SHR:.*]] = llvm.ashr %[[ADD]], %{{.*}} : vector<2xi64> // CHECK-NEXT: llvm.trunc %[[SHR]] : vector<2xi64> to vector<2xi32> @@ -107,7 +107,7 @@ hal.executable private @apply_scale_v { target_triple = "riscv64-unknown-unknown-eabi-elf" }> #map = affine_map<()[s0] -> (s0 ceildiv 2)> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info hal.executable private @apply_scale_zve64x { hal.executable.variant public @embedded_elf_riscv_64, target = #executable_target_embedded_elf_riscv_64_ { hal.executable.export public @apply_scale_zve64x ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} { @@ -133,7 +133,7 @@ hal.executable private @apply_scale_zve64x { } // 64-bit lowering is used with '+zve64x'. -// CHECK-LABEL: llvm.func internal @apply_scale_zve64x +// CHECK-LABEL: llvm.func @apply_scale_zve64x // CHECK: %[[ADD:.*]] = llvm.add %{{.*}}, %{{.*}} : vector<2xi64> // CHECK-NEXT: %[[SHR:.*]] = llvm.ashr %[[ADD]], %{{.*}} : vector<2xi64> // CHECK-NEXT: llvm.trunc %[[SHR]] : vector<2xi64> to vector<2xi32> @@ -153,7 +153,7 @@ hal.executable private @apply_scale_zve64x { target_triple = "riscv64-unknown-unknown-eabi-elf" }> #map = affine_map<()[s0] -> (s0 ceildiv 2)> -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info hal.executable private @apply_scale_zve32x { hal.executable.variant public @embedded_elf_riscv_64, target = #executable_target_embedded_elf_riscv_64_ { hal.executable.export public @apply_scale_zve32x ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} { @@ -181,7 +181,7 @@ hal.executable private @apply_scale_zve32x { // 32-bit lowering is used with '+zve32x'. Note that the 32-bit lowering // generates 64-bit mul operations that are decomposed into 32-bit operations by // the LLVM backend. -// CHECK-LABEL: llvm.func internal @apply_scale_zve32x +// CHECK-LABEL: llvm.func @apply_scale_zve32x // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : vector<2xi64> // CHECK-NEXT: %[[SHR:.*]] = llvm.lshr %{{.*}}, %{{.*}} : vector<2xi64> // CHECK-NEXT: llvm.trunc %[[SHR]] : vector<2xi64> to vector<2xi32> diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir index 478d76e32ada..9e02499ac972 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/convert_to_llvm.mlir @@ -7,7 +7,7 @@ builtin.module { } } // CHECK: llvm.func @extern_public() -// CHECK: llvm.func internal @entry_point( +// CHECK: llvm.func @entry_point( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !llvm.ptr> {llvm.align = 16 : i64, llvm.noalias} // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !llvm.ptr> {llvm.align = 16 : i64, llvm.noalias} // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !llvm.ptr> {llvm.align = 16 : i64, llvm.noalias}) -> i32 diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir index b2b28729211b..7a50bad790d5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir @@ -2,7 +2,7 @@ llvm.func @sink(f32) -// CHECK-LABEL: llvm.func internal @binding_ptrs +// CHECK-LABEL: llvm.func @binding_ptrs func.func @binding_ptrs() { // CHECK-DAG: %[[C72:.+]] = llvm.mlir.constant(72 : index) : i64 %c72 = arith.constant 72 : index diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir index ba10ff803e27..739b7a6d53fd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir @@ -2,7 +2,7 @@ llvm.func @sink(i64) -// CHECK-LABEL: llvm.func internal @constant_values +// CHECK-LABEL: llvm.func @constant_values func.func @constant_values() { // CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @restrict_num_workgroups // CHECK-SAME: translation_info = #[[TRANSLATION]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_riscv_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_riscv_launch_configuration.mlir index d2a7ebc3486b..eb996dc8df0f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_riscv_launch_configuration.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_riscv_launch_configuration.mlir @@ -45,3 +45,56 @@ hal.executable private @matmul_riscv { // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.matmul // CHECK-SAME: lowering_config = #[[CONFIG]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @thin_depthwise_conv_static { + hal.executable.variant public @embedded_elf_x86_64, target = #hal.executable.target< + "llvm-cpu", + "embedded-elf-riscv_32", { + cpu_features = "+m,+f", + data_layout = "e-m:e-p:32:32-i64:64-n32-S128", + native_vector_size = 0 : index, + target_triple = "riscv32-unknown-unknown-eabi-elf" + }> { + hal.executable.export public @thin_depthwise_conv_static layout(#pipeline_layout) + builtin.module { + func.func @thin_depthwise_conv_static() { + %cst = arith.constant 0.0 : f32 + %input_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) + : !flow.dispatch.tensor + %filter_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) + : !flow.dispatch.tensor + %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) + : !flow.dispatch.tensor + %input = flow.dispatch.tensor.load %input_binding, offsets = [0, 0, 0, 0], sizes = [1, 161, 161, 240], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor -> tensor<1x57x57x72xf32> + %filter = flow.dispatch.tensor.load %filter_binding, offsets = [0, 0, 0], sizes = [3, 3, 240], strides = [1, 1, 1] + : !flow.dispatch.tensor -> tensor<3x3x72xf32> + %init = linalg.init_tensor [1, 28, 28, 72] : tensor<1x28x28x72xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x28x28x72xf32>) -> tensor<1x28x28x72xf32> + %conv = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x57x57x72xf32>, tensor<3x3x72xf32>) + outs(%fill : tensor<1x28x28x72xf32>) -> tensor<1x28x28x72xf32> + + flow.dispatch.tensor.store %conv, %result_binding, offsets = [0, 0, 0, 0], sizes = [1, 28, 28, 72], strides = [1, 1, 1, 1] + : tensor<1x28x28x72xf32> -> !flow.dispatch.tensor + return + } + } + } +} +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: hal.executable.export public @thin_depthwise_conv_static +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc +// CHECK-SAME: lowering_config = #[[CONFIG]] + diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir index b5fd7382a184..d00a17d87f49 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir @@ -653,6 +653,44 @@ hal.executable private @conv_static { // CHECK: linalg.conv_2d_nhwc_hwcf +// ----- + +#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}> +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @conv_nchw_static { + hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { + hal.executable.export public @conv_nchw_static ordinal(0) layout(#pipeline_layout) + builtin.module { + func.func @conv_nchw_static() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 128, 30, 30], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x128x30x30xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [128, 128, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<128x128x3x3xf32> + %5 = linalg.init_tensor [1, 128, 28, 28] : tensor<1x128x28x28xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x128x28x28xf32>) -> tensor<1x128x28x28xf32> + %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x128x30x30xf32>, tensor<128x128x3x3xf32>) outs(%6 : tensor<1x128x28x28xf32>) -> tensor<1x128x28x28xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [1, 128, 28, 28], strides = [1, 1, 1, 1] : tensor<1x128x28x28xf32> -> !flow.dispatch.tensor + return + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: hal.executable.export public @conv_nchw_static +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.conv_2d_nchw_fchw + // ----- #pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @thin_depthwise_conv_static { + hal.executable.variant public @system_elf_x86_64, target = <"llvm-cpu", "system-elf-x86_64", { + data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", + native_vector_size = 64 : index, + target_triple = "x86_64-unknown-linux-gnu" + }> { + hal.executable.export public @thin_depthwise_conv_static layout(#pipeline_layout) + builtin.module { + func.func @thin_depthwise_conv_static() { + %cst = arith.constant 0.0 : f32 + %input_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) + : !flow.dispatch.tensor + %filter_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) + : !flow.dispatch.tensor + %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) + : !flow.dispatch.tensor + %input = flow.dispatch.tensor.load %input_binding, offsets = [0, 0, 0, 0], sizes = [1, 161, 161, 240], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor -> tensor<1x57x57x72xf32> + %filter = flow.dispatch.tensor.load %filter_binding, offsets = [0, 0, 0], sizes = [3, 3, 240], strides = [1, 1, 1] + : !flow.dispatch.tensor -> tensor<3x3x72xf32> + %init = linalg.init_tensor [1, 28, 28, 72] : tensor<1x28x28x72xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x28x28x72xf32>) -> tensor<1x28x28x72xf32> + %conv = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x57x57x72xf32>, tensor<3x3x72xf32>) + outs(%fill : tensor<1x28x28x72xf32>) -> tensor<1x28x28x72xf32> + + flow.dispatch.tensor.store %conv, %result_binding, offsets = [0, 0, 0, 0], sizes = [1, 28, 28, 72], strides = [1, 1, 1, 1] + : tensor<1x28x28x72xf32> -> !flow.dispatch.tensor + return + } + } + } +} +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: hal.executable.export public @thin_depthwise_conv_static +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc +// CHECK-SAME: lowering_config = #[[CONFIG]] + +// ----- + #pipeline_layout = #hal.pipeline.layout, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD index e0ecbe07138a..8501a1340ded 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD @@ -36,6 +36,7 @@ iree_compiler_cc_library( "ConvertToLLVM.h", "KernelConfig.h", "TilingUtils.h", + "TransposeUtils.h", ], deps = [ "//compiler/src/iree/compiler/Codegen:PassHeaders", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 7dc467a590a2..9b04e5d12b93 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -17,6 +17,7 @@ iree_cc_library( "ConvertToLLVM.h" "KernelConfig.h" "TilingUtils.h" + "TransposeUtils.h" SRCS "ConvertToLLVM.cpp" "ConvertToNVVM.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 9ded7eada797..99bfd92cc5cc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -9,7 +9,10 @@ #include #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Codegen/Common/LinalgOpInfo.h" +#include "iree/compiler/Codegen/Common/UserConfig.h" #include "iree/compiler/Codegen/Dialect/LoweringConfig.h" +#include "iree/compiler/Codegen/LLVMGPU/TransposeUtils.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -133,20 +136,6 @@ static bool supportsTensorCore(func::FuncOp entryPoint, linalg::LinalgOp op) { if (outputMap.getResult(i) != b.getAffineDimExpr(i)) return false; } } - // Check that we support converting any fused operation. When using the - // tensorcore pipeline we need to be sure we can generate MMA ops otherwise - // the code will be highly inneficent. - bool fusedOpSupported = true; - entryPoint.walk([&fusedOpSupported](linalg::GenericOp linalgOp) { - for (Operation &fusedOp : linalgOp.getOps()) { - if (!isa(fusedOp)) { - fusedOpSupported = false; - break; - } - } - }); - if (!fusedOpSupported) return false; return true; } @@ -439,25 +428,6 @@ static LogicalResult setRootDefaultConfig(func::FuncOp entryPoint, passPipeline, workgroupSize); } -/// Propagate the configuration annotated in the incoming IR. -static LogicalResult setUserConfig( - func::FuncOp entryPointFn, Operation *computeOp, - IREE::Codegen::CompilationInfoAttr compilationInfo) { - if (auto translationInfo = getTranslationInfo(entryPointFn)) { - return computeOp->emitOpError( - "multiple ops within dispatch trying to set the translation " - "info"); - } - - SmallVector workgroupSize = compilationInfo.getWorkgroupSizeVals(); - setTranslationInfo(entryPointFn, compilationInfo.getTranslationInfo(), - workgroupSize); - - setLoweringConfig(computeOp, compilationInfo.getLoweringConfig()); - eraseCompilationInfo(computeOp); - return success(); -} - /// Return the size of the given dimension in the linalg op. // TODO: this should be part of LinalgOp interface, the equivalent member // function currently only support the case where all the dimensions are static @@ -529,65 +499,50 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, return success(); } -/// Returns true if the index map represents a transpose that benefits from -/// shared mem. Currently supports 2D transposes. -static bool isSharedMemTranspose(AffineMap indexMap) { - if (!indexMap.isEmpty() && indexMap.isPermutation() && - indexMap.getNumInputs() == 2) { - // Ensure that the fasted moving dimension (the last one) is permuted, - // Otherwise shared memory promotion will not benefit the operation. - if (indexMap.getDimPosition(indexMap.getNumDims() - 1) != - indexMap.getNumDims() - 1) { - return true; - } - } - return false; +static bool hasTwoOrThreeLoopsInfo(linalg::LinalgOp linalgOp) { + return linalgOp.getNumParallelLoops() >= 2 && + linalgOp.getNumParallelLoops() <= 3; } -/// Returns true if the operation is a GenericOp implementing a 2D transpose. -static bool isTransposeOp(linalg::LinalgOp linalgOp) { - if (!isa(linalgOp)) return false; - // Check that the op has at least 2 parallel loops. - if (linalgOp.getNumParallelLoops() < 2) { - return false; - } +static LogicalResult setTransposeConfig(func::FuncOp entryPoint, + linalg::LinalgOp linalgOp) { + LinalgOpInfo opInfo(linalgOp, sharedMemTransposeFilter); - // Check that all the iterators are parallel. - if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) { - return false; + // Checks preconditions for shared mem transpose. + if (!opInfo.isTranspose() || opInfo.isDynamic() || opInfo.isReduction() || + !isa(linalgOp) || !hasTwoOrThreeLoopsInfo(linalgOp)) { + return failure(); } - // Only transpose static shapes - if (linalgOp.hasDynamicShape()) { - return false; - } + ArrayRef transposedOperands = opInfo.getTransposeOperands(); - // Check that at least one input operands is transposed. - bool hasPermutation = false; - for (auto indexMap : linalgOp.getIndexingMapsArray()) { - if (isSharedMemTranspose(indexMap)) { - hasPermutation = true; + // Determine the fastest moving dimensions for the source/destination indices + // of each transpose. These inform the tile sizes. + int64_t outputFastestDim = linalgOp.getNumLoops() - 1; + int64_t inputFastestDim = linalgOp.getTiedIndexingMap(transposedOperands[0]) + .getDimPosition(outputFastestDim); + // Ensure the other transposed operands match + for (int i = 1; i < transposedOperands.size(); ++i) { + if (inputFastestDim != linalgOp.getTiedIndexingMap(transposedOperands[i]) + .getDimPosition(outputFastestDim)) { + return failure(); } } - return hasPermutation; -} -static LogicalResult setTransposeConfig(func::FuncOp entryPoint, - Operation *op) { int32_t tileM = 32; int32_t tileN = 32; TileSizesListType tileSizes; - tileSizes.push_back({tileM, tileN}); - - // Check alignment with tile size for each transpose. - if (auto genericOp = dyn_cast(op)) { - auto loopRanges = genericOp.getStaticLoopRanges(); - for (auto loopRange : loopRanges) { - if (loopRange % 32 != 0) { - return failure(); - } - } - } else { + // Set all tile sizes to 1 except for fastest moving dimensions. + SmallVector tileSizesTemp(linalgOp.getNumLoops(), 1); + tileSizesTemp[outputFastestDim] = 32; + tileSizesTemp[inputFastestDim] = 32; + tileSizes.push_back(tileSizesTemp); + + // Check alignment with tile size for each transpose. Only the fastest moving + // dims need to match the transpose tile. + auto loopRanges = linalgOp.getStaticLoopRanges(); + if (loopRanges[outputFastestDim] % tileM != 0 || + loopRanges[inputFastestDim] % tileN != 0) { return failure(); } @@ -597,11 +552,103 @@ static LogicalResult setTransposeConfig(func::FuncOp entryPoint, std::array workgroupSize = {8, 32, 1}; return setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, + entryPoint, linalgOp, tileSizes, IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTransposeSharedMem, workgroupSize); } +static LogicalResult setConvolutionConfig(linalg::LinalgOp linalgOp, + const int64_t subgroupSize, + const int64_t bestTilingFactor) { + if (!isa(linalgOp)) { + return failure(); + } + Type inputType = linalgOp.getInputOperand(0)->get().getType(); + ArrayRef inputShape = inputType.cast().getShape(); + Type outputType = linalgOp.getOutputOperand(0)->get().getType(); + ArrayRef outputShape = outputType.cast().getShape(); + if (ShapedType::isDynamic(inputShape[3]) || + llvm::any_of(outputShape.drop_front(), ShapedType::isDynamic)) { + return failure(); + } + int64_t oh = outputShape[1], ow = outputShape[2], oc = outputShape[3]; + // The core idea is to distribute the convolution OH/OW/OC dimension to the + // workgroup Z/Y/X dimension, with each thread in a workgroup handling + // multiple vector elements. We try to 1) utilize all threads in a subgroup, + // and 2) handle an optimal tile size along each dimension. + int64_t residualThreads = subgroupSize; + int64_t residualTilingFactor = bestTilingFactor; + SmallVector workgroupSize(3, 1); // (X, Y, Z) + SmallVector workgroupTileSizes(4, 0); // (N, OH, OW, OC) + // Deduce the configuration for the OC dimension. + for (int64_t x = residualThreads; x >= 2; x >>= 1) { + // Handle 4 elements per thread for the innermost dimension. We need this + // for vectorized load. + int64_t chosenTileSize = 4; + if (oc % (x * chosenTileSize) == 0) { + workgroupSize[0] = x; + workgroupTileSizes[3] = x * chosenTileSize; + residualThreads /= x; + residualTilingFactor /= chosenTileSize; + break; + } + } + if (workgroupTileSizes[3] == 0) return failure(); + // Deduce the configruation for the OW and OH dimension. Try to make them even + // if possible given we typically have images with the same height and width. + bool tileToSquare = false; + unsigned log2Threads = llvm::Log2_64(residualThreads); + if (ow == oh && residualThreads != 1 && log2Threads % 2 == 0) { + int64_t yz = 1ll << (log2Threads / 2); + int64_t chosenTileSize = 1ll << (llvm::Log2_64(residualTilingFactor) / 2); + while (chosenTileSize >= 1 && ow % (yz * chosenTileSize) != 0) { + chosenTileSize >>= 1; + } + if (chosenTileSize != 0) { + workgroupSize[1] = workgroupSize[2] = yz; + workgroupTileSizes[2] = workgroupTileSizes[1] = yz * chosenTileSize; + tileToSquare = true; + } + } + // Otherwise treat OW and OH separately to allow them to have different number + // of threads and tiling size. + if (!tileToSquare) { + // Decide the tiling and distribution parameters for one dimension. + auto decideOneDim = [&](int64_t inputDim, int64_t &wgDimSize, + int64_t &wgTileSize) { + for (int64_t dim = residualThreads; dim >= 1; dim >>= 1) { + int64_t chosenTileSize = 0; + for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) { + if (inputDim % (dim * t) == 0) { + chosenTileSize = t; + break; + } + } + if (chosenTileSize) { + wgDimSize = dim; + wgTileSize = dim * chosenTileSize; + residualThreads /= dim; + residualTilingFactor /= chosenTileSize; + return true; + } + } + return false; + }; + if (!decideOneDim(ow, workgroupSize[1], workgroupTileSizes[2]) || + !decideOneDim(oh, workgroupSize[2], workgroupTileSizes[1])) { + return failure(); + } + } + auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorize; + TileSizesListType tileSizes; + // Add reduction tile sizes. + workgroupTileSizes.append({1, 1, 4}); + tileSizes.push_back(workgroupTileSizes); + auto funcOp = linalgOp->getParentOfType(); + return setOpConfigAndEntryPointFnTranslation(funcOp, linalgOp, tileSizes, + pipeline, workgroupSize); +} + static LogicalResult setRootConfig(func::FuncOp entryPointFn, Operation *computeOp) { if (!clGPUCodegenTransformDialectTileSizes.empty()) { @@ -629,8 +676,11 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn, if (succeeded(setWarpReductionConfig(entryPointFn, linalgOp))) { return success(); } - if (isTransposeOp(linalgOp) && - succeeded(setTransposeConfig(entryPointFn, linalgOp))) { + if (succeeded(setConvolutionConfig(linalgOp, 32, 16))) { + return success(); + } + auto genericOp = dyn_cast(computeOp); + if (genericOp && succeeded(setTransposeConfig(entryPointFn, genericOp))) { return success(); } } @@ -660,8 +710,6 @@ LogicalResult initGPULaunchConfig(ModuleOp moduleOp) { return funcOp.emitOpError("failed to get compute ops"); } - // If using sandbox passes, currently set the workload_per_wg to be - // empty for single-threaded execution. if (clGPUCodegenTransformDialectFileName.size() > 0) { auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( moduleOp.getContext(), IREE::Codegen::DispatchLoweringPassPipeline:: @@ -700,7 +748,7 @@ LogicalResult initGPULaunchConfig(ModuleOp moduleOp) { // setTranslationInfo( // funcOp, // IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute, - // /*workloadPerWorkgroup=*/{}, {1, 1, 1}); + // {1, 1, 1}); // continue; return funcOp.emitOpError("unable to find root operation"); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp index 6490236885a6..40d26105024f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp @@ -36,21 +36,15 @@ struct LLVMGPUDistributePass if (!isEntryPoint(funcOp)) return; auto workgroupSize = llvm::to_vector(llvm::map_range( - getEntryPoint(funcOp)->getWorkgroupSize().getValue(), + getEntryPoint(funcOp)->getWorkgroupSize().value(), [&](Attribute attr) { return attr.cast().getInt(); })); - SmallVector foreachOps; - funcOp.walk([&](scf::ForeachThreadOp foreachOp) { - foreachOps.push_back(foreachOp); - }); - for (scf::ForeachThreadOp op : foreachOps) { - IRRewriter rewriter(op->getContext()); - rewriter.setInsertionPoint(op); - if (failed( - rewriteForeachThreadToGpu(op, workgroupSize, rewriter, - /*syncAfterDistributefalse=*/false))) { - return signalPassFailure(); - } - } + + IRRewriter rewriter(funcOp->getContext()); + rewriter.setInsertionPoint(funcOp); + auto walkResult = mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads( + rewriter, funcOp, workgroupSize, false); + + if (walkResult.wasInterrupted()) return signalPassFailure(); } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUMultiBuffering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUMultiBuffering.cpp index 2cc64626b5e0..8bdbb6df0f91 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUMultiBuffering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUMultiBuffering.cpp @@ -29,8 +29,10 @@ struct LLVMGPUMultiBufferingPass // Collect all the alloc operations. funcOp.walk([&](memref::AllocOp allocOp) { // Skip allocations not used in a loop. - auto loop = allocOp->getUsers().begin()->getParentOfType(); - if (!loop) return WalkResult::advance(); + for (Operation* user : allocOp->getUsers()) { + auto loop = user->getParentOfType(); + if (!loop) return WalkResult::advance(); + } allocs.push_back(allocOp); return WalkResult::advance(); }); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUReduceBankConflicts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUReduceBankConflicts.cpp index 1e9d6b338e15..4768f4fcf8c9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUReduceBankConflicts.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUReduceBankConflicts.cpp @@ -61,7 +61,8 @@ struct LLVMGPUReduceBankConflictsPass // Collect all the alloc operations. funcOp.walk([&](memref::AllocOp allocOp) { if (allocOp.getType().getMemorySpaceAsInt() == - gpu::GPUDialect::getWorkgroupAddressSpace()) { + gpu::GPUDialect::getWorkgroupAddressSpace() && + allocOp.getType().hasStaticShape()) { sharedMemAllocs.push_back(allocOp); } }); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorAlloc.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorAlloc.cpp index ab92508ee423..5bd57d295593 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorAlloc.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorAlloc.cpp @@ -4,7 +4,9 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Common/LinalgOpInfo.h" #include "iree/compiler/Codegen/LLVMGPU/TilingUtils.h" +#include "iree/compiler/Codegen/LLVMGPU/TransposeUtils.h" #include "iree/compiler/Codegen/PassDetail.h" #include "iree/compiler/Codegen/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -16,20 +18,59 @@ namespace mlir { namespace iree_compiler { -/// Filter to decide which ops need allocations. -static bool filter(Operation *op) { +/// Filter to decide which contract ops need allocations. +static bool contractOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); if (!linalgOp) return false; // Can't promote dynamic shapes. if (linalgOp.hasDynamicShape()) return false; - return linalg::isaContractionOpInterface(op) && + SmallVector dims; + linalgOp.getParallelDims(dims); + SmallVector shapes = linalgOp.getStaticLoopRanges(); + // Don't promote vector*matrix kind of case. + int numNonUnitParallelLoop = 0; + for (unsigned parallelDim : dims) { + if (shapes[parallelDim] != 1) { + numNonUnitParallelLoop++; + } + } + return numNonUnitParallelLoop > 1 && linalg::isaContractionOpInterface(op) && linalgOp.getNumParallelLoops() >= 2 && linalgOp.getNumParallelLoops() <= 3; } +/// Filter to decide which transpose ops need allocations. +static bool transposeOpFilter(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) return false; + LinalgOpInfo opInfo(linalgOp, sharedMemTransposeFilter); + return opInfo.isTranspose(); +} + +/// Returns true if the index map represents a transpose that benefits from +/// shared mem. +static bool isSharedMemTranspose(AffineMap indexMap) { + if (!indexMap.isEmpty() && indexMap.isPermutation()) { + // Ensure that the fasted moving dimension (the last one) is permuted, + // Otherwise shared memory promotion will not benefit the operation. + if (indexMap.getDimPosition(indexMap.getNumDims() - 1) != + indexMap.getNumDims() - 1) { + return true; + } + } + return false; +} + namespace { struct LLVMGPUTensorAllocPass : public LLVMGPUTensorAllocBase { + private: + GPUPromoteSharedMemPattern promoteSharedMemPattern = + GPUPromoteSharedMemPattern::ContractionOpPattern; + + public: + LLVMGPUTensorAllocPass(GPUPromoteSharedMemPattern promoteSharedMemPattern) + : promoteSharedMemPattern(promoteSharedMemPattern) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -43,29 +84,55 @@ struct LLVMGPUTensorAllocPass SmallVector opsToPromote; funcOp.walk([&](Operation *op) { - if (filter(op)) opsToPromote.push_back(op); + switch (promoteSharedMemPattern) { + case GPUPromoteSharedMemPattern::ContractionOpPattern: + if (contractOpFilter(op)) opsToPromote.push_back(op); + break; + case GPUPromoteSharedMemPattern::TransposeOpPattern: + if (transposeOpFilter(op)) opsToPromote.push_back(op); + break; + } }); for (Operation *op : opsToPromote) { OpBuilder builder(op); auto linalgOp = cast(op); bufferization::BufferizationOptions options; - // Promote all the input operands. - for (auto operand : linalgOp.getInputOperands()) { - FailureOr ret = bufferization::allocateTensorForShapedValue( - builder, op->getLoc(), operand->get(), false, options, true); - if (failed(ret)) { - return signalPassFailure(); - } - Value v = ret.getValue(); - operand->get().replaceAllUsesExcept(v, v.getDefiningOp()); + switch (promoteSharedMemPattern) { + case GPUPromoteSharedMemPattern::ContractionOpPattern: + // Promote all the input operands + for (auto operand : linalgOp.getInputOperands()) { + FailureOr ret = bufferization::allocateTensorForShapedValue( + builder, op->getLoc(), operand->get(), false, options, true); + if (failed(ret)) { + return signalPassFailure(); + } + Value v = ret.value(); + operand->set(v); + } + break; + + case GPUPromoteSharedMemPattern::TransposeOpPattern: + LinalgOpInfo opInfo(linalgOp, sharedMemTransposeFilter); + + for (auto operand : opInfo.getTransposeOperands()) { + FailureOr ret = bufferization::allocateTensorForShapedValue( + builder, op->getLoc(), operand->get(), false, options, true); + if (failed(ret)) { + return signalPassFailure(); + } + Value v = ret.value(); + operand->set(v); + } + break; } } } }; } // namespace -std::unique_ptr> createLLVMGPUTensorAlloc() { - return std::make_unique(); +std::unique_ptr> createLLVMGPUTensorAlloc( + GPUPromoteSharedMemPattern promoteSharedMemPattern) { + return std::make_unique(promoteSharedMemPattern); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp index f465b0393651..fe87b590f3e2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp @@ -49,7 +49,9 @@ static Optional> unrollOrder(Operation *op) { // register. This is needed to get good performance on sm_80 target. // First make reduction the outer dimensions. for (auto iter : llvm::enumerate(contract.getIteratorTypes())) { - if (isReductionIterator(iter.value())) order.push_back(iter.index()); + if (vector::isReductionIterator(iter.value())) { + order.push_back(iter.index()); + } } llvm::SmallDenseSet dims; @@ -58,13 +60,15 @@ static Optional> unrollOrder(Operation *op) { } // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. for (auto iter : llvm::enumerate(contract.getIteratorTypes())) { - if (isParallelIterator(iter.value()) && dims.count(iter.index())) + if (vector::isParallelIterator(iter.value()) && dims.count(iter.index())) { order.push_back(iter.index()); + } } // Then the remaining parallel loops. for (auto iter : llvm::enumerate(contract.getIteratorTypes())) { - if (isParallelIterator(iter.value()) && !dims.count(iter.index())) + if (vector::isParallelIterator(iter.value()) && !dims.count(iter.index())) { order.push_back(iter.index()); + } } return order; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp index d72ae32dc1cd..683cdd7b4667 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp @@ -36,9 +36,6 @@ using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns; namespace mlir { namespace iree_compiler { -/// Flag defined in Passes.cpp. -extern llvm::cl::opt llvmgpuUseMMASync; - /// Patterns for workgroup level tiling. Workgroup tiling is done at the flow /// level but we may have extra tiling for the reduction dimension. Therefore we /// tile again without distributing. @@ -181,38 +178,8 @@ static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) { return success(); } -/// Returns the indices of the transposed operands in a linalg generic. -static SmallVector getTransposedOperands(linalg::GenericOp linalgOp) { - // Determine which operands to promote: - SmallVector transposedOperands; - if (linalgOp.getNumParallelLoops() < 2) { - return transposedOperands; - } - for (auto indexValue : llvm::enumerate(linalgOp.getIndexingMapsArray())) { - int64_t opIndex = indexValue.index(); - auto indexMap = indexValue.value(); - if (!indexMap.isEmpty() && indexMap.isPermutation()) { - // Ensure that the fasted moving dimension (the last one) is permuted - // otherwise data isn't moved. - if (indexMap.getDimPosition(indexMap.getNumDims() - 1) != - indexMap.getNumDims() - 1) { - // Add operand to promote to list and mark the linalg for this - // promotion. - transposedOperands.push_back(opIndex); - } - } - } - return transposedOperands; -} - using PromotionFilterFunction = std::function; -/// Returns true if op is appropriate transpose for promotion. -static LogicalResult transposeFilter(Operation *op, - linalg::GenericOp promotedFilterOp) { - return success(op == promotedFilterOp.getOperation()); -} - /// Returns true if op is appropriate contract for promotion. static LogicalResult contractOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); @@ -249,31 +216,88 @@ static void populatePromotionPatterns(MLIRContext *context, .addFilter(filterFunction)); } +static bool propagateCopyDestIntoProducerFill(memref::CopyOp copyOp) { + // Look for a fill Op writing into the copyOp source. + Operation *prevOp = copyOp->getPrevNode(); + while (prevOp) { + if (isSideEffectFree(prevOp)) { + prevOp = prevOp->getPrevNode(); + continue; + } + + auto fillOp = dyn_cast(prevOp); + if (!fillOp) break; + if (fillOp.output() != copyOp.getSource()) break; + // Move the fillOp and change the destination to the copy destination. + fillOp->moveBefore(copyOp); + fillOp.getOutputsMutable().assign(copyOp.getTarget()); + return true; + } + return false; +} + +// Split input/output operand from copy from shared memory into a separate +// input. +static void insertInputValueIntoGeneric(Value source, linalg::GenericOp op) { + SmallVector newOperands; + SmallVector maps; + for (OpOperand *in : op.getInputOperands()) { + newOperands.push_back(in->get()); + maps.push_back(op.getTiedIndexingMap(in)); + } + newOperands.push_back(source); + assert(op.getNumOutputs() == 1); + OpOperand *outOperand = op.getOutputOperand(0); + maps.push_back(op.getTiedIndexingMap(outOperand)); + maps.push_back(op.getTiedIndexingMap(outOperand)); + Location loc = op.getLoc(); + SmallVector iterTypes(op.getNumLoops(), + getParallelIteratorTypeName()); + OpBuilder builder(op); + auto newOp = builder.create( + loc, newOperands, outOperand->get(), maps, iterTypes); + newOp.getRegion().getBlocks().splice(newOp.getRegion().begin(), + op.getRegion().getBlocks()); + + Block &payload = newOp.getRegion().front(); + payload.addArgument(payload.getArguments().back().getType(), loc); + setMarker(newOp, getCopyToWorkgroupMemoryMarker()); +} + +/// Propagate the shared memory copy into the consumer op if it's a fully +/// parallel linalg.generic. +static bool propagateCopySourceIntoConsumerGeneric( + memref::CopyOp copyOp, SmallVector &toDelete) { + // Look for a generic Op reading the copyOp target. + Operation *nextOp = copyOp->getNextNode(); + while (nextOp) { + if (isSideEffectFree(nextOp)) { + nextOp = nextOp->getNextNode(); + continue; + } + auto consumer = dyn_cast(nextOp); + if (!consumer || consumer.getNumOutputs() != 1 || + !consumer.getTiedIndexingMap(consumer.getOutputOperand(0)).isIdentity()) + break; + if (*consumer.outputs().begin() != copyOp.getTarget()) break; + insertInputValueIntoGeneric(copyOp.getSource(), consumer); + toDelete.push_back(consumer); + return true; + } + return false; +} + /// Transformation to propagate FillOp + CopyOp to temp allocation. /// This is needed because we are doing promotion to shared memory on buffers. /// This is a fragile and temporary solution until we move to be able to do this /// kind of transformations on tensors. -static void propagateFillIntoPromotionAlloc(func::FuncOp funcOp) { +static void propagateSharedMemCopy(func::FuncOp funcOp) { SmallVector toDelete; funcOp.walk([&toDelete](memref::CopyOp copyOp) { if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) { - // Look for a fill Op writing into the copyOp source. - Operation *prevOp = copyOp->getPrevNode(); - while (prevOp) { - if (isSideEffectFree(prevOp)) { - prevOp = prevOp->getPrevNode(); - continue; - } - - auto fillOp = dyn_cast(prevOp); - if (!fillOp) break; - if (fillOp.output() != copyOp.getSource()) break; - // Move the fillOp and change the destination to the copy destination. - fillOp->moveBefore(copyOp); - fillOp.getOutputsMutable().assign(copyOp.getTarget()); + if (propagateCopyDestIntoProducerFill(copyOp) || + propagateCopySourceIntoConsumerGeneric(copyOp, toDelete)) toDelete.push_back(copyOp.getOperation()); - break; - } } }); for (Operation *op : toDelete) op->erase(); @@ -285,14 +309,10 @@ struct LLVMGPUTileAndDistributePass private: // Distribute the workloads to warp if true otherwise distribute to threads. bool distributeToWarp = false; - GPUPromoteSharedMemPattern promoteSharedMemPattern = - GPUPromoteSharedMemPattern::ContractionOpPattern; public: - LLVMGPUTileAndDistributePass( - bool distributeToWarp, GPUPromoteSharedMemPattern promoteSharedMemPattern) - : distributeToWarp(distributeToWarp), - promoteSharedMemPattern(promoteSharedMemPattern) {} + LLVMGPUTileAndDistributePass(bool distributeToWarp) + : distributeToWarp(distributeToWarp) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -303,7 +323,7 @@ struct LLVMGPUTileAndDistributePass // Promote C matrix and propagate the potential fill producer into the temp // allocation. This needs to be done before reduction tiling. - if (llvmgpuUseMMASync) { + { RewritePatternSet promotionPatterns(&getContext()); populatePromotionPatterns(context, promotionPatterns, contractOpFilter, {2}); @@ -311,7 +331,7 @@ struct LLVMGPUTileAndDistributePass std::move(promotionPatterns)))) { return signalPassFailure(); } - propagateFillIntoPromotionAlloc(funcOp); + propagateSharedMemCopy(funcOp); } // Tile again at the workgroup level since reduction dimension were @@ -336,35 +356,9 @@ struct LLVMGPUTileAndDistributePass if (flatWorkgroupSize > kWarpSize) { RewritePatternSet promotionPatterns(&getContext()); - switch (promoteSharedMemPattern) { - case GPUPromoteSharedMemPattern::ContractionOpPattern: populatePromotionPatterns(context, promotionPatterns, contractOpFilter, {0, 1}); - break; - case GPUPromoteSharedMemPattern::TransposeOpPattern: - funcOp.walk( - [&context, &promotionPatterns](linalg::GenericOp linalgOp) { - // Promotion patterns accept a fixed list of operands to promote - // before determine which op is being promoted. To support - // multiple linalg generic ops with different promoted operands, - // We walk each linalg generic op to determine which operands to - // promote, then create a filter that will only apply to it's - // configuration. - SmallVector operandsToPromote = - getTransposedOperands(linalgOp); - if (!operandsToPromote.empty()) { - populatePromotionPatterns( - context, promotionPatterns, - [linalgOp](Operation *op) -> LogicalResult { - return transposeFilter(op, linalgOp); - }, - operandsToPromote); - } - }); - - break; - } if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns)))) { return signalPassFailure(); @@ -372,17 +366,17 @@ struct LLVMGPUTileAndDistributePass // Insert barriers before and after copies to workgroup memory and skip // insert barriers between back to back copy to workgroup memory. OpBuilder builder(&getContext()); - funcOp.walk([&builder](memref::CopyOp copyOp) { + funcOp.walk([&builder](Operation *copyOp) { if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) { Operation *prevOp = copyOp->getPrevNode(); if (!prevOp || !hasMarker(prevOp, getCopyToWorkgroupMemoryMarker())) { builder.setInsertionPoint(copyOp); - builder.create(copyOp.getLoc()); + builder.create(copyOp->getLoc()); } Operation *nextOp = copyOp->getNextNode(); if (!nextOp || !hasMarker(nextOp, getCopyToWorkgroupMemoryMarker())) { builder.setInsertionPointAfter(copyOp); - builder.create(copyOp.getLoc()); + builder.create(copyOp->getLoc()); } } }); @@ -442,9 +436,8 @@ struct LLVMGPUTileAndDistributePass } // namespace std::unique_ptr> createLLVMGPUTileAndDistribute( - bool distributeToWarp, GPUPromoteSharedMemPattern promoteSharedMemPattern) { - return std::make_unique( - distributeToWarp, promoteSharedMemPattern); + bool distributeToWarp) { + return std::make_unique(distributeToWarp); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp index 5ab157166b0a..12315502b192 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns; @@ -55,8 +56,9 @@ static void populateTilingReductionPatterns(RewritePatternSet &patterns) { StringAttr::get(context, getWorkgroupMemoryMarker())}, StringAttr::get(context, getWorkgroupKTiledMarker())); filter.setMatchByDefault(); - TilingPatterns::insert(patterns, tilingOptions, filter); + TilingPatterns::insert(patterns, tilingOptions, + filter); } LogicalResult tileReduction(func::FuncOp funcOp) { @@ -77,6 +79,8 @@ LogicalResult tileReduction(func::FuncOp funcOp) { linalg::getLinalgTilingCanonicalizationPatterns(funcOp.getContext()); populateAffineMinSCFCanonicalizationPattern( wgTilingCanonicalizationPatterns); + scf::populateSCFForLoopCanonicalizationPatterns( + wgTilingCanonicalizationPatterns); if (failed(applyPatternsAndFoldGreedily( funcOp, std::move(wgTilingCanonicalizationPatterns)))) { return failure(); @@ -147,24 +151,24 @@ struct LLVMGPUTileTensorPass auto funcOp = getOperation(); if (!isEntryPoint(funcOp)) return; - auto workgroupSize = llvm::to_vector<4>(llvm::map_range( - getEntryPoint(funcOp)->getWorkgroupSize().getValue(), - [&](Attribute attr) { return attr.cast().getInt(); })); - if (failed(tileParallelDims(funcOp, workgroupSize, distributeToWarp))) { + if (failed(tileReduction(funcOp))) { return signalPassFailure(); } LLVM_DEBUG({ - llvm::dbgs() << "--- After second level of tiling"; + llvm::dbgs() << "--- After tile reductions:"; funcOp.dump(); }); - if (failed(tileReduction(funcOp))) { + auto workgroupSize = llvm::to_vector<4>(llvm::map_range( + getEntryPoint(funcOp)->getWorkgroupSize().value(), + [&](Attribute attr) { return attr.cast().getInt(); })); + if (failed(tileParallelDims(funcOp, workgroupSize, distributeToWarp))) { return signalPassFailure(); } LLVM_DEBUG({ - llvm::dbgs() << "--- After tile reductions:"; + llvm::dbgs() << "--- After second level of tiling"; funcOp.dump(); }); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp index 9726d50aa0cf..05b919785d0e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp @@ -257,6 +257,7 @@ struct LLVMGPUVectorToGPUPass return signalPassFailure(); } RewritePatternSet patterns(funcOp.getContext()); + mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); populatePrepareVectorToMMAPatterns(patterns, llvmgpuUseMMASync); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index e050776ee12d..2481dd4563f8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -117,12 +117,20 @@ void addGPUVectorizationPassPipeline(OpPassManager &pm) { nestedModulePM.addNestedPass(createGPUVectorizationPass()); nestedModulePM.addNestedPass(createCanonicalizerPass()); nestedModulePM.addNestedPass(createCSEPass()); - nestedModulePM.addNestedPass( - createOptimizeVectorTransferPass()); // tensor to memref addBufferizePasses(nestedModulePM); nestedModulePM.addNestedPass(createLLVMGPUDistribute()); + + // Post bufferization optimizations. + nestedModulePM.addNestedPass( + createLoopInvariantCodeMotionPass()); + nestedModulePM.addNestedPass( + memref::createFoldMemRefAliasOpsPass()); + nestedModulePM.addNestedPass(createCanonicalizerPass()); + nestedModulePM.addNestedPass(createCSEPass()); + nestedModulePM.addNestedPass( + createOptimizeVectorTransferPass()); } void addGPUMatmulSimtPassPipeline(OpPassManager &pm) { @@ -188,16 +196,9 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm, if (pipelineDepth > 1) nestedModulePM.addNestedPass( createLLVMGPUMultiBuffering(pipelineDepth)); - nestedModulePM.addNestedPass(createMemrefCopyToLinalgPass()); - nestedModulePM.addNestedPass( - createGPUDistributeSharedMemoryCopy()); nestedModulePM.addPass(createCanonicalizerPass()); nestedModulePM.addPass(createCSEPass()); - if (!llvmgpuUseMMASync) { - nestedModulePM.addNestedPass( - createLLVMGPUReduceSharedMemoryBankConflicts()); - } nestedModulePM.addNestedPass( createRemoveSingleIterationLoopPass()); nestedModulePM.addNestedPass( @@ -213,6 +214,17 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm, nestedModulePM.addNestedPass( createOptimizeVectorTransferPass()); + // Distribute shared memory copies. + nestedModulePM.addNestedPass(createMemrefCopyToLinalgPass()); + nestedModulePM.addNestedPass( + createGPUDistributeSharedMemoryCopy()); + nestedModulePM.addNestedPass(createCanonicalizerPass()); + nestedModulePM.addNestedPass(createCSEPass()); + if (!llvmgpuUseMMASync) { + nestedModulePM.addNestedPass( + createLLVMGPUReduceSharedMemoryBankConflicts()); + } + // Vector -> MMA ops nestedModulePM.addNestedPass( memref::createFoldMemRefAliasOpsPass()); @@ -222,16 +234,33 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm, // Pipeline memory operations. nestedModulePM.addNestedPass( - createGPUPipeliningPass(pipelineDepth)); + createGPUPipeliningPass(/*epiloguePeeling=*/false, pipelineDepth)); } void addGPUTransposePassPipeline(OpPassManager &pm) { - tileAndBufferize(pm); - + tileAndDistributeToWorkgroup(pm); auto &nestedModulePM = pm.nest(); - // Distribute linalg onto threads within the workgroup. - nestedModulePM.addNestedPass(createLLVMGPUTileAndDistribute( - false, GPUPromoteSharedMemPattern::TransposeOpPattern)); + + nestedModulePM.addNestedPass( + createRemoveSingleIterationLoopPass()); + + nestedModulePM.addNestedPass( + createLLVMGPUTensorAlloc(GPUPromoteSharedMemPattern::TransposeOpPattern)); + nestedModulePM.addNestedPass(createLLVMGPUTileTensor(false)); + + // Linalg -> vector + nestedModulePM.addNestedPass(createGPUVectorizationPass()); + nestedModulePM.addNestedPass(createCanonicalizerPass()); + nestedModulePM.addNestedPass(createCSEPass()); + nestedModulePM.addNestedPass( + createOptimizeVectorTransferPass()); + + // tensor to memref + addBufferizePasses(nestedModulePM); + + // distribute foreach threads + nestedModulePM.addNestedPass(createLLVMGPUDistribute()); + nestedModulePM.addNestedPass(createMemrefCopyToLinalgPass()); nestedModulePM.addNestedPass( createGPUDistributeSharedMemoryCopy()); @@ -245,13 +274,6 @@ void addGPUTransposePassPipeline(OpPassManager &pm) { createRemoveSingleIterationLoopPass()); nestedModulePM.addPass(createCanonicalizerPass()); nestedModulePM.addPass(createCSEPass()); - - // Linalg -> vector - nestedModulePM.addNestedPass(createGPUVectorizationPass()); - nestedModulePM.addNestedPass(createCanonicalizerPass()); - nestedModulePM.addNestedPass(createCSEPass()); - nestedModulePM.addNestedPass( - createOptimizeVectorTransferPass()); } void addGPUWarpReductionPassPipeline(OpPassManager &pm) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD index 07bc3bd6e44c..7942aa920e23 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD @@ -66,6 +66,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:PDLDialect", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt index 2b214286b0d6..6902239f78c7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt @@ -37,6 +37,7 @@ iree_cc_library( MLIRFuncDialect MLIRGPUOps MLIRIR + MLIRLinalgTransforms MLIRLinalgUtils MLIRMemRefDialect MLIRPDLDialect diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 2f95f6a4264d..38b48dd46267 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -34,159 +35,6 @@ void mlir::iree_compiler::registerTransformDialectLLVMGPUExtension( registry.addExtensions(); } -// TODO: Maybe we need both a transform.iree.cpu.bufferize and a -// transform.iree.gpu.bufferize rather than a single common bufferize op? - -/// Apply the permutation `perm` to `vals; i.e. vals[i] is stored into -/// res[perm[i]] Return failure if perm is not a permutation. -// TODO: upstream as extraClassDeclaration once stabilized. -template -static FailureOr> permute(const SmallVector &vals, - ArrayRef perm) { - if (vals.size() != perm.size()) return failure(); - SmallVector result(vals.size()); - SmallVector seen(vals.size()); - for (auto [idx, val] : llvm::zip(perm, vals)) { - // Already seen, invalid thread_dim_mapping. - if (seen[idx]) return failure(); - result[idx] = val; - seen[idx] = true; - } - // Some not seen, invalid thread_dim_mapping. - if (!llvm::all_of(seen, [](bool b) { return b; })) return failure(); - return result; -} - -/// Helper to get apply the `thread_dim_mapping` permutation of a -/// `foreachThreadOp` to `values`. -// TODO: upstream as extraClassDeclaration once stabilized. -template -static FailureOr> getValuesPermutedByThreadMapping( - scf::ForeachThreadOp foreachThreadOp, const SmallVector &values) { - // Apply mapping permutation if specified. - auto mapping = foreachThreadOp.getThreadDimMapping(); - if (mapping && !mapping.empty()) { - auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping)); - if (failed(maybePermuted)) - return foreachThreadOp->emitError("invalid permutation"); - return *maybePermuted; - } - return values; -} - -/// Helper to get the `num_threads` of a `foreachThreadOp` after applying the -/// `thread_dim_mapping` permutation. -// TODO: upstream as extraClassDeclaration once stabilized. -static FailureOr> getNumThreads( - OpBuilder &b, scf::ForeachThreadOp foreachThreadOp) { - SmallVector threadCount = foreachThreadOp.getNumThreads(); - threadCount.resize(3, b.getIndexAttr(1)); - return getValuesPermutedByThreadMapping(foreachThreadOp, threadCount); -} - -/// Helper to get the thread indices of a `foreachThreadOp` after applying the -/// `thread_dim_mapping` permutation. -// TODO: upstream as extraClassDeclaration once stabilized. -static FailureOr> getThreadIndices( - OpBuilder &b, scf::ForeachThreadOp foreachThreadOp) { - SmallVector threadCount = foreachThreadOp.getThreadIndices(); - threadCount.resize(3, Value()); - return getValuesPermutedByThreadMapping(foreachThreadOp, threadCount); -} - -//===---------------------------------------------------------------------===// -// Patterns for ForeachThreadToGpu rewrite. -//===---------------------------------------------------------------------===// - -FailureOr> -mlir::iree_compiler::rewriteForeachThreadToGpu( - scf::ForeachThreadOp foreachThreadOp, - const SmallVector &globalWorkgroupSizes, RewriterBase &rewriter, - bool syncAfterDistribute) { - if (foreachThreadOp.getNumResults() > 0) - return foreachThreadOp->emitError( - "only bufferized scf.foreach_thread lowers to gpu.thread"); - if (foreachThreadOp.getNumThreads().size() > 3) - return foreachThreadOp->emitError( - "scf.foreach_thread with rank > 3 does not lower to gpu.thread"); - - auto maybeWorkgroupSizes = getNumThreads(rewriter, foreachThreadOp); - if (failed(maybeWorkgroupSizes) || - llvm::any_of(*maybeWorkgroupSizes, [](OpFoldResult ofr) { - return !getConstantIntValue(ofr).has_value(); - })) - return foreachThreadOp->emitError("unsupported dynamic workgroup size"); - - SmallVector workgroupSizes = llvm::to_vector(llvm::map_range( - *maybeWorkgroupSizes, - [](OpFoldResult ofr) { return getConstantIntValue(ofr).value(); })); - - // Step 1. Create the gpu.thread ops - Location loc = foreachThreadOp.getLoc(); - IndexType indexType = rewriter.getIndexType(); - - SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, - gpu::Dimension::z}; - SmallVector threadOps; - for (int64_t idx : llvm::seq(0, workgroupSizes.size())) { - threadOps.push_back( - rewriter.create(loc, indexType, gpuDims[idx])); - } - - // Step 2. Maybe create conditionals to predicate the region. - Value predicate; - for (auto [threadId, workgroupSize, globalWorkgroupSize] : - llvm::zip(threadOps, workgroupSizes, globalWorkgroupSizes)) { - if (workgroupSize > globalWorkgroupSize) { - return foreachThreadOp.emitOpError("workgroup size overflow: ") - << workgroupSize << " > " << globalWorkgroupSize; - } - if (workgroupSize == globalWorkgroupSize) continue; - Value tmpPredicate = rewriter.create( - loc, arith::CmpIPredicate::ult, threadId, - rewriter.create(loc, workgroupSize)); - predicate = - predicate ? rewriter.create(loc, predicate, tmpPredicate) - : tmpPredicate; - } - - // Step 3. Move the body of foreachThreadOp. - // Erase the terminator first, it will not be used. - rewriter.eraseOp(foreachThreadOp.getTerminator()); - Block *targetBlock; - Block::iterator insertionPoint; - if (predicate) { - // Step 3.a. If predicated, move at the beginning. - auto ifOp = - rewriter.create(loc, predicate, /*withElseRegion=*/false); - targetBlock = ifOp.thenBlock(); - insertionPoint = ifOp.thenBlock()->begin(); - } else { - // Step 3.a. Otherwise, move inline just before foreachThreadOp. - targetBlock = foreachThreadOp->getBlock(); - insertionPoint = Block::iterator(foreachThreadOp); - } - Block &sourceBlock = foreachThreadOp.getRegion().front(); - targetBlock->getOperations().splice(insertionPoint, - sourceBlock.getOperations()); - - // Step 4. RAUW thread indices to thread ops. - SmallVector threadIndices = - *getThreadIndices(rewriter, foreachThreadOp); - for (auto it : llvm::zip(threadIndices, threadOps)) { - if (!std::get<0>(it)) continue; - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - } - - // Step 5. syncthreads. - if (syncAfterDistribute) rewriter.create(loc); - - // Step 6. Erase old op. - rewriter.eraseOp(foreachThreadOp); - - return *maybeWorkgroupSizes; -} - //===---------------------------------------------------------------------===// // IREE-specific LLVMGPU transformations. //===---------------------------------------------------------------------===// @@ -195,7 +43,7 @@ mlir::iree_compiler::rewriteForeachThreadToGpu( // reuse most of the code and not require a static number of threads. // TODO: synchronizations for imperfectly nested stuff. DiagnosedSilenceableFailure -transform_dialect::ForeachThreadToGpuAndTranslationInfo::applyToOne( +transform_dialect::MapNestedForeachThreadToGpuThreads::applyToOne( func::FuncOp target, SmallVectorImpl &results, transform::TransformState &state) { if (!isa(state.getTopLevel())) { @@ -219,13 +67,8 @@ transform_dialect::ForeachThreadToGpuAndTranslationInfo::applyToOne( // TODO: no magic constant but IREE uses this extensively. workgroupSize.resize(/*size=*/3, /*value=*/1); SimplePatternRewriter rewriter(target); - auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { - rewriter.setInsertionPoint(foreachThreadOp); - if (failed(rewriteForeachThreadToGpu(foreachThreadOp, workgroupSize, - rewriter))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); + auto walkResult = mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads( + rewriter, target, workgroupSize, true); if (walkResult.wasInterrupted()) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); @@ -481,7 +324,7 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input, .create(loc, laneVal, i, /*width=*/size, /*mode=*/gpu::ShuffleMode::XOR) - .result(); + .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } return laneVal; @@ -625,7 +468,8 @@ static void warpSyncronizationFn(Location loc, OpBuilder &builder, static void populateWarpExecuteOnLane0ToScf( Operation *target, RewritePatternSet &patterns, - vector::WarpExecuteOnLane0LoweringOptions options, PatternBenefit benefit) { + const vector::WarpExecuteOnLane0LoweringOptions &options, + PatternBenefit benefit) { assert(target->hasTrait()); vector::populateWarpExecuteOnLane0OpToScfForPattern(patterns, options, benefit); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 209d2a700d71..9280954d30a6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" -def ForeachThreadToGpuAndTranslationInfo : - Op], legacy_sync}> +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> +#pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> +module attributes {hal.device.targets = [#device_target_cuda]} { + hal.executable private @conv2d_1x230x230x3_7x7x3x64_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb { + hal.executable.export public @conv2d_1x230x230x3_7x7x3x64 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @conv2d_1x230x230x3_7x7x3x64() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 230, 230, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x230x230x3xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [7, 7, 3, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<7x7x3x64xf32> + %5 = linalg.init_tensor [1, 112, 112, 64] : tensor<1x112x112x64xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> + %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%3, %4 : tensor<1x230x230x3xf32>, tensor<7x7x3x64xf32>) outs(%6 : tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 64], strides = [1, 1, 1, 1] : tensor<1x112x112x64xf32> -> !flow.dispatch.tensor + return + } + } + } + } +} + +// CHECK-LABEL: func.func @conv2d_1x230x230x3_7x7x3x64 +// CHECK-NOT: vector.transfer_write +// CHECK-NOT: vector.transfer_read +// CHECK: scf.for +// CHECK: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-4: vector.contract +// CHECK: scf.yield %{{.*}} : vector<4x4xf32> +// CHECK: scf.yield %{{.*}} : vector<4x4xf32> +// CHECK: vector.transfer_write {{.*}} : vector<4x4xf32>, memref<1x112x112x64xf32> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir index 287731d1c26c..03c778409c53 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-tile-and-distribute))))' %s | FileCheck %s #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> #pipeline_layout = #hal.pipeline.layout // CHECK-DAG: %[[BUFFER1:.+]] = memref.alloc() : memref<2x4xf32, 3> +// CHECK-DAG: %[[BUFFER2:.+]] = memref.alloc() : memref<2x256xf32, 3> // CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C1024]] step %[[C4]] { // CHECK: gpu.barrier -// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<2x4xf32, #{{.*}}> to memref<2x4xf32, 3> +// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<2x4xf32, strided<[1024, 1], offset: ?>> to memref<2x4xf32, 3> // CHECK-NOT: gpu.barrier -// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<4x256xf32, #{{.*}}> to memref<4x256xf32, 3> +// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<4x256xf32, strided<[1024, 1], offset: ?>> to memref<4x256xf32, 3> // CHECK: gpu.barrier // CHECK: scf.for %[[IND0:.+]] = %{{.*}} to %[[C2]] step %[[C2]] { // CHECK: scf.for %[[IND1:.+]] = %{{.*}} to %[[C256]] step %[[C256]] { -// CHECK-DAG: %[[A:.+]] = memref.subview %[[BUFFER1]][%[[IND0]], 0] [2, 4] [1, 1] : memref<2x4xf32, 3> to memref<2x4xf32, #{{.*}}, 3> -// CHECK-DAG: %[[B:.+]] = memref.subview %[[BUFFER0]][0, %[[IND1]]] [4, 4] [1, 1] : memref<4x256xf32, 3> to memref<4x4xf32, #{{.*}}, 3> -// CHECK-DAG: %[[C:.+]] = memref.subview %{{.*}}[%[[IND0]], %[[IND1]]] [2, 4] [1, 1] : memref<2x256xf32, #{{.*}}> to memref<2x4xf32, #{{.*}}> -// CHECK: linalg.matmul {__internal_linalg_transform__ = "vectorize", {{.*}}} ins(%[[A]], %[[B]] : memref<2x4xf32, #{{.*}}, 3>, memref<4x4xf32, #{{.*}}, 3>) outs(%[[C]] : memref<2x4xf32, #{{.*}}>) +// CHECK-DAG: %[[A:.+]] = memref.subview %[[BUFFER1]][%[[IND0]], 0] [2, 4] [1, 1] : memref<2x4xf32, 3> to memref<2x4xf32, strided<[4, 1], offset: ?>, 3> +// CHECK-DAG: %[[B:.+]] = memref.subview %[[BUFFER0]][0, %[[IND1]]] [4, 4] [1, 1] : memref<4x256xf32, 3> to memref<4x4xf32, strided<[256, 1], offset: ?>, 3> +// CHECK-DAG: %[[C:.+]] = memref.subview %[[BUFFER2]][%[[IND0]], %[[IND1]]] [2, 4] [1, 1] : memref<2x256xf32, 3> to memref<2x4xf32, strided<[256, 1], offset: ?>, 3> +// CHECK: linalg.matmul {__internal_linalg_transform__ = "vectorize", {{.*}}} ins(%[[A]], %[[B]] : memref<2x4xf32, strided<[4, 1], offset: ?>, 3>, memref<4x4xf32, strided<[256, 1], offset: ?>, 3>) outs(%[[C]] : memref<2x4xf32, strided<[256, 1], offset: ?>, 3>) // CHECK: } // CHECK: } +// CHECK: gpu.barrier +// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<2x256xf32, 3> to memref<2x256xf32, +// CHECK: gpu.barrier // ----- -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> #pipeline_layout = #hal.pipeline.layout, memref<1x32x4xf32, #{{.*}}, 3>) outs(%{{.*}} : memref<1x1x4xf32, #{{.*}}>) +// CHECK: linalg.batch_matmul {__internal_linalg_transform__ = "vectorize", {{.*}}} ins(%{{.*}}, %{{.*}} : memref<1x1x32xf32, strided<[256, 32, 1], offset: ?>, 3>, memref<1x32x4xf32, strided<[1024, 32, 1], offset: ?>, 3>) outs(%{{.*}} : memref<1x1x4xf32, strided<[256, 32, 1], offset: ?>, 3>) // CHECK: } // CHECK: } +// CHECK: gpu.barrier +// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<1x8x32xf32, 3> to memref<1x8x32xf32 +// CHECK: gpu.barrier // ----- #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> #pipeline_layout = #hal.pipeline.layout // CHECK-DAG: %[[BUFFER1:.+]] = memref.alloc() : memref<2x4xf32, 3> +// CHECK-DAG: %[[BUFFER2:.+]] = memref.alloc() : memref<2x32xf32, 3> // CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C1024]] step %[[C4]] { // CHECK: gpu.barrier -// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<2x4xf32, #{{.*}}> to memref<2x4xf32, 3> +// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<2x4xf32, strided<[1024, 1], offset: ?>> to memref<2x4xf32, 3> // CHECK-NOT: gpu.barrier -// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<4x32xf32, #{{.*}}> to memref<4x32xf32, 3> +// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<4x32xf32, strided<[1024, 1], offset: ?>> to memref<4x32xf32, 3> // CHECK: gpu.barrier // CHECK: scf.for %[[IND0:.+]] = %{{.*}} to %[[C2]] step %[[C8]] { // CHECK: scf.for %[[IND1:.+]] = %{{.*}} to %[[C32]] step %[[C64]] { -// CHECK-DAG: %[[A:.+]] = memref.subview %[[BUFFER1]][%[[IND0]], 0] [1, 4] [1, 1] : memref<2x4xf32, 3> to memref<1x4xf32, #{{.*}}, 3> -// CHECK-DAG: %[[B:.+]] = memref.subview %[[BUFFER0]][0, %[[IND1]]] [4, 1] [1, 1] : memref<4x32xf32, 3> to memref<4x1xf32, #{{.*}}, 3> -// CHECK-DAG: %[[C:.+]] = memref.subview %{{.*}}[%[[IND0]], %[[IND1]]] [1, 1] [1, 1] : memref<2x32xf32, #{{.*}}> to memref<1x1xf32, #{{.*}}> -// CHECK: linalg.matmul {__internal_linalg_transform__ = "vectorize", {{.*}}} ins(%[[A]], %[[B]] : memref<1x4xf32, #{{.*}}, 3>, memref<4x1xf32, #{{.*}}, 3>) outs(%[[C]] : memref<1x1xf32, #{{.*}}>) +// CHECK-DAG: %[[A:.+]] = memref.subview %[[BUFFER1]][%[[IND0]], 0] [1, 4] [1, 1] : memref<2x4xf32, 3> to memref<1x4xf32, strided<[4, 1], offset: ?>, 3> +// CHECK-DAG: %[[B:.+]] = memref.subview %[[BUFFER0]][0, %[[IND1]]] [4, 1] [1, 1] : memref<4x32xf32, 3> to memref<4x1xf32, strided<[32, 1], offset: ?>, 3> +// CHECK-DAG: %[[C:.+]] = memref.subview %[[BUFFER2]][%[[IND0]], %[[IND1]]] [1, 1] [1, 1] : memref<2x32xf32, 3> to memref<1x1xf32, strided<[32, 1], offset: ?>, 3> +// CHECK: linalg.matmul {__internal_linalg_transform__ = "vectorize", {{.*}}} ins(%[[A]], %[[B]] : memref<1x4xf32, strided<[4, 1], offset: ?>, 3>, memref<4x1xf32, strided<[32, 1], offset: ?>, 3>) outs(%[[C]] : memref<1x1xf32, strided<[32, 1], offset: ?>, 3>) // CHECK: } // CHECK: } +// CHECK: gpu.barrier +// CHECK: memref.copy {{.*}}, {{.*}} {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<2x32xf32, 3> to memref<2x32xf32 +// CHECK: gpu.barrier // ----- @@ -294,7 +307,7 @@ hal.executable @reduction_dispatch { // ----- -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> #pipeline_layout = #hal.pipeline.layout -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> #pipeline_layout = #hal.pipeline.layout, - translation_info = , + translation_info = , workgroup_size = [16, 8, 1]> #pipeline_layout = #hal.pipeline.layout } // CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @_lowering_config_test_dispatch_1 // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK-SAME: workgroup_size = [16 : index, 8 : index, 1 : index] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir index 3e4a70f546b5..3e57a92bb349 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir @@ -289,7 +289,7 @@ hal.executable private @matmul_tensors { // ----- #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> #pipeline_layout = #hal.pipeline.layout) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma -// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void // CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br -// CHECK: nvvm.cp.async.wait.group 3 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 2 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 1 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 0 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK-COUNT-8: llvm.fadd -// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, f32, f32, f32, f32, f32, f32, f32, f32 +// CHECK-NOT: nvvm.wmma.mma +// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, f32, f32, f32, f32, f32, f32, f32, f32 +// CHECK: vvm.barrier0 +// CHECK: llvm.load {{.*}} : !llvm.ptr, 3> +// CHECK: llvm.fadd {{.*}} : vector<4xf32> +// CHECK: llvm.store {{.*}} : !llvm.ptr> +// CHECK: llvm.load {{.*}} : !llvm.ptr, 3> +// CHECK: llvm.fadd {{.*}} : vector<4xf32> +// CHECK: llvm.store {{.*}} : !llvm.ptr> // mma.sync case: // MMASYNC-LABEL: hal.executable public @mma_fused @@ -506,30 +501,22 @@ hal.executable @mma_fused { // MMASYNC: nvvm.cp.async.wait.group 3 // MMASYNC-COUNT-4: nvvm.ldmatrix{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32) // MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// MMASYNC-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void // MMASYNC: nvvm.cp.async.commit.group // MMASYNC: llvm.br -// MMASYNC: nvvm.cp.async.wait.group 3 -// MMASYNC-COUNT-4: nvvm.ldmatrix{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32) -// MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC: nvvm.cp.async.wait.group 2 -// MMASYNC-COUNT-4: nvvm.ldmatrix{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32) -// MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC: nvvm.cp.async.wait.group 1 -// MMASYNC-COUNT-4: nvvm.ldmatrix{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32) -// MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC: nvvm.cp.async.wait.group 0 -// MMASYNC-COUNT-4: nvvm.ldmatrix{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32) -// MMASYNC-COUNT-8: nvvm.mma.sync +// MMASYNC-NOT: nvvm.mma.sync // MMASYNC-COUNT-4: llvm.store {{.*}} : !llvm.ptr, 3> -// MMASYNC-COUNT-2: llvm.load {{.*}} : !llvm.ptr, 3> -// MMASYNC-COUNT-2: llvm.store {{.*}} : !llvm.ptr> - -// C matrix promotion prevent efficient fusion with matmul consumer, this needs -// to be fixed to get good performance. -// MMASYNC-COUNT-32: llvm.load {{.*}} : !llvm.ptr> -// MMASYNC-COUNT-32: llvm.fadd {{.*}} : vector<8xf32> -// MMASYNC-COUNT-32: llvm.store {{.*}} : !llvm.ptr> +// MMASYNC-COUNT: llvm.load {{.*}} : !llvm.ptr, 3> +// MMASYNC-COUNT: llvm.store {{.*}} : !llvm.ptr> +// MMASYNC-COUNT: llvm.load {{.*}} : !llvm.ptr, 3> +// MMASYNC-COUNT: llvm.store {{.*}} : !llvm.ptr> +// MMASYNC-COUNT: nvvm.barrier0 +// MMASYNC-COUNT: llvm.load {{.*}} : !llvm.ptr, 3> +// MMASYNC-COUNT: llvm.fadd {{.*}} : vector<4xf32> +// MMASYNC-COUNT: llvm.store {{.*}} : !llvm.ptr> +// MMASYNC-COUNT: llvm.load {{.*}} : !llvm.ptr, 3> +// MMASYNC-COUNT: llvm.fadd {{.*}} : vector<4xf32> +// MMASYNC-COUNT: llvm.store {{.*}} : !llvm.ptr> @@ -601,23 +588,16 @@ hal.executable @mma_fused_fp16 { // CHECK: nvvm.cp.async.wait.group 3 // CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) // CHECK-COUNT-1: nvvm.wmma.mma -// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void // CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br -// CHECK: nvvm.cp.async.wait.group 3 -// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECK-COUNT-1: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 2 -// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECK-COUNT-1: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 1 -// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECK-COUNT-1: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 0 -// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -// CHECK-COUNT-1: nvvm.wmma.mma -// CHECK-COUNT-4: llvm.fadd -// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> +// CHECK-NOT: nvvm.wmma.mma +// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> +// CHECK: vvm.barrier0 +// CHECK: llvm.load {{.*}} : !llvm.ptr, 3> +// CHECK: llvm.fadd {{.*}} : vector<8xf16> +// CHECK: llvm.store {{.*}} : !llvm.ptr> +// CHECK: vvm.barrier0 // mma.sync case: // MMASYNC-LABEL: hal.executable public @mma_fused_fp16 @@ -635,21 +615,10 @@ hal.executable @mma_fused_fp16 { // MMASYNC: nvvm.cp.async.wait.group 3 // MMASYNC-COUNT-4: nvvm.ldmatrix {{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> // MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// MMASYNC-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void // MMASYNC: nvvm.cp.async.commit.group // MMASYNC: llvm.br -// MMASYNC: nvvm.cp.async.wait.group 3 -// MMASYNC-COUNT-4: nvvm.ldmatrix {{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> -// MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC: nvvm.cp.async.wait.group 2 -// MMASYNC-COUNT-4: nvvm.ldmatrix {{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> -// MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC: nvvm.cp.async.wait.group 1 -// MMASYNC-COUNT-4: nvvm.ldmatrix {{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> -// MMASYNC-COUNT-8: nvvm.mma.sync -// MMASYNC: nvvm.cp.async.wait.group 0 -// MMASYNC-COUNT-4: nvvm.ldmatrix {{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> -// MMASYNC-COUNT-8: nvvm.mma.sync +// MMASYNC-NOT: nvvm.mma.sync // MMASYNC-COUNT-4: llvm.store {{.*}} : !llvm.ptr, 3> // MMASYNC: llvm.load {{.*}} : !llvm.ptr, 3> // MMASYNC: llvm.store {{.*}} : !llvm.ptr> @@ -723,22 +692,16 @@ hal.executable @mma_fused_fp16 { // CHECK: nvvm.cp.async.wait.group 3 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma -// CHECK-COUNT-2: nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16 +// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void // CHECK: nvvm.cp.async.commit.group // CHECK: llvm.br -// CHECK: nvvm.cp.async.wait.group 3 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 2 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 1 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 0 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, f32, f32, f32, f32, f32, f32, f32, f32 +// CHECK-NOT: nvvm.wmma.mma +// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, f32, f32, f32, f32, f32, f32, f32, f32 +// CHECK: vvm.barrier0 +// CHECK: llvm.load {{.*}} : !llvm.ptr, 3> +// CHECK: llvm.store {{.*}} : !llvm.ptr> +// CHECK: llvm.load {{.*}} : !llvm.ptr, 3> +// CHECK: llvm.store {{.*}} : !llvm.ptr> // ----- @@ -799,16 +762,16 @@ hal.executable @mma_fused_fp16 { // CHECK: nvvm.cp.async.wait.group 3 // CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) // CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 2 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 1 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK: nvvm.cp.async.wait.group 0 -// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32) -// CHECK-COUNT-2: nvvm.wmma.mma -// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, f32, f32, f32, f32, f32, f32, f32, f32 +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void +// CHECK: nvvm.cp.async.commit.group +// CHECK: llvm.br +// CHECK-NOT: nvvm.wmma.mma +// CHECK-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr, f32, f32, f32, f32, f32, f32, f32, f32 +// CHECK: vvm.barrier0 +// CHECK: llvm.load {{.*}} : !llvm.ptr, 3> +// CHECK: llvm.store {{.*}} : !llvm.ptr> +// CHECK: llvm.load {{.*}} : !llvm.ptr, 3> +// CHECK: llvm.store {{.*}} : !llvm.ptr> // ----- @@ -1063,5 +1026,3 @@ hal.executable private @shared_mem_transpose { // CHECK: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr> // CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : !llvm.ptr, 3> // CHECK: nvvm.barrier0 - -// ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduce_bank_conflicts.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduce_bank_conflicts.mlir index 6d7bfe7b9ffb..8118fc932120 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduce_bank_conflicts.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduce_bank_conflicts.mlir @@ -1,22 +1,33 @@ -// RUN: iree-opt %s --iree-llvmgpu-reduce-bank-conflicts | FileCheck %s +// RUN: iree-opt %s --split-input-file --iree-llvmgpu-reduce-bank-conflicts | FileCheck %s #map = affine_map<(d0, d1, d2) -> (d0 * 2048 + d1 * 64 + d2)> -// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1, d2) -> (d0 * 2176 + d1 * 68 + d2)> // CHECK-LABEL: func.func @pad_alloc func.func @pad_alloc(%a: memref<1024x1024xf32>) { // CHECK: %[[A:.*]] = memref.alloc() : memref<4x32x68xf32, 3> %0 = memref.alloc() : memref<4x32x64xf32, 3> -// CHECK: %[[S1:.*]] = memref.subview %[[A]][0, 0, 0] [4, 32, 64] [1, 1, 1] : memref<4x32x68xf32, 3> to memref<4x32x64xf32, #[[$MAP]], 3> -// CHECK: %[[S2:.*]] = memref.subview %[[S1]][0, 0, 0] [1, 32, 64] [1, 1, 1] : memref<4x32x64xf32, #[[$MAP]], 3> to memref<1x32x64xf32, #[[$MAP]], 3> +// CHECK: %[[S1:.*]] = memref.subview %[[A]][0, 0, 0] [4, 32, 64] [1, 1, 1] : memref<4x32x68xf32, 3> to memref<4x32x64xf32, strided<[2176, 68, 1]>, 3> +// CHECK: %[[S2:.*]] = memref.subview %[[S1]][0, 0, 0] [1, 32, 64] [1, 1, 1] : memref<4x32x64xf32, strided<[2176, 68, 1]>, 3> to memref<1x32x64xf32, strided<[2176, 68, 1]>, 3> %1 = memref.subview %0[0, 0, 0] [1, 32, 64] [1, 1, 1] : memref<4x32x64xf32, 3> to memref<1x32x64xf32, #map, 3> %c0 = arith.constant 0 : index %cst_0 = arith.constant 0.000000e+00 : f32 %2 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<4xf32> -// CHECK: vector.transfer_write %{{.*}}, %[[S2]][%{{.*}}, %{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<1x32x64xf32, #[[$MAP]], 3> +// CHECK: vector.transfer_write %{{.*}}, %[[S2]][%{{.*}}, %{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<1x32x64xf32, strided<[2176, 68, 1]>, 3> vector.transfer_write %2, %1[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<1x32x64xf32, #map, 3> return } + +// ----- + +// CHECK-LABEL: func.func @pad_alloc_negative +func.func @pad_alloc_negative(%a: memref<1024x1024xf32>, %i: index, %v: vector<4xf32>) { + %c0 = arith.constant 0 : index +// CHECK: memref.alloc(%{{.*}}) : memref + %0 = memref.alloc(%i) : memref + vector.transfer_write %v, %0[%c0, %c0, %c0] {in_bounds = [true]} : + vector<4xf32>, memref + return +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_alloc.mlir index 87cdd970b690..5b7de94b501d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_alloc.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tensor_alloc.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt %s -iree-llvmgpu-alloc | FileCheck %s +// RUN: iree-opt %s -allow-unregistered-dialect --split-input-file -iree-llvmgpu-alloc | FileCheck %s func.func @matmul_2048x512x1024() { %c0 = arith.constant 0 : index @@ -27,3 +27,58 @@ func.func @matmul_2048x512x1024() { // CHECK: %[[PB:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {bufferization.escape = [false]} : tensor<32x128xf32> // CHECK: %[[M:.*]] = linalg.matmul {{.*}} ins(%[[PA]], %[[PB]] : tensor<32x32xf32>, tensor<32x128xf32>) outs(%{{.*}} : tensor<32x128xf32>) -> tensor<32x128xf32> // CHECK: scf.yield %[[M]] : tensor<32x128xf32> + +// ----- + +func.func @matmul_1x384x384() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 384], strides = [1, 1] : !flow.dispatch.tensor -> tensor<1x384xf32> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x] + %5 = flow.dispatch.tensor.load %2, offsets = [0, %4], sizes = [1, 128], strides = [1, 1] : !flow.dispatch.tensor -> tensor<1x128xf32> + %6 = flow.dispatch.tensor.load %1, offsets = [0, %4], sizes = [384, 128], strides = [1, 1] : !flow.dispatch.tensor -> tensor<384x128xf32> + %7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%5 : tensor<1x128xf32>) -> tensor<1x128xf32> + %8 = linalg.matmul {lowering_config = #iree_codegen.lowering_config} ins(%3, %6 : tensor<1x384xf32>, tensor<384x128xf32>) outs(%7 : tensor<1x128xf32>) -> tensor<1x128xf32> + flow.dispatch.tensor.store %8, %2, offsets = [0, %4], sizes = [1, 128], strides = [1, 1] : tensor<1x128xf32> -> !flow.dispatch.tensor + return +} + +// CHECK-LABEL: func.func @matmul_1x384x384 +// CHECK-NOT: bufferization.alloc_tensor() +// CHECK: return + +// ----- + +func.func @matmul_multi_uses() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y] + %4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x] + %5 = flow.dispatch.tensor.load %2, offsets = [%3, %4], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor -> tensor<32x128xf32> + %6 = flow.dispatch.tensor.load %0, offsets = [%3, 0], sizes = [32, 1024], strides = [1, 1] : !flow.dispatch.tensor -> tensor<32x1024xf32> + %7 = flow.dispatch.tensor.load %1, offsets = [0, %4], sizes = [1024, 128], strides = [1, 1] : !flow.dispatch.tensor -> tensor<1024x128xf32> + %8 = linalg.fill ins(%cst : f32) outs(%5 : tensor<32x128xf32>) -> tensor<32x128xf32> + %9 = linalg.matmul ins(%6, %7 : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%8 : tensor<32x128xf32>) -> tensor<32x128xf32> + "some_use"(%6) : (tensor<32x1024xf32>) -> () + flow.dispatch.tensor.store %9, %2, offsets = [%3, %4], sizes = [32, 128], strides = [1, 1] : tensor<32x128xf32> -> !flow.dispatch.tensor + return +} + +// test corner case where the promoted value has multiple uses. +// CHECK-LABEL: func.func @matmul_multi_uses +// CHECK: %[[C:.*]] = flow.dispatch.tensor.load +// CHECK: %[[A:.*]] = flow.dispatch.tensor.load +// CHECK: %[[B:.*]] = flow.dispatch.tensor.load +// CHECK: %[[PA:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {bufferization.escape = [false]} : tensor<32x1024xf32> +// CHECK: %[[PB:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {bufferization.escape = [false]} : tensor<1024x128xf32> +// CHECK: %[[M:.*]] = linalg.matmul {{.*}} ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%{{.*}} : tensor<32x128xf32>) -> tensor<32x128xf32> +// CHECK: "some_use"(%[[A]]) : (tensor<32x1024xf32>) -> () diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir index fdcf20683dc2..649b7e63a46d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir @@ -37,7 +37,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 * 4)> // CHECK-LABEL: func.func @add_tensor -// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK-DAG: %[[A:.*]] = hal.interface.binding.subspan set(0) binding(0) // CHECK-DAG: %[[B:.*]] = hal.interface.binding.subspan set(0) binding(1) // CHECK-DAG: %[[C:.*]] = hal.interface.binding.subspan set(0) binding(2) @@ -103,19 +103,19 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // First the scf.foreach for the linalg.fill. // CHECK: scf.foreach_thread // then the reduction case. -// CHECK: %[[T:.*]] = scf.foreach_thread (%[[ARG:.*]]) in (%[[C64]]) shared_outs(%[[O:.+]] = %{{.+}}) -> (tensor<64xf32>) { -// CHECK: %[[OUTSLICE:.*]] = tensor.extract_slice %{{.*}}[%[[ARG]], 0] [1, 384] [1, 1] : tensor<64x384xf32> to tensor<1x384xf32> -// CHECK: %[[A:.*]] = tensor.extract_slice %[[O]][%[[ARG]]] [1] [1] : tensor<64xf32> to tensor<1xf32> -// CHECK: %[[R:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C384]] step %[[C4]] iter_args(%[[ACC:.*]] = %[[A]]) -> (tensor<1xf32>) { -// CHECK: %[[E:.*]] = tensor.extract_slice %[[OUTSLICE]][0, %[[IV]]] [1, 4] [1, 1] : tensor<1x384xf32> to tensor<1x4xf32> -// CHECK: %[[L:.*]] = linalg.generic {{.*}} ins(%[[E]] : tensor<1x4xf32>) outs(%[[ACC]] : tensor<1xf32>) -// CHECK: arith.addf -// CHECK: linalg.yield %{{.*}} : f32 -// CHECK: } -> tensor<1xf32> -// CHECK: scf.yield %[[L]] : tensor<1xf32> -// CHECK: } -// CHECK: scf.foreach_thread.perform_concurrently { -// CHECK: tensor.parallel_insert_slice %[[R]] into %[[O]][%[[ARG]]] [1] [1] : tensor<1xf32> into tensor<64xf32> -// CHECK: } -// CHECK: } {thread_dim_mapping = [0, 1, 2]} +// CHECK: %[[T:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C384]] step %[[C4]] iter_args(%[[ACC:.*]] = %{{.*}}) -> (tensor<64xf32>) { +// CHECK: %[[OUTSLICE:.*]] = tensor.extract_slice %{{.*}}[0, %[[IV]]] [64, 4] [1, 1] : tensor<64x384xf32> to tensor<64x4xf32> +// CHECK: %[[F:.*]] = scf.foreach_thread (%[[ARG:.*]]) in (%[[C64]]) shared_outs(%[[O:.+]] = %[[ACC]]) -> (tensor<64xf32>) { +// CHECK: %[[E:.*]] = tensor.extract_slice %[[OUTSLICE]][%[[ARG]], 0] [1, 4] [1, 1] : tensor<64x4xf32> to tensor<1x4xf32> +// CHECK: %[[A:.*]] = tensor.extract_slice %[[O]][%[[ARG]]] [1] [1] : tensor<64xf32> to tensor<1xf32> +// CHECK: %[[L:.*]] = linalg.generic {{.*}} ins(%[[E]] : tensor<1x4xf32>) outs(%[[A]] : tensor<1xf32>) +// CHECK: arith.addf +// CHECK: linalg.yield %{{.*}} : f32 +// CHECK: } -> tensor<1xf32> +// CHECK: scf.foreach_thread.perform_concurrently { +// CHECK: tensor.parallel_insert_slice %[[L]] into %[[O]][%[[ARG]]] [1] [1] : tensor<1xf32> into tensor<64xf32> +// CHECK: } +// CHECK: } {thread_dim_mapping = [0, 1, 2]} +// CHECK: scf.yield %[[F]] : tensor<64xf32> +// CHECK: } // CHECK: flow.dispatch.tensor.store %[[T]], %{{.}}, offsets = [%{{.*}}], sizes = [64], strides = [1] : tensor<64xf32> -> !flow.dispatch.tensor diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir index 93a56af2f622..93091c3d9ab8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir @@ -11,5 +11,5 @@ transform.structured.canonicalized_sequence failures(propagate) { // Get the function to which to apply to. %2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_2 %func = transform.get_closest_isolated_parent %2 - transform.iree.foreach_thread_to_gpu_and_translation_info %func { workgroup_size = [10, 11]} + transform.iree.map_nested_foreach_thread_to_gpu_threads %func { workgroup_size = [10, 11]} } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir index e33ceb444bf8..4232bb365373 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' %s --fold-memref-alias-ops -canonicalize -cse | FileCheck %s #device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> @@ -31,26 +31,35 @@ module attributes {hal.device.targets = [#device_target_cuda]} { } // CHECK-LABEL: hal.executable public @transpose_dispatch_0 -// CHECK: hal.executable.variant public @cuda -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[CST:.+]] = arith.constant 0 -// CHECK: %[[D3:.+]] = memref.alloc() : memref<32x33xf32, 3> -// CHECK: %[[D4:.+]] = memref.subview %[[D3]][0, 0] [32, 32] [1, 1] : memref<32x33xf32, 3> to memref<32x32xf32, #{{.*}}, 3> -// CHECK: %[[D9:.+]] = memref.subview %[[D6:.+]][%{{.*}}, %{{.*}}] [32, 32] [1, 1] : memref<4096x4096xf32> to memref<32x32xf32, #{{.*}}> -// CHECK: %[[D10:.+]] = memref.subview %[[D5:.+]][%{{.*}}, %{{.*}}] [32, 32] [1, 1] : memref<4096x4096xf32> to memref<32x32xf32, #{{.*}}> +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[D0:.*]] = gpu.thread_id x +// CHECK-DAG: %[[D1:.*]] = gpu.thread_id y +// CHECK-DAG: %[[D2:.*]] = gpu.thread_id z +// CHECK-DAG: %[[D3:.*]] = memref.alloc() : memref<32x33xf32, 3> +// CHECK: %[[D4:.*]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<4096x4096xf32> +// CHECK: memref.assume_alignment %[[D4]], 64 : memref<4096x4096xf32> +// CHECK: %[[D5:.*]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<4096x4096xf32> +// CHECK: memref.assume_alignment %[[D5]], 64 : memref<4096x4096xf32> // CHECK: gpu.barrier -// CHECK: %[[D13:.+]] = vector.transfer_read %[[D10]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true]} : memref<32x32xf32, #{{.*}}>, vector<4xf32> -// CHECK: vector.transfer_write %[[D13]], %[[D4]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<32x32xf32, #{{.*}}, 3> +// CHECK: %[[D6:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]], %{{.*}}] +// CHECK: %[[D7:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: %[[D8:.*]] = vector.transfer_read %[[D4]]{{\[}}%[[D6]], %[[D7]]], %[[CST]] {in_bounds = [true, true]} : memref<4096x4096xf32>, vector<1x4xf32> +// CHECK: %[[D9:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]]] +// CHECK: %[[D10:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: vector.transfer_write %[[D8]], %[[D3]]{{\[}}%[[D9]], %[[D10]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, 3> // CHECK: gpu.barrier -// CHECK: %[[D15:.+]] = memref.subview %[[D4]][%{{.*}}, %{{.*}}] [4, 1] [1, 1] : memref<32x32xf32, #{{.*}}, 3> to memref<4x1xf32, #{{.*}}, 3> -// CHECK: %[[D16:.+]] = memref.subview %[[D9]][%{{.*}}, %{{.*}}] [1, 4] [1, 1] : memref<32x32xf32, #{{.*}}> to memref<1x4xf32, #{{.*}}> -// CHECK: %[[D17:.+]] = vector.transfer_read %[[D15]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true, true]} : memref<4x1xf32, #{{.*}}, 3>, vector<4x1xf32> -// CHECK: %[[D18:.+]] = vector.shape_cast %[[D17]] : vector<4x1xf32> to vector<1x4xf32> -// CHECK: %[[D19:.+]] = vector.extract %[[D18]][0] : vector<1x4xf32> -// CHECK: vector.transfer_write %[[D19]], %[[D16]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<1x4xf32, #{{.*}}> +// CHECK: %[[D11:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: %[[D12:.*]] = vector.transfer_read %[[D3]]{{\[}}%[[D11]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, 3>, vector<4x1xf32> +// CHECK: %[[D13:.*]] = vector.shape_cast %[[D12]] : vector<4x1xf32> to vector<1x4xf32> +// CHECK: %[[D14:.*]] = vector.extract %[[D13]][0] : vector<1x4xf32> +// CHECK: %[[D15:.*]] = affine.apply #{{.*}}(){{\[}}%{{.*}}, %[[D1]]] +// CHECK: %[[D16:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: vector.transfer_write %[[D14]], %[[D5]]{{\[}}%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32> // ----- + #device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer>]>]> @@ -85,23 +94,270 @@ module attributes {hal.device.targets = [#device_target_cuda]} { } // CHECK-LABEL: hal.executable public @transpose_single_operand_dispatch_0_generic_768x2048 -// CHECK: hal.executable.variant public @cuda -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[CST:.+]] = arith.constant 0 -// CHECK: %[[D3:.+]] = memref.alloc() : memref<32x33xf32, 3> -// CHECK: %[[D4:.+]] = memref.subview %[[D3]][0, 0] [32, 32] [1, 1] : memref<32x33xf32, 3> to memref<32x32xf32, #{{.*}}, 3> -// CHECK: %[[D5:.+]] = hal.interface.binding.subspan -// CHECK: %[[D11:.+]] = memref.subview %[[D5:.+]][%{{.*}}, %{{.*}}] [32, 32] [1, 1] : memref<2048x768xf32> to memref<32x32xf32, #{{.*}}> +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[D0:.*]] = gpu.thread_id x +// CHECK: %[[D1:.*]] = gpu.thread_id y +// CHECK: %[[D2:.*]] = gpu.thread_id z +// CHECK: %[[D3:.*]] = memref.alloc() : memref<32x33xf32, 3> +// CHECK: %[[D4:.*]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<2048x768xf32> +// CHECK: memref.assume_alignment %[[D4]], 64 : memref<2048x768xf32> +// CHECK: %[[D5:.*]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<768x2048xf32> +// CHECK: memref.assume_alignment %[[D5]], 64 : memref<768x2048xf32> +// CHECK: %[[D6:.*]] = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<768x2048xf32> +// CHECK: memref.assume_alignment %[[D6]], 64 : memref<768x2048xf32> // CHECK: gpu.barrier -// CHECK: %[[D15:.+]] = vector.transfer_read %[[D11]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true]} : memref<32x32xf32, #{{.*}}>, vector<4xf32> -// CHECK: vector.transfer_write %[[D15]], %[[D4]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<32x32xf32, #{{.*}}, 3> +// CHECK: %[[D7:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]], %{{.*}}] +// CHECK: %[[D8:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: %[[D9:.*]] = vector.transfer_read %[[D4]]{{\[}}%[[D7]], %[[D8]]], %[[CST]] {in_bounds = [true, true]} : memref<2048x768xf32>, vector<1x4xf32> +// CHECK: %[[D10:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]]] +// CHECK: %[[D11:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: vector.transfer_write %[[D9]], %[[D3]]{{\[}}%[[D10]], %[[D11]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, 3> // CHECK: gpu.barrier -// CHECK: %[[D17:.+]] = memref.subview %[[D4]][%{{.*}}, %{{.*}}] [4, 1] [1, 1] : memref<32x32xf32, #{{.*}}, 3> to memref<4x1xf32, #{{.*}}, 3> -// CHECK: %[[D18:.+]] = memref.subview %{{.*}}[%{{.*}}, %{{.*}}] [1, 4] [1, 1] : memref<32x32xf32, #{{.*}}> to memref<1x4xf32, #{{.*}}> -// CHECK: %[[D19:.+]] = memref.subview %{{.*}}[%{{.*}}, %{{.*}}] [1, 4] [1, 1] : memref<32x32xf32, #{{.*}}> to memref<1x4xf32, #{{.*}}> -// CHECK: %[[D20:.+]] = vector.transfer_read %[[D17]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true, true]} : memref<4x1xf32, #{{.*}}, 3>, vector<4x1xf32> -// CHECK: %[[D21:.+]] = vector.shape_cast %[[D20]] : vector<4x1xf32> to vector<1x4xf32> -// CHECK: %[[D22:.+]] = vector.transfer_read %[[D18]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true]} : memref<1x4xf32, #{{.*}}>, vector<4xf32> -// CHECK: %[[D23:.+]] = vector.extract %[[D21]][0] : vector<1x4xf32> -// CHECK: %[[D24:.+]] = arith.addf %[[D23]], %[[D22]] : vector<4xf32> -// CHECK: vector.transfer_write %[[D24]], %[[D19]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<1x4xf32, #{{.*}}> +// CHECK: %[[D12:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]]{{\[}}%[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, 3>, vector<4x1xf32> +// CHECK: %[[D14:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<1x4xf32> +// CHECK: %[[D15:.*]] = affine.apply #{{.*}}(){{\[}}%{{.*}}, %[[D1]]] +// CHECK: %[[D16:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: %[[D17:.*]] = vector.transfer_read %[[D5]]{{\[}}%[[D15]], %[[D16]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32>, vector<4xf32> +// CHECK: %[[D18:.*]] = vector.extract %[[D14]][0] : vector<1x4xf32> +// CHECK: %[[D19:.*]] = arith.addf %[[D18]], %[[D17]] : vector<4xf32> +// CHECK: vector.transfer_write %[[D19]], %[[D6]]{{\[}}%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32> + +// ----- + +#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}> +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> +#pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> +module attributes {hal.device.targets = [#device_target_cuda]} { + hal.executable @transpose_3d_no_dispatch_0_generic_768x2048x1024 { + hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb { + hal.executable.export public @transpose_3d_no_dispatch_0_generic_768x2048x1024 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_no_dispatch_0_generic_768x2048x1024() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2048, 768, 1024], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<2048x768x1024xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [768, 2048, 1024], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<768x2048x1024xf32> + %5 = linalg.init_tensor [768, 2048, 1024] : tensor<768x2048x1024xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<2048x768x1024xf32>, tensor<768x2048x1024xf32>) outs(%5 : tensor<768x2048x1024xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %7 = arith.addf %arg0, %arg1 : f32 + linalg.yield %7 : f32 + } -> tensor<768x2048x1024xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [768, 2048, 1024], strides = [1, 1, 1] : tensor<768x2048x1024xf32> -> !flow.dispatch.tensor + return + } + } + } + } +} + +// CHECK-LABEL: hal.executable public @transpose_3d_no_dispatch_0_generic_768x2048x1024 { +// CHECK-NOT: gpu.barrier +// CHECK-NOT: memref.alloc +// CHECK: return + +// ----- + +#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}> +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> +#pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> +module attributes {hal.device.targets = [#device_target_cuda]} { + hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 { + hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb { + hal.executable.export public @transpose_3d_yes_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_yes_dispatch_0_generic_10x768x2048() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [10, 2048, 768], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<10x2048x768xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<10x768x2048xf32> + %5 = linalg.init_tensor [10, 768, 2048] : tensor<10x768x2048xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<10x2048x768xf32>, tensor<10x768x2048xf32>) outs(%5 : tensor<10x768x2048xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %7 = arith.addf %arg0, %arg1 : f32 + linalg.yield %7 : f32 + } -> tensor<10x768x2048xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : tensor<10x768x2048xf32> -> !flow.dispatch.tensor + return + } + } + } + } +} + +// CHECK-LABEL: hal.executable public @transpose_3d_yes_dispatch_0_generic_10x768x2048 { +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[D0:.*]] = gpu.thread_id x +// CHECK: %[[D1:.*]] = gpu.thread_id y +// CHECK: %[[D2:.*]] = gpu.thread_id z +// CHECK: %[[D3:.*]] = memref.alloc() : memref<1x32x33xf32, 3> +// CHECK: %[[D4:.*]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<10x2048x768xf32> +// CHECK: memref.assume_alignment %[[D4]], 64 : memref<10x2048x768xf32> +// CHECK: %[[D5:.*]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<10x768x2048xf32> +// CHECK: memref.assume_alignment %[[D5]], 64 : memref<10x768x2048xf32> +// CHECK: %[[D6:.*]] = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<10x768x2048xf32> +// CHECK: memref.assume_alignment %[[D6]], 64 : memref<10x768x2048xf32> +// CHECK: gpu.barrier +// CHECK: %[[D7:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]], %{{.*}}] +// CHECK: %[[D8:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: %[[D9:.*]] = vector.transfer_read %[[D4]]{{\[}}%{{.*}}, %[[D7]], %[[D8]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x2048x768xf32>, vector<1x1x4xf32> +// CHECK: %[[D10:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]]] +// CHECK: %[[D11:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: vector.transfer_write %[[D9]], %[[D3]]{{\[}}%[[C0]], %[[D10]], %[[D11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, 3> +// CHECK: gpu.barrier +// CHECK: %[[D12:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]]{{\[}}%[[C0]], %[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, 3>, vector<4x1xf32> +// CHECK: %[[D14:.*]] = vector.broadcast %[[D13]] : vector<4x1xf32> to vector<1x4x1xf32> +// CHECK: %[[D15:.*]] = vector.shape_cast %[[D14]] : vector<1x4x1xf32> to vector<1x1x4xf32> +// CHECK: %[[D16:.*]] = affine.apply #{{.*}}(){{\[}}%{{.*}}, %[[D1]]] +// CHECK: %[[D17:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: %[[D18:.*]] = vector.transfer_read %[[D5]]{{\[}}%{{.*}}, %[[D16]], %[[D17]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32>, vector<4xf32> +// CHECK: %[[D19:.*]] = vector.extract %[[D15]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[D20:.*]] = arith.addf %[[D19]], %[[D18]] : vector<4xf32> +// CHECK: vector.transfer_write %[[D20]], %[[D6]]{{\[}}%{{.*}}, %[[D16]], %[[D17]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32> + +// ----- + +#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}> +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> +#pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> +module attributes {hal.device.targets = [#device_target_cuda]} { + hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { + hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb { + hal.executable.export public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_trans_out_dispatch_0_generic_10x2048x768() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<10x768x2048xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<10x768x2048xf32> + %5 = linalg.init_tensor [10, 2048, 768] : tensor<10x2048x768xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<10x768x2048xf32>, tensor<10x768x2048xf32>) outs(%5 : tensor<10x2048x768xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %7 = arith.addf %arg0, %arg1 : f32 + linalg.yield %7 : f32 + } -> tensor<10x2048x768xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [10, 2048, 768], strides = [1, 1, 1] : tensor<10x2048x768xf32> -> !flow.dispatch.tensor + return + } + } + } + } +} + +// CHECK-LABEL: hal.executable public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[D0:.*]] = gpu.thread_id x +// CHECK: %[[D1:.*]] = gpu.thread_id y +// CHECK: %[[D2:.*]] = gpu.thread_id z +// CHECK: %[[D3:.*]] = memref.alloc() : memref<1x32x33xf32, 3> +// CHECK: %[[D4:.*]] = memref.alloc() : memref<1x32x33xf32, 3> +// CHECK: %[[D5:.*]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<10x768x2048xf32> +// CHECK: memref.assume_alignment %[[D5]], 64 : memref<10x768x2048xf32> +// CHECK: %[[D6:.*]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<10x768x2048xf32> +// CHECK: memref.assume_alignment %[[D6]], 64 : memref<10x768x2048xf32> +// CHECK: %[[D7:.*]] = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%[[C0]]) alignment(64) : memref<10x2048x768xf32> +// CHECK: memref.assume_alignment %[[D7]], 64 : memref<10x2048x768xf32> +// CHECK: gpu.barrier +// CHECK: %[[D8:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]], %{{.*}}] +// CHECK: %[[D9:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: %[[D10:.*]] = vector.transfer_read %[[D5]]{{\[}}%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32>, vector<1x1x4xf32> +// CHECK: %[[D11:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %[[D1]], %[[D2]]] +// CHECK: %[[D12:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: vector.transfer_write %[[D10]], %[[D4]]{{\[}}%[[C0]], %[[D11]], %[[D12]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, 3> +// CHECK: %[[D13:.*]] = vector.transfer_read %[[D6]]{{\[}}%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32>, vector<1x1x4xf32> +// CHECK: vector.transfer_write %[[D13]], %[[D3]]{{\[}}%[[C0]], %[[D11]], %[[D12]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, 3> +// CHECK: gpu.barrier +// CHECK: %[[D14:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]]] +// CHECK: %[[D15:.*]] = vector.transfer_read %[[D4]]{{\[}}%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, 3>, vector<4x1xf32> +// CHECK: %[[D16:.*]] = vector.transfer_read %[[D3]]{{\[}}%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, 3>, vector<4x1xf32> +// CHECK: %[[D17:.*]] = arith.addf %[[D15]], %[[D16]] : vector<4x1xf32> +// CHECK: %[[D18:.*]] = vector.broadcast %[[D17]] : vector<4x1xf32> to vector<1x4x1xf32> +// CHECK: %[[D19:.*]] = vector.shape_cast %[[D18]] : vector<1x4x1xf32> to vector<1x1x4xf32> +// CHECK: %[[D20:.*]] = vector.extract %[[D19]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[D21:.*]] = affine.apply #{{.*}}(){{\[}}%{{.*}}, %[[D1]]] +// CHECK: %[[D22:.*]] = affine.apply #{{.*}}(){{\[}}%[[D0]], %{{.*}}] +// CHECK: vector.transfer_write %[[D20]], %[[D7]]{{\[}}%{{.*}}, %[[D21]], %[[D22]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32> + +// ----- + +#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}> +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> +#pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> +module attributes {hal.device.targets = [#device_target_cuda]} { + hal.executable @transpose_3d_diff_dispatch_0_generic_10x768x2048 { + hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb { + hal.executable.export public @transpose_3d_diff_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_diff_dispatch_0_generic_10x768x2048() { + %c256 = arith.constant 256 : index + %c10 = arith.constant 10 : index + %c768 = arith.constant 768 : index + %c2048 = arith.constant 2048 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + scf.for %arg0 = %workgroup_id_z to %c10 step %workgroup_count_z { + scf.for %arg1 = %workgroup_id_y to %c768 step %workgroup_count_y { + %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x] + %4 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_count_x] + scf.for %arg2 = %3 to %c2048 step %4 { + %5 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg2, %arg1], sizes = [1, %c256, 1], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<1x?x1xf32> + %6 = flow.dispatch.tensor.load %1, offsets = [%arg2, %arg1, %arg0], sizes = [%c256, 1, 1], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor + %7 = linalg.init_tensor [1, 1, 256] : tensor<1x1x256xf32> + %8 = tensor.cast %5 : tensor<1x?x1xf32> to tensor<1x256x1xf32> + %9 = tensor.cast %6 : tensor to tensor<256x1x1xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d2, d1, d0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %9 : tensor<1x256x1xf32>, tensor<256x1x1xf32>) outs(%7 : tensor<1x1x256xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %12 = arith.addf %arg3, %arg4 : f32 + linalg.yield %12 : f32 + } -> tensor<1x1x256xf32> + %11 = tensor.cast %10 : tensor<1x1x256xf32> to tensor<1x1x?xf32> + flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 1, %c256], strides = [1, 1, 1] : tensor<1x1x?xf32> -> !flow.dispatch.tensor + } + } + } + return + } + } + } + } +} + +// CHECK-LABEL: hal.executable public @transpose_3d_diff_dispatch_0_generic_10x768x2048 { +// CHECK-NOT: gpu.barrier +// CHECK-NOT: memref.alloc +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_to_gpu.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_to_gpu.mlir index 67cab0238ab0..0b2e30f92572 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_to_gpu.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_to_gpu.mlir @@ -29,15 +29,14 @@ func.func @ksplitmatmul_basic(%a: memref<128x16x256xf32>) -> vector<16x1x8xf32> %0 = vector.transfer_read %a[%c2, %c3, %c4], %cst {in_bounds = [true, true, true]} : memref<128x16x256xf32>, vector<16x1x8xf32> return %0 : vector<16x1x8xf32> } -// CHECK-DAG:#[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 4096 + d1 + 8964)> // CHECK-LABEL: func.func @ksplitmatmul_basic // CHECK-DAG: %[[ID:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[M:.*]] = memref.subview // CHECK-SAME:[2, 3, 4] [16, 1, 8] [1, 1, 1] -// CHECK-SAME:memref<128x16x256xf32> to memref<16x8xf32, #[[$MAP]]> +// CHECK-SAME:memref<128x16x256xf32> to memref<16x8xf32, strided<[4096, 1], offset: 8964>> // CHECK: vector.transfer_read %[[M]][%[[ID]], %[[ID]]] -// CHECK-SAME: {in_bounds = [true, true]} : memref<16x8xf32, #[[$MAP]]>, vector<16x8xf32> +// CHECK-SAME: {in_bounds = [true, true]} : memref<16x8xf32, strided<[4096, 1], offset: 8964>>, vector<16x8xf32> // CHECK: vector.broadcast %{{.*}} : vector<16x8xf32> to vector<1x16x8xf32> // CHECK: vector.transpose %{{.*}} [1, 0, 2] : vector<1x16x8xf32> to vector<16x1x8xf32> // CHECK: return %{{.*}} : vector<16x1x8xf32> @@ -72,15 +71,14 @@ func.func @ksplitmatmul_4D(%a: memref<128x16x32x256xf32>) -> vector<16x1x1x8xf32 %0 = vector.transfer_read %a[%c2, %c3, %c4, %c5], %cst {in_bounds = [true, true, true, true]} : memref<128x16x32x256xf32>, vector<16x1x1x8xf32> return %0 : vector<16x1x1x8xf32> } -// CHECK-DAG:#[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 131072 + d1 + 287749)> // CHECK-LABEL: func.func @ksplitmatmul_4D // CHECK-DAG: %[[ID:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[M:.*]] = memref.subview // CHECK-SAME:[2, 3, 4, 5] [16, 1, 1, 8] [1, 1, 1, 1] -// CHECK-SAME: memref<128x16x32x256xf32> to memref<16x8xf32, #[[$MAP]]> +// CHECK-SAME: memref<128x16x32x256xf32> to memref<16x8xf32, strided<[131072, 1], offset: 287749>> // CHECK: vector.transfer_read %[[M]][%[[ID]], %[[ID]]] -// CHECK-SAME: {in_bounds = [true, true]} : memref<16x8xf32, #[[$MAP]]>, vector<16x8xf32> +// CHECK-SAME: {in_bounds = [true, true]} : memref<16x8xf32, strided<[131072, 1], offset: 287749>>, vector<16x8xf32> // CHECK: vector.broadcast %{{.*}} : vector<16x8xf32> to vector<1x1x16x8xf32> // CHECK: vector.transpose %{{.*}} [2, 0, 1, 3] : vector<1x1x16x8xf32> to vector<16x1x1x8xf32> // CHECK: return %{{.*}} : vector<16x1x1x8xf32> @@ -96,15 +94,14 @@ func.func @ksplitmatmul_4D_lower_rank_read(%a: memref<128x512x32x256xf32>) -> ve %0 = vector.transfer_read %a[%c2, %c3, %c4, %c5], %cst {in_bounds = [true, true, true]} : memref<128x512x32x256xf32>, vector<16x1x8xf32> return %0 : vector<16x1x8xf32> } -// CHECK-DAG:#[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 8192 + d1 + 8414213)> // CHECK-LABEL: func.func @ksplitmatmul_4D_lower_rank_read // CHECK-DAG: %[[ID:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[M:.*]] = memref.subview // CHECK-SAME:[2, 3, 4, 5] [1, 16, 1, 8] [1, 1, 1, 1] -// CHECK-SAME: memref<128x512x32x256xf32> to memref<16x8xf32, #[[$MAP]]> +// CHECK-SAME: memref<128x512x32x256xf32> to memref<16x8xf32, strided<[8192, 1], offset: 8414213>> // CHECK: vector.transfer_read %[[M]][%[[ID]], %[[ID]]] -// CHECK-SAME: {in_bounds = [true, true]} : memref<16x8xf32, #[[$MAP]]>, vector<16x8xf32> +// CHECK-SAME: {in_bounds = [true, true]} : memref<16x8xf32, strided<[8192, 1], offset: 8414213>>, vector<16x8xf32> // CHECK: vector.broadcast %{{.*}} : vector<16x8xf32> to vector<1x16x8xf32> // CHECK: vector.transpose %{{.*}} [1, 0, 2] : vector<1x16x8xf32> to vector<16x1x8xf32> // CHECK: return %{{.*}} : vector<16x1x8xf32> @@ -143,15 +140,14 @@ func.func @ksplitmatmul_4D_allone(%a: memref<128x16x32x256xf32>) -> vector<1x1x1 return %0 : vector<1x1x1x1xf32> } -// CHECK-DAG:#[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 131072 + d1 * 8192 + 287749)> // CHECK-LABEL: func.func @ksplitmatmul_4D_allone // CHECK-DAG: %[[ID:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[M:.*]] = memref.subview // CHECK-SAME:[2, 3, 4, 5] [1, 1, 1, 1] [1, 1, 1, 1] -// CHECK-SAME: memref<128x16x32x256xf32> to memref<1x1xf32, #[[$MAP]]> +// CHECK-SAME: memref<128x16x32x256xf32> to memref<1x1xf32, strided<[131072, 8192], offset: 287749>> // CHECK: vector.transfer_read %[[M]][%[[ID]], %[[ID]]] -// CHECK-SAME: {in_bounds = [true, true]} : memref<1x1xf32, #[[$MAP]]>, vector<1x1xf32> -// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x1x1xf32> +// CHECK-SAME: {in_bounds = [true]} : memref<1x1xf32, strided<[131072, 8192], offset: 287749>>, vector<1xf32> +// CHECK: vector.broadcast %{{.*}} : vector<1xf32> to vector<1x1x1x1xf32> // CHECK-NOT: vector.transpose // CHECK: return %{{.*}} : vector<1x1x1x1xf32> diff --git a/compiler/src/iree/compiler/Codegen/Passes.cpp b/compiler/src/iree/compiler/Codegen/Passes.cpp index f7eb1ea7c614..3a99195d5099 100644 --- a/compiler/src/iree/compiler/Codegen/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/Passes.cpp @@ -22,7 +22,7 @@ void registerCodegenPasses() { registerPasses(); registerSandboxPasses(); - static PassPipelineRegistration<> LinalgLLVMVPipeline( + static PassPipelineRegistration<> LinalgLLVMPipeline( "iree-codegen-linalg-to-llvm-pipeline", "Runs the progressive lowering pipeline from Linalg to LLVM", [](OpPassManager &passManager) { @@ -49,6 +49,20 @@ void registerCodegenPasses() { [](OpPassManager &passManager) { buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false); }); + + static PassPipelineRegistration<> LLVMCPULinkingPipeline( + "iree-codegen-llvmcpu-linking-pipeline", + "Runs the LLVMCPU HAL executable linking pipeline", + [](OpPassManager &passManager) { + buildLLVMCPULinkingPassPipeline(passManager); + }); + + static PassPipelineRegistration<> VMVXLinkingPipeline( + "iree-codegen-vmvx-linking-pipeline", + "Runs the VMVX HAL executable linking pipeline", + [](OpPassManager &passManager) { + buildVMVXLinkingPassPipeline(passManager); + }); } /// Hook to verify the lowering configuration and translation info for an diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h index b257681a4983..b1eb706265d2 100644 --- a/compiler/src/iree/compiler/Codegen/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Passes.h @@ -97,10 +97,6 @@ createRemoveSingleIterationLoopPass(); std::unique_ptr> createConvertToDestinationPassingStylePass(); -/// Creates a pass to vectorize a very specific form of linalg.conv ops. -std::unique_ptr> -createLinalgToVectorVectorizeConvPass(); - /// Creates a pass to vectorize a very specific form of tensor.pad ops with /// control flows. std::unique_ptr> createVectorizePadPass(); @@ -133,7 +129,7 @@ createGPUDistributeSharedMemoryCopy(); /// Apply software pipelining. std::unique_ptr> createGPUPipeliningPass( - unsigned depth = 1); + bool epiloguePeeling = true, unsigned depth = 1); /// Converts vector ops to gpu dialect. std::unique_ptr> createWorkGroupSwizzle( @@ -315,6 +311,17 @@ void addCPUAArchDoubleTilingExpertPassPipeline(OpPassManager &passManager); /// module within the IREE::HAL::ExecutableOp. void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager); +//----------------------------------------------------------------------------// +// LLVMCPU Linking Passes and Pipelines +//----------------------------------------------------------------------------// + +/// Links LLVMCPU HAL executables within the top-level program module. +std::unique_ptr> +createLLVMCPULinkExecutablesPass(); + +/// Populates passes needed to link HAL executables across LLVMCPU targets. +void buildLLVMCPULinkingPassPipeline(OpPassManager &passManager); + //------------------------------------------------------------------------------ // LLVMGPU //------------------------------------------------------------------------------ @@ -370,16 +377,16 @@ std::unique_ptr> createConvertToROCDLPass(); /// Perform tiling and distribution to threads. std::unique_ptr> createLLVMGPUTileAndDistribute( - bool distributeToWarp = false, - GPUPromoteSharedMemPattern promoteSharedMemPattern = - GPUPromoteSharedMemPattern::ContractionOpPattern); + bool distributeToWarp = false); std::unique_ptr> createLLVMGPUTileTensor( bool distributeToWarp = false); std::unique_ptr> createLLVMGPUDistribute(); -std::unique_ptr> createLLVMGPUTensorAlloc(); +std::unique_ptr> createLLVMGPUTensorAlloc( + GPUPromoteSharedMemPattern promoteSharedMemPattern = + GPUPromoteSharedMemPattern::ContractionOpPattern); /// Create pass calling the dynamic pipeline for LLVMGPU. std::unique_ptr> @@ -504,6 +511,16 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath); // at the bufferized linalg level. std::unique_ptr createVMVXLowerLinalgMicrokernelsPass(); +//----------------------------------------------------------------------------// +// VMVX Linking Passes and Pipelines +//----------------------------------------------------------------------------// + +/// Links VMVX HAL executables within the top-level program module. +std::unique_ptr> createVMVXLinkExecutablesPass(); + +/// Populates passes needed to link HAL executables across VMVX targets. +void buildVMVXLinkingPassPipeline(OpPassManager &passManager); + //------------------------------------------------------------------------------ // Test passes //------------------------------------------------------------------------------ diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td index 80ac1da17563..8e47d8a950ad 100644 --- a/compiler/src/iree/compiler/Codegen/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Passes.td @@ -124,14 +124,6 @@ def RemoveSingleIterationLoop : let constructor = "mlir::iree_compiler::createRemoveSingleIterationLoopPass()"; } -// TODO: Rename argument to be fully qualified. -def LinalgToVectorVectorizeConv : - Pass<"iree-codegen-vectorize-linalg-conv", "func::FuncOp"> { - let summary = "Vectorize a very specific form of linalg.conv"; - let constructor = - "mlir::iree_compiler::createLinalgToVectorVectorizeConvPass()"; -} - def TensorToVectorVectorizePad : Pass<"iree-codegen-vectorize-tensor-pad", "func::FuncOp"> { let summary = "Vectorize a very specific form of tensor.pad with " @@ -163,6 +155,11 @@ def GPUDistributeSharedMemoryCopy : def GPUPipelining : Pass<"iree-gpu-pipelining", "func::FuncOp"> { let summary = "Pass to do software pipelining."; let constructor = "mlir::iree_compiler::createGPUPipeliningPass()"; + let options = [ + Option<"epiloguePeeling", "epilogue-peeling", "bool", + /*default=*/"true", + "Try to use un-peeling epilogue when false, peeled epilouge o.w.">, + ]; } def WorkGroupSwizzle : @@ -302,6 +299,12 @@ def VectorContractCustomKernels : ]; } +def LLVMCPULinkExecutables : + Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> { + let summary = "Links LLVMCPU HAL executables within the top-level program module."; + let constructor = "mlir::iree_compiler::createLLVMCPULinkExecutablesPass()"; +} + //------------------------------------------------------------------------------ // LLVMGPU //------------------------------------------------------------------------------ @@ -473,6 +476,12 @@ def VMVXLowerLinalgMicrokernels : ]; } +def VMVXLinkExecutables : + Pass<"iree-vmvx-link-executables", "mlir::ModuleOp"> { + let summary = "Links VMVX HAL executables within the top-level program module."; + let constructor = "mlir::iree_compiler::createVMVXLinkExecutablesPass()"; +} + //------------------------------------------------------------------------------ // Test Passes //------------------------------------------------------------------------------ diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp index f6d1b32e4b22..eeab8ef65db3 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp @@ -15,6 +15,7 @@ #include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #define DEBUG_TYPE "iree-spirv-amd-config" @@ -25,7 +26,7 @@ namespace detail { static LogicalResult setAMDMatmulConfig(linalg::LinalgOp op, int subgroupSize) { const std::array workgroupXY = {subgroupSize / 2, 8}; - const std::array threadMNK = {8, 4, 32}; + const std::array threadMNK = {8, 4, 16}; return setMatmulOpConfig(op, subgroupSize, workgroupXY, threadMNK, /*useWorkgroupMemory=*/true); } @@ -55,6 +56,18 @@ LogicalResult setAMDCodeGenConfig(const spirv::TargetEnv &targetEnv, .Case([subgroupSize](auto op) { return setAMDMatmulConfig(op, subgroupSize); }) + .Case([subgroupSize](auto op) { + bool hasPaddedInput = + op.image().template getDefiningOp(); + int bestTilingFactor = hasPaddedInput ? 16 : 32; + return setConvOpConfig(op, subgroupSize, bestTilingFactor); + }) + .Case([subgroupSize](auto op) { + bool hasPaddedInput = + op.image().template getDefiningOp(); + int bestTilingFactor = hasPaddedInput ? 16 : 32; + return setConvOpConfig(op, subgroupSize, bestTilingFactor); + }) .Default([](Operation *) { return success(); }); } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index dc551c691ace..14f41e0a1937 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -10,6 +10,7 @@ #include #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Codegen/Common/UserConfig.h" #include "iree/compiler/Codegen/Dialect/LoweringConfig.h" #include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" @@ -159,9 +160,11 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp, tileSizes.push_back(invocationTileSizes); // Tiling along reduction dimensions if (isa(linalgOp)) { - tileSizes.push_back({0, 0, 0, 0, 1, 1, 4}); + tileSizes.push_back({0, 0, 0, 0, 1, 1, 4}); // (N, OH, OW, OC, FH, FW, IC) + tileSizes.push_back({0, 1, 0, 0}); } else if (isa(linalgOp)) { - tileSizes.push_back({0, 0, 0, 0, 1, 1}); + tileSizes.push_back({0, 0, 0, 0, 1, 1}); // (N, OH, OW, C, FH, FW) + tileSizes.push_back({0, 1, 0, 0}); } else { return success(); } @@ -211,7 +214,7 @@ LogicalResult setMatmulOpConfig(linalg::LinalgOp op, int64_t subgroupSize, int bIndex = -1, mIndex = -1, nIndex = -1, kIndex = -1; int lastParallelDim = -1; for (unsigned i = 0; i < op.getNumLoops(); ++i) { - if (isReductionIterator(op.getIteratorTypes()[i])) { + if (linalg::isReductionIterator(op.getIteratorTypes()[i])) { kIndex = i; continue; } @@ -663,8 +666,10 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits, for (const auto &it : llvm::enumerate(linalgOp.getIteratorTypes())) { auto i = it.index(); if (loopBounds[i] % 4 != 0) continue; - if (isReductionIterator(it.value()) || workgroupTileSizes[i] == 0) + if (linalg::isReductionIterator(it.value()) || + workgroupTileSizes[i] == 0) { loopTileSizes[it.index()] = 4; + } } if (llvm::any_of(loopTileSizes, [](int64_t s) { return s != 0; })) { tileSizes.push_back(loopTileSizes); @@ -682,7 +687,15 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits, /// Sets the CodeGen configuration as attributes to the given `rootOp` if it's a /// known Linalg matmul/convolution op with good configurations. static LogicalResult setSPIRVOpConfig(const spirv::TargetEnv &targetEnv, + func::FuncOp entryPointFn, Operation *rootOp) { + if (IREE::Codegen::CompilationInfoAttr compilationInfo = + getCompilationInfo(rootOp)) { + // If the op already has a lowering configuration specified from the + // original source by the user, then use it directly. + return setUserConfig(entryPointFn, rootOp, compilationInfo); + } + LogicalResult result = success(); // First try to find a proper CodeGen configuration to tile and vectorize for // the current target architecture. @@ -796,7 +809,8 @@ LogicalResult initSPIRVLaunchConfig(ModuleOp module) { // Try to find a configuration according to a matmul/convolution op and use // it as the root op. for (Operation *computeOp : computeOps) { - if (failed(setSPIRVOpConfig(targetEnv, computeOp))) return failure(); + if (failed(setSPIRVOpConfig(targetEnv, funcOp, computeOp))) + return failure(); // Check if the op configuration was set. if (!getLoweringConfig(computeOp)) continue; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVCreateFastSlowPath.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVCreateFastSlowPath.cpp index ffe55be42f10..3a4da85a238c 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVCreateFastSlowPath.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVCreateFastSlowPath.cpp @@ -57,7 +57,7 @@ struct CreateFastSlowPath final : public OpRewritePattern { LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { // Flow tiled and distributed loops do not carry values. - if (!llvm::empty(forOp.getIterOpOperands())) return failure(); + if (!forOp.getIterOpOperands().empty()) return failure(); Block *forBody = forOp.getBody(0); // Find the anchor tensor.pad op, from which we get the conditions for diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp index 2b929e61327e..1a3d17908dd0 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp @@ -26,6 +26,8 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -186,7 +188,7 @@ class SPIRVTilePass final : public SPIRVTileBase { } { // Tile reduction dimensions. - RewritePatternSet tilingPatterns(&getContext()); + RewritePatternSet tilingPatterns(context); populateTilingReductionPatterns(tilingPatterns); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns)))) { @@ -214,6 +216,53 @@ class SPIRVTilePass final : public SPIRVTileBase { }); } + { // Tile convolution output window dimension by 1 to prepare downsizing. + SmallVector convOps; + funcOp.walk([&convOps](linalg::ConvolutionOpInterface convOp) { + convOps.push_back(convOp); + }); + for (linalg::ConvolutionOpInterface convOp : convOps) { + auto consumerOp = cast(*convOp); + OpBuilder builder(context); + SmallVector tileSizes = getTileSizes(consumerOp, 3); + auto identityLoopOrder = + llvm::to_vector<4>(llvm::seq(0, tileSizes.size())); + + FailureOr loopNest = + linalg::tileConsumerAndFuseProducers(builder, consumerOp, tileSizes, + identityLoopOrder, llvm::None); + if (failed(loopNest)) { + consumerOp.emitOpError("failed tiling and fusing producers"); + return signalPassFailure(); + } + + consumerOp->replaceAllUsesWith(loopNest->getRootOpReplacementResults()); + + // Fully unroll the generated loop. This allows us to remove the loop + // for parallel output window dimension, so it helps future vector + // transformations. + if (!loopNest->getLoopOps().empty()) { + assert(loopNest->getLoopOps().size() == 1); + scf::ForOp loopOp = loopNest->getLoopOps().front(); + IntegerAttr ub; + if (!matchPattern(loopOp.getUpperBound(), m_Constant(&ub))) { + loopOp.emitOpError("upper bound should be a constant"); + return signalPassFailure(); + } + if (failed(mlir::loopUnrollByFactor(loopOp, ub.getInt()))) { + loopOp.emitOpError("failed unrolling by factor 1"); + return signalPassFailure(); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "--- After tiling convolution output window ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + { RewritePatternSet patterns(context); populateConcretizePadResultShapePatterns(context, patterns); @@ -225,6 +274,26 @@ class SPIRVTilePass final : public SPIRVTileBase { llvm::dbgs() << "\n\n"; }); } + + { // Downsize n-D (n > 1) convolutions to 1-D. + RewritePatternSet patterns(context); + linalg::populateDecomposeConvolutionPatterns(patterns); + // Downsizing creates consecutive extract/insert slice ops. Merge them. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + // Pull in patterns to fold constant insert/extract slice op parameters. + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); + // Pull in scf.for op canonicalization patterns to help hoisting across + // multiple loops and remove loop carried values unused in the body. + scf::ForOp::getCanonicalizationPatterns(patterns, context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After Downsizing N-D convolution to 1-D ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp index 6b51c7afed14..79b826f35b43 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp @@ -121,13 +121,13 @@ struct ConvertVectorContractOp final LogicalResult matchAndRewrite( vector::ContractionOp contractOp, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - if (!llvm::empty(contractOp.getMasks())) return failure(); + if (!contractOp.getMasks().empty()) return failure(); // Check that this is a matmul operation. auto iterators = contractOp.getIteratorTypes().getValue(); - if (iterators.size() != 3 || !isParallelIterator(iterators[0]) || - !isParallelIterator(iterators[1]) || - !isReductionIterator(iterators[2])) { + if (iterators.size() != 3 || !vector::isParallelIterator(iterators[0]) || + !vector::isParallelIterator(iterators[1]) || + !vector::isReductionIterator(iterators[2])) { return failure(); } if (contractOp.getKind() != vector::CombiningKind::ADD) return failure(); @@ -233,19 +233,19 @@ struct SPIRVVectorToCooperativeOpsPass final // a result of performing cooperative matrix conversions earlier (it needs // to be done before FlattenMemRefSubspanPass because we need 2-D MemRefs) // and conversions spreading across upstream and IREE repos.. - typeConverter.addConversion( - [&typeConverter](MemRefType type) -> Optional { - if (!type.hasStaticShape()) return llvm::None; - // In IREE all MemRefs are originated from subspan ops, which should - // have identity layout. - if (!type.getLayout().isIdentity()) return llvm::None; - auto storage = spirv::mapMemorySpaceToVulkanStorageClass( - type.getMemorySpaceAsInt()); - auto flattenedType = MemRefType::get( - ShapedType::kDynamicSize, type.getElementType(), AffineMap(), - spirv::StorageClassAttr::get(type.getContext(), *storage)); - return typeConverter.convertType(flattenedType); - }); + typeConverter.addConversion([&typeConverter]( + MemRefType type) -> Optional { + if (!type.hasStaticShape()) return llvm::None; + // In IREE all MemRefs are originated from subspan ops, which should + // have identity layout. + if (!type.getLayout().isIdentity()) return llvm::None; + auto storage = + spirv::mapMemorySpaceToVulkanStorageClass(type.getMemorySpaceAsInt()); + auto flattenedType = MemRefType::get( + ShapedType::kDynamicSize, type.getElementType(), AffineMap(), + spirv::StorageClassAttr::get(type.getContext(), *storage)); + return typeConverter.convertType(flattenedType); + }); // Add unrealized conversion cast ops to bridge type conversions: we are // only converting the cooperative matrix subset; the rest needs to be done diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp index b3582127dbd8..5d6f41716fbc 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp @@ -86,7 +86,7 @@ Optional> getNativeVectorShape(Operation *op) { } else if (auto contractOp = dyn_cast(op)) { unsigned lastParallelDim = 0; for (const auto &it : llvm::enumerate(contractOp.getIteratorTypes())) { - if (isParallelIterator(it.value())) lastParallelDim = it.index(); + if (vector::isParallelIterator(it.value())) lastParallelDim = it.index(); } SmallVector nativeSize(contractOp.getIteratorTypes().size(), 1); SmallVector bounds; @@ -122,6 +122,7 @@ void populateVectorizationPatterns(RewritePatternSet &patterns) { linalg::LinalgTransformationFilter f; VectorizationPatterns::insert(patterns, opt, f); + linalg::populateConvolutionVectorizationPatterns(patterns); patterns.add( patterns.getContext(), f.addOpFilter(), opt); @@ -160,7 +161,6 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase { RewritePatternSet patterns(context); populateVectorizationPatterns(patterns); // Pull in additional vectorization patterns in IREE. - populateLinalgToVectorVectorizeConvPatterns(context, patterns); populateVectorizePadPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); @@ -192,6 +192,25 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase { llvm::dbgs() << "\n\n"; }); + // Fold tensor.extract_slice/insert_slice ops into transfer ops. This helps + // to remove those tensor slice ops so that we can enable further vector op + // transformations. + { + RewritePatternSet patterns(context); + vector::TransferReadOp::getCanonicalizationPatterns(patterns, context); + vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "--- After folding tensor extract/insert slice ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + // Lower vector.multi_dimension early if any operand is a transpose op. // The lowering itself generates transpose ops. This helps to cancel // transpose ops. vector.multi_reduction is arguably a higher level op and @@ -209,8 +228,7 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase { RewritePatternSet patterns(context); vector::populateVectorMultiReductionLoweringPatterns( patterns, vector::VectorMultiReductionLowering::InnerParallel); - FrozenRewritePatternSet frozenSet(std::move(patterns)); - applyOpPatternsAndFold(reductionOps, frozenSet, + applyOpPatternsAndFold(reductionOps, std::move(patterns), /*strict=*/false); } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD index 73d6065f1bc6..d6d807735f83 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD @@ -29,6 +29,7 @@ iree_lit_test_suite( "config_mali_reduction.mlir", "config_nvidia_matmul.mlir", "config_nvidia_matmul_cooperative_ops.mlir", + "config_user.mlir", "convert_to_spirv.mlir", "create_fast_slow_path.mlir", "distribute_to_invocations.mlir", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt index a27c82532fbf..7227be83500f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt @@ -25,6 +25,7 @@ iree_lit_test_suite( "config_mali_reduction.mlir" "config_nvidia_matmul.mlir" "config_nvidia_matmul_cooperative_ops.mlir" + "config_user.mlir" "convert_to_spirv.mlir" "create_fast_slow_path.mlir" "distribute_to_invocations.mlir" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir index 306e64294869..b01a4d79c6f5 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir @@ -43,7 +43,7 @@ hal.executable @conv_112x112x512 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @conv_112x112x512 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -97,7 +97,7 @@ hal.executable @conv_112x112x32 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @conv_112x112x32 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -150,7 +150,7 @@ hal.executable @conv_16x16x16 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @conv_16x16x16 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -205,7 +205,7 @@ hal.executable @dwconv_28x28x144 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @dwconv_28x28x144 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -258,7 +258,7 @@ hal.executable @dwconv_4x4x8 { } } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @dwconv_4x4x8 // CHECK-SAME: translation_info = #[[TRANSLATION]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir index 9d30b4563d8b..abb1f17892a6 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir @@ -62,7 +62,7 @@ hal.executable private @conv_pointwise_112x112x32 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @conv_pointwise_112x112x32 // CHECK-SAME: translation_info = #[[TRANSLATION]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir index 94a9cdc9b75a..fd96fcee2dca 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir @@ -44,7 +44,7 @@ hal.executable @conv_112x112x512 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @conv_112x112x512 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -98,7 +98,7 @@ hal.executable @conv_112x112x32 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @conv_112x112x32 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -151,7 +151,7 @@ hal.executable @conv_16x16x16 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @conv_16x16x16 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -205,7 +205,7 @@ hal.executable @dwconv_28x28x144 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @dwconv_28x28x144 // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -260,7 +260,7 @@ hal.executable @dwconv_1x2x8 { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @dwconv_1x2x8 // CHECK-SAME: translation_info = #[[TRANSLATION]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir new file mode 100644 index 000000000000..a03da10df277 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir @@ -0,0 +1,55 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true}))' %s | FileCheck %s + +#compilation = #iree_codegen.compilation_info< + lowering_config = , + translation_info = , + workgroup_size = [16, 8, 1]> +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable public @user_config { + hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce, Unknown:IntegratedGPU, #spv.resource_limits< + max_compute_shared_memory_size = 16384, + max_compute_workgroup_invocations = 128, + max_compute_workgroup_size = [128, 128, 64], + subgroup_size = 32>> + }> { + hal.executable.export public @matmul_128x1024x256 layout(#pipeline_layout) + builtin.module { + func.func @matmul_128x1024x256() { + %cst = arith.constant 0.000000e+00 : f32 + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] + : !flow.dispatch.tensor -> tensor<128x256xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] + : !flow.dispatch.tensor -> tensor<256x1024xf32> + %15 = linalg.init_tensor [128, 1024] : tensor<128x1024xf32> + %16 = linalg.fill ins(%cst : f32) outs(%15 : tensor<128x1024xf32>) -> tensor<128x1024xf32> + %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", compilation_info = #compilation} + ins(%3, %4 : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%16 : tensor<128x1024xf32>) -> tensor<128x1024xf32> + flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [128, 1024], strides = [1, 1] : tensor<128x1024xf32> -> !flow.dispatch.tensor + return + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: hal.executable.export public @matmul_128x1024x256 +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK-SAME: workgroup_size = [16 : index, 8 : index, 1 : index] +// CHECK: linalg.fill +// CHECK-SAME: lowering_config = #[[CONFIG]] +// CHECK: linalg.matmul +// CHECK-SAME: lowering_config = #[[CONFIG]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir index 263918f8f577..c7e9b4790b44 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir @@ -9,7 +9,7 @@ #map6 = affine_map<(d0, d1, d2) -> (d0, d1)> #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -79,7 +79,7 @@ hal.executable private @matmul { // ----- #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -159,7 +159,7 @@ hal.executable private @conv_1d { #map7 = affine_map<(d0)[s0] -> (32, -d0 + s0)> #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -274,7 +274,7 @@ hal.executable private @conv_2d { // ----- #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -344,7 +344,7 @@ hal.executable private @conv_3d { #map7 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1092 + s0 + d1 * 78 + d2 * 6 + d3)> #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -411,7 +411,7 @@ module { // ----- #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -468,12 +468,12 @@ hal.executable @matvec { // CHECK: scf.for %[[IV:.+]] = %[[IDX]] to %[[UB]] step %[[DIMX]] // CHECK: %[[OUTPUT:.+]] = memref.subview %[[CVIEW]][%[[IV]]] [1] [1] // CHECK: linalg.fill -// CHECK-SAME: outs(%[[OUTPUT]] : memref<1xf32, #map3>) +// CHECK-SAME: outs(%[[OUTPUT]] : memref<1xf32, strided<[1], offset: ?>>) // CHECK: %[[IDX:.+]] = gpu.thread_id x // CHECK: %[[DIMX:.+]] = gpu.block_dim x // CHECK: scf.for %[[IV:.+]] = %[[IDX]] to %[[UB]] step %[[DIMX]] // CHECK: %[[INPUT:.+]] = memref.subview %[[AVIEW]][%[[IV]], 0] [1, 1024] [1, 1] // CHECK: %[[OUTPUT:.+]] = memref.subview %[[CVIEW]][%[[IV]]] [1] [1] // CHECK: linalg.matvec -// CHECK-SAME: ins(%[[INPUT]], %[[B]] : memref<1x1024xf32, #map2>, memref<1024xf32> -// CHECK-SAME: outs(%[[OUTPUT]] : memref<1xf32, #map3>) +// CHECK-SAME: ins(%[[INPUT]], %[[B]] : memref<1x1024xf32, strided<[1024, 1], offset: ?>>, memref<1024xf32> +// CHECK-SAME: outs(%[[OUTPUT]] : memref<1xf32, strided<[1], offset: ?>>) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir index b359ce5eb7f7..14a7d834530d 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-and-distribute))))' %s | FileCheck %s #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -66,10 +66,11 @@ hal.executable private @static_scatter_update_slice { // CHECK: scf.for %[[IV_X:.+]] = %[[TID_X]] to %{{.+}} step %[[DIM_X]] // CHECK: %[[T_UPDATE:.+]] = memref.subview %[[WG_UPDATE]][%[[IV_Y]], %[[IV_X]]] [1, 1] [1, 1] // CHECK: %[[T_UPDATE_CAST:.+]] = memref.cast %[[T_UPDATE]] -// CHECK: %[[T_INDEX:.+]] = memref.cast %[[WG_INDEX]] +// CHECK: %[[T_INDEX:.+]] = memref.subview %[[WG_INDEX]][%[[IV_Y]], 0] [1, 1] [1, 1] +// CHECK: %[[T_INDEX_CAST:.+]] = memref.cast %[[T_INDEX]] // CHECK: %[[T_TARGET:.+]] = memref.subview %[[WG_TARGET]][0, %[[IV_X]]] [100, 1] [1, 1] // CHECK: %[[T_TARGET_CAST:.+]] = memref.cast %[[T_TARGET]] // CHECK: iree_linalg_ext.scatter // CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[T_UPDATE_CAST]], %[[T_INDEX]] +// CHECK-SAME: ins(%[[T_UPDATE_CAST]], %[[T_INDEX_CAST]] // CHECK-SAME: outs(%[[T_TARGET_CAST]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir index ebb933e997f1..6b098272b3db 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-and-distribute, cse))))' %s | FileCheck %s #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir index 42ae34ed6979..f3e38e5a0852 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir @@ -18,7 +18,7 @@ hal.executable @matmul_256x1024x128 { max_compute_workgroup_size = [65535, 65535, 65535], subgroup_size = 32>>}> { hal.executable.export public @matmul_256x1024x128 ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info, + translation_info = #iree_codegen.translation_info, workgroup_size = [32 : index, 8 : index, 1 : index] } builtin.module { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir index 22dadffad056..e9d7aee0cf3b 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,iree-spirv-vectorize))))' --cse %s | FileCheck %s #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir index b50440c98dd1..e60db39eb1d7 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir @@ -1,7 +1,7 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,iree-spirv-vectorize))))' %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-spirv-vectorize))))' %s | FileCheck %s -#config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -17,13 +17,13 @@ hal.executable private @conv_static_shape_f32 { } builtin.module { func.func @conv_static_shape_f32() { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 %c112 = arith.constant 112 : index %c16 = arith.constant 16 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor %workgroup_id_x = hal.interface.workgroup.id[0] : index %workgroup_count_x = hal.interface.workgroup.count[0] : index %workgroup_id_y = hal.interface.workgroup.id[1] : index @@ -39,24 +39,16 @@ hal.executable private @conv_static_shape_f32 { %7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x] %8 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x] scf.for %arg2 = %7 to %c16 step %8 { - %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) - %10 = affine.min affine_map<(d0) -> (9, d0 * -2 + 225)>(%arg0)[] + %9 = flow.dispatch.tensor.load %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, 4, 4, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x4x4x16xf32> + %10 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) - %12 = affine.min affine_map<(d0) -> (9, d0 * -2 + 225)>(%arg1)[] - %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 8], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x?x?x8xf32> - %14 = affine.min affine_map<(d0) -> (16, -d0 + 16)>(%arg2)[] - %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 8, %14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<3x3x8x?xf32> - %16 = affine.min affine_map<(d0) -> (4, -d0 + 112)>(%arg0)[] - %17 = affine.min affine_map<(d0) -> (4, -d0 + 112)>(%arg1)[] - %18 = affine.min affine_map<(d0) -> (-d0 + 112, 4)>(%arg0)[] - %19 = affine.min affine_map<(d0) -> (-d0 + 112, 4)>(%arg1)[] - %20 = affine.min affine_map<(d0) -> (-d0 + 16, 16)>(%arg2)[] - %21 = linalg.init_tensor [1, %18, %19, %20] : tensor<1x?x?x?xf32> - %22 = linalg.fill ins(%cst : f32) outs(%21 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32> - %23 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} - ins(%13, %15 : tensor<1x?x?x8xf32>, tensor<3x3x8x?xf32>) - outs(%22 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32> - flow.dispatch.tensor.store %23, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %14], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor + %12 = flow.dispatch.tensor.load %0, offsets = [0, %10, %11, 0], sizes = [1, 9, 9, 8], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x9x9x8xf32> + %13 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 8, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<3x3x8x16xf32> + %14 = linalg.fill ins(%cst : f32) outs(%9 : tensor<1x4x4x16xf32>) -> tensor<1x4x4x16xf32> + %15 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, lowering_config = #config, strides = dense<2> : tensor<2xi64>} + ins(%12, %13 : tensor<1x9x9x8xf32>, tensor<3x3x8x16xf32>) + outs(%14 : tensor<1x4x4x16xf32>) -> tensor<1x4x4x16xf32> + flow.dispatch.tensor.store %15, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, 4, 4, 16], strides = [1, 1, 1, 1] : tensor<1x4x4x16xf32> -> !flow.dispatch.tensor } } } @@ -88,8 +80,8 @@ hal.executable private @conv_static_shape_f32 { // ----- -#config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -105,46 +97,35 @@ hal.executable private @depthwise_conv_static_shape_f32 { } builtin.module { func.func @depthwise_conv_static_shape_f32() { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 %c56 = arith.constant 56 : index %c96 = arith.constant 96 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor %workgroup_id_x = hal.interface.workgroup.id[0] : index %workgroup_count_x = hal.interface.workgroup.count[0] : index %workgroup_id_y = hal.interface.workgroup.id[1] : index %workgroup_count_y = hal.interface.workgroup.count[1] : index %workgroup_id_z = hal.interface.workgroup.id[2] : index %workgroup_count_z = hal.interface.workgroup.count[2] : index - %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_z] - %4 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_z] - scf.for %arg0 = %3 to %c56 step %4 { - %5 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_y] - %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_y] - scf.for %arg1 = %5 to %c56 step %6 { - %7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x] - %8 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x] - scf.for %arg2 = %7 to %c96 step %8 { - %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) - %10 = affine.min affine_map<(d0) -> (9, d0 * -2 + 113)>(%arg0)[] - %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) - %12 = affine.min affine_map<(d0) -> (9, d0 * -2 + 113)>(%arg1)[] - %13 = affine.min affine_map<(d0) -> (16, -d0 + 96)>(%arg2)[] - %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x?x?x?xf32> - %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, %13], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<3x3x?xf32> - %16 = affine.min affine_map<(d0) -> (4, -d0 + 56)>(%arg0)[] - %17 = affine.min affine_map<(d0) -> (4, -d0 + 56)>(%arg1)[] - %18 = affine.min affine_map<(d0) -> (-d0 + 56, 4)>(%arg0)[] - %19 = affine.min affine_map<(d0) -> (-d0 + 56, 4)>(%arg1)[] - %20 = affine.min affine_map<(d0) -> (-d0 + 96, 16)>(%arg2)[] - %21 = linalg.init_tensor [1, %18, %19, %20] : tensor<1x?x?x?xf32> - %22 = linalg.fill ins(%cst : f32) outs(%21 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32> - %23 = linalg.depthwise_conv_2d_nhwc_hwc {lowering_config = #config, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} - ins(%14, %15 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) - outs(%22 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32> - flow.dispatch.tensor.store %23, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %13], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor + scf.for %arg0 = %workgroup_id_z to %c56 step %workgroup_count_z { + %3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y] + %4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y] + scf.for %arg1 = %3 to %c56 step %4 { + %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x] + %6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x] + scf.for %arg2 = %5 to %c96 step %6 { + %7 = flow.dispatch.tensor.load %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, 1, 8, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x1x8x32xf32> + %8 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %10 = flow.dispatch.tensor.load %0, offsets = [0, %8, %9, %arg2], sizes = [1, 3, 17, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor -> tensor<1x3x17x32xf32> + %11 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, 32], strides = [1, 1, 1] : !flow.dispatch.tensor -> tensor<3x3x32xf32> + %12 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x1x8x32xf32>) -> tensor<1x1x8x32xf32> + %13 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, lowering_config = #config, strides = dense<2> : tensor<2xi64>} + ins(%10, %11 : tensor<1x3x17x32xf32>, tensor<3x3x32xf32>) outs(%12 : tensor<1x1x8x32xf32>) -> tensor<1x1x8x32xf32> + flow.dispatch.tensor.store %13, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, 1, 8, 32], strides = [1, 1, 1, 1] : tensor<1x1x8x32xf32> -> !flow.dispatch.tensor } } } @@ -161,21 +142,22 @@ hal.executable private @depthwise_conv_static_shape_f32 { // check tiling loop along filter height/width and input channel // CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) +// CHECK-SAME: -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) // CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) +// CHECK-SAME: -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) -// CHECK: vector.fma +// CHECK-COUNT-5: vector.transfer_read +// CHECK-COUNT-4: vector.fma // CHECK-COUNT-2: scf.yield // For linalg.depthwise_conv_2d_nhwc_hwc -// CHECK: vector.transfer_write +// CHECK-COUNT-4: vector.transfer_write // ----- -#config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -271,7 +253,11 @@ hal.executable private @low_padded_conv { // Loop nest for thread tiling and reduction tiling // CHECK-COUNT-4: scf.for // Vector code +// CHECK-COUNT-9: vector.transfer_read // CHECK-COUNT-6: vector.fma +// Fused elementwise ops +// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-2: arith.subf // CHECK: } else { @@ -279,15 +265,19 @@ hal.executable private @low_padded_conv { // Loop nest for thread tiling and reduction tiling // CHECK-COUNT-4: scf.for // CHECK: scf.if -// CHECK-NEXT: vector.transfer_read +// CHECK-COUNT-3: vector.transfer_read // CHECK: scf.if -// CHECK-NEXT: vector.transfer_read +// CHECK-COUNT-3: vector.transfer_read +// CHECK-COUNT-3: vector.transfer_read // CHECK-COUNT-6: vector.fma +// Fused elementwise ops +// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-2: arith.subf // ----- -#config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -392,10 +382,8 @@ hal.executable private @low_high_padded_depthwise_conv { // Loop nest for thread tiling and reduction tiling // CHECK-COUNT-4: scf.for // Vector code -// CHECK-COUNT-2: vector.transfer_read -// CHECK: vector.fma -// CHECK: vector.transfer_read -// CHECK: vector.fma +// CHECK-COUNT-3: vector.transfer_read +// CHECK-COUNT-2: vector.fma // CHECK: } else { @@ -403,8 +391,8 @@ hal.executable private @low_high_padded_depthwise_conv { // Loop nest for thread tiling and reduction tiling // CHECK-COUNT-4: scf.for // CHECK: scf.if -// CHECK-NEXT: vector.transfer_read +// CHECK: vector.transfer_read // CHECK: scf.if -// CHECK-NEXT: vector.transfer_read +// CHECK: vector.transfer_read // CHECK: vector.transfer_read // CHECK-COUNT-2: vector.fma diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir index 1135fef3d7ad..f86d27059418 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,iree-spirv-vectorize))))' %s | FileCheck %s #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, @@ -63,7 +63,7 @@ hal.executable private @matmul_static_shape_f16 { // ----- #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir index 0e99a8c22618..89abb9592c82 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-and-vectorize-to-cooperative-ops))))' %s | FileCheck %s #config = #iree_codegen.lowering_config -#translation = #iree_codegen.translation_info +#translation = #iree_codegen.translation_info #pipeline_layout = #hal.pipeline.layout, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir index e7ee378d7ad8..8098200bab65 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vector_to_cooperative_matrix.mlir @@ -53,8 +53,8 @@ hal.executable private @matmul_contract { hal.executable private @matmul_contract_licm { hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", { spv.target_env = #spv.target_env<#spv.vce, + [Shader, CooperativeMatrixNV, Int8, StorageBuffer8BitAccess], + [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix, SPV_KHR_8bit_storage]>, #spv.resource_limits>}> { builtin.module { // CHECK-LABEL: func.func @matmul_contract_licm @@ -99,8 +99,8 @@ hal.executable private @matmul_contract_licm { hal.executable private @matmul_contract_vector_memref { hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", { spv.target_env = #spv.target_env<#spv.vce, + [Shader, CooperativeMatrixNV, Int8, StorageBuffer8BitAccess], + [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix, SPV_KHR_8bit_storage]>, #spv.resource_limits>}> { builtin.module { // CHECK-LABEL: func.func @matmul_contract_vector_memref @@ -134,8 +134,8 @@ hal.executable private @matmul_contract_vector_memref { hal.executable private @const_elementwise_ops { hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", { spv.target_env = #spv.target_env<#spv.vce, + [Shader, CooperativeMatrixNV, Float16], + [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spv.resource_limits>}> { builtin.module { // CHECK-LABEL: func.func @const_elementwise_ops diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir index 35b07a04be09..31613e821132 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir @@ -1,79 +1,38 @@ // RUN: iree-opt --split-input-file --iree-spirv-vectorize %s | FileCheck %s -func.func @matmul_2x128x4() { - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c128 = arith.constant 128 : index - %c2 = arith.constant 2 : index - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %workgroup_count_x = hal.interface.workgroup.count[0] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %workgroup_count_y = hal.interface.workgroup.count[1] : index - %3 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_id_y] - %4 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_count_y] - scf.for %arg0 = %3 to %c2 step %4 { - %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x] - %6 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x] - scf.for %arg1 = %5 to %c128 step %6 { - %7 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [2, 4], strides = [1, 1] : !flow.dispatch.tensor -> tensor<2x4xf32> - %8 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [4, 128], strides = [1, 1] : !flow.dispatch.tensor -> tensor<4x128xf32> - %9 = linalg.init_tensor [2, 128] : tensor<2x128xf32> - %10 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %9) -> (tensor<2x128xf32>) { - %11 = scf.for %arg4 = %c0 to %c128 step %c4 iter_args(%arg5 = %arg3) -> (tensor<2x128xf32>) { - %12 = tensor.extract_slice %arg5[%arg2, %arg4] [1, 4] [1, 1] : tensor<2x128xf32> to tensor<1x4xf32> - %13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<1x4xf32>) -> tensor<1x4xf32> - %14 = tensor.extract_slice %7[%arg2, 0] [1, 4] [1, 1] : tensor<2x4xf32> to tensor<1x4xf32> - %15 = tensor.extract_slice %8[0, %arg4] [4, 4] [1, 1] : tensor<4x128xf32> to tensor<4x4xf32> - %16 = linalg.matmul ins(%14, %15 : tensor<1x4xf32>, tensor<4x4xf32>) outs(%13 : tensor<1x4xf32>) -> tensor<1x4xf32> - %17 = tensor.insert_slice %16 into %arg5[%arg2, %arg4] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<2x128xf32> - scf.yield %17 : tensor<2x128xf32> - } {iree.spirv.distribute_dim = 0 : index} - scf.yield %11 : tensor<2x128xf32> - } {iree.spirv.distribute_dim = 1 : index} - flow.dispatch.tensor.store %10, %2, offsets = [%arg0, %arg1], sizes = [%c2, %c128], strides = [1, 1] : tensor<2x128xf32> -> !flow.dispatch.tensor - } - } - return +func.func @matmul_1x4x4(%lhs: tensor<1x4xf32>, %rhs: tensor<4x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> { + %0 = linalg.matmul ins(%lhs, %rhs : tensor<1x4xf32>, tensor<4x4xf32>) outs(%init : tensor<1x4xf32>) -> tensor<1x4xf32> + return %0: tensor<1x4xf32> } -// CHECK-LABEL: func.func @matmul_2x128x4() +// CHECK-LABEL: func.func @matmul_1x4x4 +// CHECK-SAME: (%[[LHS:.+]]: tensor<1x4xf32>, %[[RHS:.+]]: tensor<4x4xf32>, %[[INIT:.+]]: tensor<1x4xf32>) -// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index - -// CHECK: scf.for %[[IV_Y:.+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %{{.+}}[%[[IV_Y]], 0] [1, 4] -// CHECK: %[[LHS_VECTOR:.+]] = vector.transfer_read %[[LHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]] -// CHECK: scf.for %[[IV_X:.+]] = %[[C0]] to %[[C128]] step %[[C4]] iter_args(%[[ACC_TILE:.+]] = -// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %{{.+}}[0, %[[IV_X]]] [4, 4] -// CHECK: %[[RHS_0_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]] -// CHECK: %[[RHS_1_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C1]], %[[C0]]], %[[PAD]] -// CHECK: %[[RHS_2_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C2]], %[[C0]]], %[[PAD]] -// CHECK: %[[RHS_3_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C3]], %[[C0]]], %[[PAD]] -// CHECK: %[[LHS_0_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][0] -// CHECK: %[[LHS_0_VECTOR:.+]] = vector.splat %[[LHS_0_SCALAR]] : vector<4xf32> -// CHECK: %[[FMA_0:.+]] = vector.fma %[[LHS_0_VECTOR]], %[[RHS_0_VECTOR]], %[[ZERO]] : vector<4xf32> -// CHECK: %[[LHS_1_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][1] -// CHECK: %[[LHS_1_VECTOR:.+]] = vector.splat %[[LHS_1_SCALAR]] : vector<4xf32> -// CHECK: %[[FMA_1:.+]] = vector.fma %[[LHS_1_VECTOR]], %[[RHS_1_VECTOR]], %[[FMA_0]] : vector<4xf32> -// CHECK: %[[LHS_2_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][2] -// CHECK: %[[LHS_2_VECTOR:.+]] = vector.splat %[[LHS_2_SCALAR]] : vector<4xf32> -// CHECK: %[[FMA_2:.+]] = vector.fma %[[LHS_2_VECTOR]], %[[RHS_2_VECTOR]], %[[FMA_1]] : vector<4xf32> -// CHECK: %[[LHS_3_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][3] -// CHECK: %[[LHS_3_VECTOR:.+]] = vector.splat %[[LHS_3_SCALAR]] : vector<4xf32> -// CHECK: %[[FMA_3:.+]] = vector.fma %[[LHS_3_VECTOR]], %[[RHS_3_VECTOR]], %[[FMA_2]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_3]], %[[ACC_TILE]][%[[IV_Y]], %[[IV_X]]] + +// CHECK: %[[LHS_VECTOR:.+]] = vector.transfer_read %[[LHS]][%[[C0]], %[[C0]]], %[[PAD]] +// CHECK: %[[RHS_0_VECTOR:.+]] = vector.transfer_read %[[RHS]][%[[C0]], %[[C0]]], %[[PAD]] +// CHECK: %[[RHS_1_VECTOR:.+]] = vector.transfer_read %[[RHS]][%[[C1]], %[[C0]]], %[[PAD]] +// CHECK: %[[RHS_2_VECTOR:.+]] = vector.transfer_read %[[RHS]][%[[C2]], %[[C0]]], %[[PAD]] +// CHECK: %[[RHS_3_VECTOR:.+]] = vector.transfer_read %[[RHS]][%[[C3]], %[[C0]]], %[[PAD]] +// CHECK: %[[INIT_VECTOR:.+]] = vector.transfer_read %[[INIT]][%[[C0]], %[[C0]]], %[[PAD]] +// CHECK: %[[LHS_0_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][0] +// CHECK: %[[LHS_0_VECTOR:.+]] = vector.splat %[[LHS_0_SCALAR]] : vector<4xf32> +// CHECK: %[[FMA_0:.+]] = vector.fma %[[LHS_0_VECTOR]], %[[RHS_0_VECTOR]], %[[INIT_VECTOR]] : vector<4xf32> +// CHECK: %[[LHS_1_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][1] +// CHECK: %[[LHS_1_VECTOR:.+]] = vector.splat %[[LHS_1_SCALAR]] : vector<4xf32> +// CHECK: %[[FMA_1:.+]] = vector.fma %[[LHS_1_VECTOR]], %[[RHS_1_VECTOR]], %[[FMA_0]] : vector<4xf32> +// CHECK: %[[LHS_2_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][2] +// CHECK: %[[LHS_2_VECTOR:.+]] = vector.splat %[[LHS_2_SCALAR]] : vector<4xf32> +// CHECK: %[[FMA_2:.+]] = vector.fma %[[LHS_2_VECTOR]], %[[RHS_2_VECTOR]], %[[FMA_1]] : vector<4xf32> +// CHECK: %[[LHS_3_SCALAR:.+]] = vector.extract %[[LHS_VECTOR]][3] +// CHECK: %[[LHS_3_VECTOR:.+]] = vector.splat %[[LHS_3_SCALAR]] : vector<4xf32> +// CHECK: %[[FMA_3:.+]] = vector.fma %[[LHS_3_VECTOR]], %[[RHS_3_VECTOR]], %[[FMA_2]] : vector<4xf32> +// CHECK: vector.transfer_write %[[FMA_3]], %[[INIT]][%[[C0]], %[[C0]]] // ----- @@ -197,11 +156,11 @@ func.func @matmul_2x8x128_fp16(%a: tensor<2x128xf16>, %b: tensor<128x8xf16>, %x: } // CHECK-LABEL: func.func @matmul_2x8x128_fp16 -// CHECK-SAME: (%{{.+}}: tensor<2x128xf16>, %{{.+}}: tensor<128x8xf16>, %[[X:.+]]: tensor<2x8xf16>, %[[Y:.+]]: tensor<2x8xf16>) +// CHECK-SAME: (%[[LHS:.+]]: tensor<2x128xf16>, %[[RHS:.+]]: tensor<128x8xf16>, %[[X:.+]]: tensor<2x8xf16>, %[[Y:.+]]: tensor<2x8xf16>) // CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<8xf16> // CHECK: %[[FOR:.+]]:2 = scf.for %arg4 = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%arg5 = %[[ZERO]], %arg6 = %[[ZERO]]) -// CHECK-COUNT-2: vector.transfer_read {{.+}} : tensor<2x8xf16>, vector<8xf16> -// CHECK-COUNT-8: vector.transfer_read {{.+}} : tensor<8x8xf16>, vector<8xf16> +// CHECK-COUNT-2: vector.transfer_read %[[LHS]]{{.+}} : tensor<2x128xf16>, vector<8xf16> +// CHECK-COUNT-8: vector.transfer_read %[[RHS]]{{.+}} : tensor<128x8xf16>, vector<8xf16> // CHECK-COUNT-32: vector.fma {{.+}} : vector<4xf16> // CHECK: %[[ISS0:.+]] = vector.insert_strided_slice %{{.+}}, %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> // CHECK: %[[ISS1:.+]] = vector.insert_strided_slice %{{.+}}, %[[ISS0]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16> diff --git a/compiler/src/iree/compiler/Codegen/Sandbox/BUILD b/compiler/src/iree/compiler/Codegen/Sandbox/BUILD index d2306fd4ea19..65886f9e9230 100644 --- a/compiler/src/iree/compiler/Codegen/Sandbox/BUILD +++ b/compiler/src/iree/compiler/Codegen/Sandbox/BUILD @@ -55,6 +55,8 @@ iree_compiler_cc_library( ":PassesIncGen", "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Utils", + "//llvm-external-projects/iree-dialects:IREELinalgExtPasses", + "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithmeticDialect", diff --git a/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt index 0fde1f57ed93..720357af6483 100644 --- a/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt @@ -45,6 +45,8 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen + IREELinalgExtPasses + IREELinalgExtTransforms LLVMSupport MLIRAffineDialect MLIRArithmeticDialect diff --git a/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp b/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp index 845f8aa15b3c..65e5aac6eb3d 100644 --- a/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp +++ b/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp @@ -4,6 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" +#include "iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h" #include "iree/compiler/Codegen/Dialect/LoweringConfig.h" #include "iree/compiler/Codegen/Sandbox/PassDetail.h" #include "iree/compiler/Codegen/Sandbox/Passes.h" @@ -12,8 +14,6 @@ #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -26,7 +26,9 @@ #include "mlir/Transforms/Passes.h" using namespace mlir; -using namespace mlir::linalg; +// using namespace mlir::linalg; + +using mlir::iree_compiler::IREE::LinalgExt::CodegenStrategy; #define DEBUG_TYPE "iree-linalg-tensor-codegen-driver" @@ -53,8 +55,9 @@ static FailureOr getRootOp(func::FuncOp funcOp) { /// Default method to initialize the tiling options in IREE. These could be /// overriden by the command line options if specified. For now the sentinel /// -1 is used for avoiding querying the lowering config. -static bool getTilingOptionsFromConfig(func::FuncOp funcOp, int64_t tilingLevel, - LinalgTilingOptions &tilingOptions) { +static bool getTilingOptionsFromConfig( + func::FuncOp funcOp, int64_t tilingLevel, + linalg::LinalgTilingOptions &tilingOptions) { if (tilingLevel != -1) { FailureOr rootOp = getRootOp(funcOp); if (failed(rootOp)) { @@ -115,10 +118,10 @@ static LogicalResult getPaddingDims(func::FuncOp funcOp, /// Default method to initialize the tiling options for fusion in IREE. These /// could be ovveridden by the command line options if specified. -static FailureOr getTileAndFuseOptionsFromConfig( - func::FuncOp funcOp, int64_t tilingLevel) { +static FailureOr +getTileAndFuseOptionsFromConfig(func::FuncOp funcOp, int64_t tilingLevel) { if (tilingLevel == -1) { - return LinalgTilingAndFusionOptions(); + return linalg::LinalgTilingAndFusionOptions(); } FailureOr rootOp = getRootOp(funcOp); @@ -127,7 +130,7 @@ static FailureOr getTileAndFuseOptionsFromConfig( iree_compiler::IREE::Codegen::LoweringConfigAttr loweringConfig = iree_compiler::getLoweringConfig(rootOp.value()); - LinalgTilingAndFusionOptions options; + linalg::LinalgTilingAndFusionOptions options; options.tileSizes.assign(loweringConfig.getTileSizeVals(tilingLevel)); options.tileInterchange.assign( loweringConfig.getTileInterchangeVals(tilingLevel)); @@ -260,12 +263,13 @@ void LinalgFusePass::runOnOperation() { func::FuncOp funcOp = getOperation(); // Set up tiling and vectorization options. - FailureOr defaultTilingOptions = + FailureOr defaultTilingOptions = getTileAndFuseOptionsFromConfig(funcOp, tilingLevel); if (failed(defaultTilingOptions)) { return signalPassFailure(); } - LinalgTilingAndFusionOptions tilingOptions = defaultTilingOptions.value(); + linalg::LinalgTilingAndFusionOptions tilingOptions = + defaultTilingOptions.value(); bool doTiling = !tilingOptions.tileSizes.empty(); if (!tileSizes.empty()) { doTiling = true; @@ -336,7 +340,7 @@ void LinalgFusePass::runOnOperation() { transposePaddingVectors.push_back(transposeVector); } - LinalgPaddingOptions paddingOptions; + linalg::LinalgPaddingOptions paddingOptions; paddingOptions.setPaddingValues(paddingValueAttributes); paddingOptions.setPaddingDimensions( SmallVector{paddingDimensions.begin(), paddingDimensions.end()}); @@ -354,6 +358,8 @@ void LinalgFusePass::runOnOperation() { // Created a nested OpPassManager and run. OpPassManager dynamicPM(func::FuncOp::getOperationName()); strategy.configurePassPipeline(dynamicPM, funcOp.getContext()); + dynamicPM.addPass( + iree_compiler::IREE::LinalgExt::createLinalgStrategyEnablePass()); if (failed(runPipeline(dynamicPM, funcOp))) { return signalPassFailure(); @@ -364,7 +370,7 @@ void LinalgSingleTilingExpertPass::runOnOperation() { func::FuncOp funcOp = getOperation(); // Set up tiling and vectorization options. - LinalgTilingOptions tilingOptions; + linalg::LinalgTilingOptions tilingOptions; bool doTiling = getTilingOptionsFromConfig(funcOp, tilingLevel, tilingOptions); if (!tileSizes.empty()) { @@ -399,7 +405,7 @@ void LinalgSingleTilingExpertPass::runOnOperation() { transposePaddingVectors.push_back(transposeVector); } - LinalgPaddingOptions paddingOptions; + linalg::LinalgPaddingOptions paddingOptions; paddingOptions.setPaddingValues(paddingValueAttributes); paddingOptions.setPackPaddings( SmallVector{packPaddings.begin(), packPaddings.end()}); @@ -409,7 +415,7 @@ void LinalgSingleTilingExpertPass::runOnOperation() { // Gather tiled loops that aren't distribution loops from previous tiling // stages. - LinalgPeelOptions peelingOptions; + linalg::LinalgPeelOptions peelingOptions; peelingOptions.loopsToPeelComputationFunction = [](OpBuilder &builder, Operation *op, SmallVectorImpl &loopsToPeel) { @@ -432,7 +438,7 @@ void LinalgSingleTilingExpertPass::runOnOperation() { }; CodegenStrategy strategy; - StringRef genericOpName = GenericOp::getOperationName(); + StringRef genericOpName = linalg::GenericOp::getOperationName(); strategy.tileIf(doTiling, anchorOpName, tilingOptions) .padIf(pad, anchorOpName, paddingOptions) .decomposeIf(decomposeToLowerDimOp) @@ -443,6 +449,8 @@ void LinalgSingleTilingExpertPass::runOnOperation() { // Created a nested OpPassManager and run. OpPassManager dynamicPM(func::FuncOp::getOperationName()); strategy.configurePassPipeline(dynamicPM, funcOp.getContext()); + dynamicPM.addPass( + iree_compiler::IREE::LinalgExt::createLinalgStrategyEnablePass()); if (failed(runPipeline(dynamicPM, funcOp))) { return signalPassFailure(); } @@ -490,8 +498,8 @@ void LinalgVectorLoweringPass::runOnOperation() { .enableFullUnroll(unrollVectorTransfers) .enableLowerPermutationMaps(); - LinalgVectorLoweringOptions vectorLoweringOptions = - LinalgVectorLoweringOptions() + linalg::LinalgVectorLoweringOptions vectorLoweringOptions = + linalg::LinalgVectorLoweringOptions() // Lowering of vector contractions. .enableContractionLowering(vectorLoweringStage >= 0) // Lowering of vector multi_reduction. @@ -526,6 +534,8 @@ void LinalgVectorLoweringPass::runOnOperation() { OpPassManager dynamicPM(func::FuncOp::getOperationName()); func::FuncOp funcOp = getOperation(); strategy.configurePassPipeline(dynamicPM, funcOp.getContext()); + dynamicPM.addPass( + iree_compiler::IREE::LinalgExt::createLinalgStrategyEnablePass()); if (failed(runPipeline(dynamicPM, funcOp))) { return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Codegen/Sandbox/Passes.td b/compiler/src/iree/compiler/Codegen/Sandbox/Passes.td index f909854b1a73..40ef5eb86579 100644 --- a/compiler/src/iree/compiler/Codegen/Sandbox/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Sandbox/Passes.td @@ -126,7 +126,7 @@ def LinalgSingleTilingExpert Option<"vectorize", "vectorize", "bool", /*default=*/"false", "Rewrite the linalg op as a vector operation.">, Option<"vectorizePadding", "vectorize-padding", "bool", /*default=*/"false", - "Rewrite all linalg.pad_tensor ops in the function to vector form.">, + "Rewrite all tensor.pad ops in the function to vector form.">, // IREE specific options Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1", diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD b/compiler/src/iree/compiler/Codegen/Utils/BUILD index cc66c369d6f7..bb6ea9f19666 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/BUILD +++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD @@ -18,11 +18,13 @@ iree_compiler_cc_library( name = "Utils", srcs = [ "GPUUtils.cpp", + "LinkingUtils.cpp", "MarkerUtils.cpp", "Utils.cpp", ], hdrs = [ "GPUUtils.h", + "LinkingUtils.h", "MarkerUtils.h", "Utils.h", ], diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt index 69c7679ba25e..0f73750ec8cc 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt @@ -15,10 +15,12 @@ iree_cc_library( Utils HDRS "GPUUtils.h" + "LinkingUtils.h" "MarkerUtils.h" "Utils.h" SRCS "GPUUtils.cpp" + "LinkingUtils.cpp" "MarkerUtils.cpp" "Utils.cpp" DEPS diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index 99d062b81b26..89eeef7c7865 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -84,7 +84,13 @@ bool canPerformVectorAccessUsingAllThreads(ArrayRef shape, int64_t numElementPerThread = dim.index() == 0 ? vectorSize : 1; int64_t numThreads = dim.value() / numElementPerThread; if (numThreads == 0) return false; - numThreads = std::min(numThreads, threadsAvailable); + if (numThreads > threadsAvailable) { + // If there are no enough remaining threads to distribute the current + // dimension, try to use all remaining threads. But we still need to make + // sure all work can be distributed to these threads evenly. + if (numThreads % threadsAvailable != 0) return false; + numThreads = threadsAvailable; + } if (threadsAvailable % numThreads != 0) return false; threadsAvailable = threadsAvailable / numThreads; if (threadsAvailable == 1) break; diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp new file mode 100644 index 000000000000..5b4912be24f0 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -0,0 +1,273 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Utils/LinkingUtils.h" + +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +namespace iree_compiler { + +SetVector gatherExecutableTargets( + ArrayRef executableOps) { + SetVector result; + for (auto executableOp : executableOps) { + auto variantOps = llvm::to_vector<4>( + executableOp.getOps()); + for (auto variantOp : variantOps) { + result.insert(variantOp.getTarget()); + } + } + return result; +} + +// Renames |op| within |moduleOp| with a new name that is unique within both +// |moduleOp| and |optionalSymbolTable| (if one is provided). +static void renameWithDisambiguatedName( + Operation *op, Operation *moduleOp, + DenseMap &targetSymbolMap, + SymbolTable *optionalSymbolTable) { + StringRef originalName = SymbolTable::getSymbolName(op).getValue(); + + // Iteratively try suffixes until we find one that isn't used. + std::string disambiguatedName; + int uniqueingCounter = 0; + do { + disambiguatedName = + llvm::formatv("{0}_{1}", originalName, uniqueingCounter++).str(); + } while ( + targetSymbolMap.lookup(disambiguatedName) || + (optionalSymbolTable && optionalSymbolTable->lookup(disambiguatedName))); + + SymbolTableCollection symbolTable; + SymbolUserMap symbolUsers(symbolTable, moduleOp); + mlir::StringAttr nameAttr = + mlir::StringAttr::get(op->getContext(), disambiguatedName); + symbolUsers.replaceAllUsesWith(op, nameAttr); + SymbolTable::setSymbolName(op, disambiguatedName); +} + +// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version. +// Only difference is one has the symbol map that we don't even need. + +// Destructively merges |sourceModuleOp| into |targetModuleOp|. +// |targetSymbolMap| is updated with the new symbols. +// +// If a private symbol in |sourceModuleOp| conflicts with another symbol +// (public or private) tracked in |targetSymbolMap|, it will be renamed. +// +// Fails if a public symbol in |sourceModuleOp| conflicts with another public +// symbol tracked in |targetSymbolMap|. +static LogicalResult mergeModuleInto( + Operation *sourceModuleOp, Operation *targetModuleOp, + DenseMap &targetSymbolMap) { + auto &sourceBlock = sourceModuleOp->getRegion(0).front(); + auto &targetBlock = targetModuleOp->getRegion(0).front(); + SymbolTable sourceSymbolTable(sourceModuleOp); + auto allOps = llvm::to_vector<8>( + llvm::map_range(sourceBlock, [&](Operation &op) { return &op; })); + + for (auto &op : allOps) { + if (op->hasTrait()) continue; + if (auto symbolOp = dyn_cast(op)) { + auto symbolName = symbolOp.getName(); + + // Resolve symbol name conflicts. + if (auto targetOp = targetSymbolMap[symbolName]) { + if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) { + // Private symbols can be safely folded into duplicates or renamed. + if (OperationEquivalence::isEquivalentTo( + targetOp, op, OperationEquivalence::exactValueMatch, + OperationEquivalence::exactValueMatch, + OperationEquivalence::Flags::IgnoreLocations)) { + // Optimization: skip over duplicate private symbols. + // We could let CSE do this later, but we may as well check here. + continue; + } else { + // Preserve the op but give it a unique name. + renameWithDisambiguatedName(op, sourceModuleOp, targetSymbolMap, + &sourceSymbolTable); + } + } else { + // The source symbol has 'nested' or 'public' visibility. + if (SymbolTable::getSymbolVisibility(targetOp) != + SymbolTable::Visibility::Private) { + // Oops! Both symbols are public and we can't safely rename either. + // If you hit this with ops that you think are safe to rename, mark + // them private. + // + // Note: we could also skip linking between executables with + // conflicting symbol names. We think such conflicts will be better + // fixed in other ways, so we'll emit an error until we find a case + // where that isn't true. + return op->emitError() + << "multiple public symbols with the name: " << symbolName; + } else { + // Keep the original name for our new op, rename the target op. + renameWithDisambiguatedName(targetOp, targetModuleOp, + targetSymbolMap, + /*optionalSymbolTable=*/nullptr); + } + } + } + targetSymbolMap[SymbolTable::getSymbolName(op).getValue()] = op; + } + if (!targetBlock.empty() && + targetBlock.back().hasTrait()) { + op->moveBefore(&targetBlock.back()); + } else { + op->moveBefore(&targetBlock, targetBlock.end()); + } + } + + // Now that we're done cloning its ops, delete the original target op. + sourceModuleOp->erase(); + + return success(); +} + +struct SymbolReplacements { + DenseMap executableRefs; + DenseMap variantRefs; + DenseMap exportRefs; +}; + +// Replaces each usage of an entry point with its original symbol name with a +// new symbol name. +// +// Due to replaceSubElements recursing into symbol refs we need to perform +// replacement in descending symbol ref length; otherwise replacing the +// executable name in `@old_executable::@old_export` would result in +// `@new_executable::@old_export` and an export update would then not match the +// new/old mismatched ref. This means we have to do three walks over the entire +// module in order to do the replacements; not great. +static void replaceEntryPointUses( + mlir::ModuleOp moduleOp, const SymbolReplacements &symbolReplacements) { + auto replaceSymbolRefs = [](Operation *rootOp, + const DenseMap &map) { + auto allUses = SymbolTable::getSymbolUses(rootOp); + if (!allUses) return; + for (auto use : *allUses) { + auto oldAttr = use.getSymbolRef(); + auto newAttr = map.lookup(oldAttr); + if (!newAttr) continue; + auto newDict = use.getUser()->getAttrDictionary().replaceSubElements( + [&](Attribute attr) -> std::pair { + if (attr == oldAttr) { + // Found old->new replacement. + return {newAttr, WalkResult::skip()}; + } else if (attr.isa()) { + // Don't recurse into symbol refs - we only want to match roots. + return {attr, WalkResult::skip()}; + } + // Non-symbol ref attr. + return {attr, WalkResult::advance()}; + }); + use.getUser()->setAttrs(newDict.cast()); + } + }; + replaceSymbolRefs(moduleOp, symbolReplacements.exportRefs); + replaceSymbolRefs(moduleOp, symbolReplacements.variantRefs); + replaceSymbolRefs(moduleOp, symbolReplacements.executableRefs); + for (auto funcLikeOp : moduleOp.getOps()) { + replaceSymbolRefs(funcLikeOp, symbolReplacements.exportRefs); + replaceSymbolRefs(funcLikeOp, symbolReplacements.variantRefs); + replaceSymbolRefs(funcLikeOp, symbolReplacements.executableRefs); + } +} + +LogicalResult linkExecutablesInto( + mlir::ModuleOp moduleOp, + ArrayRef sourceExecutableOps, + IREE::HAL::ExecutableOp linkedExecutableOp, + IREE::HAL::ExecutableVariantOp linkedTargetOp, + std::function getInnerModuleFn, + OpBuilder &builder) { + int nextEntryPointOrdinal = 0; + DenseMap targetSymbolMap; + SymbolReplacements symbolReplacements; + + auto linkedTargetBuilder = + OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); + auto linkedModuleOp = getInnerModuleFn(linkedTargetOp.getInnerModule()); + + // Iterate over all source executable ops, linking as many as we can. + for (auto sourceExecutableOp : sourceExecutableOps) { + // Remap root executable refs. + symbolReplacements.executableRefs[SymbolRefAttr::get(sourceExecutableOp)] = + SymbolRefAttr::get(linkedExecutableOp); + + auto variantOps = llvm::to_vector<4>( + sourceExecutableOp.getOps()); + for (auto variantOp : variantOps) { + // Only process compatible targets. + // TODO(benvanik): allow for grouping when multi-versioning is supported? + // We could, for example, link all aarch64 variants together and then + // use function multi-versioning to let LLVM insert runtime switches. + if (variantOp.getTarget() != linkedTargetOp.getTarget()) continue; + + // Remap variant refs. + auto oldVariantRefAttr = + SymbolRefAttr::get(builder.getContext(), sourceExecutableOp.getName(), + {SymbolRefAttr::get(variantOp)}); + auto newVariantRefAttr = + SymbolRefAttr::get(builder.getContext(), linkedExecutableOp.getName(), + {SymbolRefAttr::get(linkedTargetOp)}); + symbolReplacements.variantRefs[oldVariantRefAttr] = newVariantRefAttr; + + // Clone export ops and queue remapping ordinals and updating + // symbol refs. + for (auto exportOp : variantOp.getOps()) { + auto newExportOp = + linkedTargetBuilder.create( + exportOp.getLoc(), exportOp.getSymNameAttr(), + builder.getIndexAttr(nextEntryPointOrdinal++), + exportOp.getLayout(), ArrayAttr{}, IntegerAttr{}); + newExportOp->setDialectAttrs(exportOp->getDialectAttrs()); + + // Add to replacement table for fixing up dispatch calls referencing + // this export. + auto oldExportRefAttr = SymbolRefAttr::get( + builder.getContext(), sourceExecutableOp.getName(), + {SymbolRefAttr::get(variantOp), SymbolRefAttr::get(exportOp)}); + auto newExportRefAttr = SymbolRefAttr::get( + builder.getContext(), linkedExecutableOp.getName(), + {SymbolRefAttr::get(linkedTargetOp), + SymbolRefAttr::get(newExportOp)}); + symbolReplacements.exportRefs[oldExportRefAttr] = newExportRefAttr; + } + + // Merge the existing module into the new linked module op. + auto sourceModuleOp = getInnerModuleFn(variantOp.getInnerModule()); + if (failed(mergeModuleInto(sourceModuleOp, linkedModuleOp, + targetSymbolMap))) { + return failure(); + } + + variantOp.erase(); + } + + if (sourceExecutableOp.getOps().empty()) { + sourceExecutableOp.erase(); + } + } + + // Update references to @executable::@target::@entry symbols. + replaceEntryPointUses(moduleOp, symbolReplacements); + + // Remove if we didn't add anything. + if (linkedTargetOp.getOps().empty()) { + linkedTargetOp.erase(); + linkedExecutableOp.erase(); + } + + return success(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h new file mode 100644 index 000000000000..e082544f41fb --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h @@ -0,0 +1,34 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_UTILS_LINKINGUTILS_H_ +#define IREE_COMPILER_CODEGEN_UTILS_LINKINGUTILS_H_ + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" + +namespace mlir { +namespace iree_compiler { + +// Returns a uniqued set of all targets in |executableOps|. +SetVector gatherExecutableTargets( + ArrayRef executableOps); + +// Links all executables for the current target found in |moduleOp| into +// |linkedExecutableOp|. Functions will be cloned into |linkedModuleOp|. +LogicalResult linkExecutablesInto( + mlir::ModuleOp moduleOp, + ArrayRef sourceExecutableOps, + IREE::HAL::ExecutableOp linkedExecutableOp, + IREE::HAL::ExecutableVariantOp linkedTargetOp, + std::function getInnerModuleFn, + OpBuilder &builder); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_CODEGEN_UTILS_LINKINGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/VMVX/BUILD b/compiler/src/iree/compiler/Codegen/VMVX/BUILD index 7c90d0ce565d..5b5e8b181dd3 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/BUILD +++ b/compiler/src/iree/compiler/Codegen/VMVX/BUILD @@ -16,12 +16,17 @@ iree_compiler_cc_library( name = "VMVX", srcs = [ "LowerLinalgMicrokernels.cpp", + "Passes.cpp", + "VMVXLinkExecutables.cpp", ], deps = [ "//compiler/src/iree/compiler/Codegen:PassHeaders", + "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Dialect/VM/IR", "//compiler/src/iree/compiler/Dialect/VMVX/IR", "//compiler/src/iree/compiler/Dialect/VMVX/IR:VMVXDialect", + "//compiler/src/iree/compiler/Utils", "//runtime/src/iree/builtins/ukernel", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithmeticDialect", diff --git a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt index 16216c5f33c3..c661c463ef84 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt @@ -15,6 +15,8 @@ iree_cc_library( VMVX SRCS "LowerLinalgMicrokernels.cpp" + "Passes.cpp" + "VMVXLinkExecutables.cpp" DEPS LLVMSupport MLIRArithmeticDialect @@ -27,9 +29,12 @@ iree_cc_library( MLIRTransforms iree::builtins::ukernel iree::compiler::Codegen::PassHeaders + iree::compiler::Codegen::Utils iree::compiler::Dialect::Util::IR + iree::compiler::Dialect::VM::IR iree::compiler::Dialect::VMVX::IR iree::compiler::Dialect::VMVX::IR::VMVXDialect + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp new file mode 100644 index 000000000000..8e5c95ccf8d1 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp @@ -0,0 +1,23 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Passes.h" + +#include "iree/compiler/Codegen/PassDetail.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree_compiler { + +// NOTE: this runs on the top-level program module containing all +// hal.executable ops. +void buildVMVXLinkingPassPipeline(OpPassManager &passManager) { + passManager.addPass(createVMVXLinkExecutablesPass()); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp new file mode 100644 index 000000000000..9e1cb676a874 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp @@ -0,0 +1,80 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/PassDetail.h" +#include "iree/compiler/Codegen/Passes.h" +#include "iree/compiler/Codegen/Utils/LinkingUtils.h" +#include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "iree/compiler/Utils/ModuleUtils.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct VMVXLinkExecutablesPass + : public VMVXLinkExecutablesBase { + VMVXLinkExecutablesPass() = default; + void runOnOperation() override { + auto moduleOp = getOperation(); + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + auto sourceExecutableOps = + llvm::to_vector<8>(moduleOp.getOps()); + if (sourceExecutableOps.size() <= 1) return; + + // Guess a module name, if needed, to make the output files readable. + auto moduleName = guessModuleName(moduleOp, "vmvx_module"); + + // Create our new "linked" hal.executable. + std::string linkedExecutableName = + llvm::formatv("{0}_linked_{1}", moduleName, "vmvx"); + auto linkedExecutableOp = moduleBuilder.create( + moduleOp.getLoc(), linkedExecutableName); + linkedExecutableOp.setVisibility( + sourceExecutableOps.front().getVisibility()); + auto executableBuilder = + OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock()); + + // Gather all unique executable targets - we may have multiple. + auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps); + for (auto executableTargetAttr : executableTargetAttrs) { + // Add our VMVX hal.executable.variant with an empty module. + auto linkedTargetOp = + executableBuilder.create( + moduleOp.getLoc(), executableTargetAttr.getSymbolNameFragment(), + executableTargetAttr); + auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); + auto linkedModuleOp = targetBuilder.create(moduleOp.getLoc()); + + // Add an empty vm.module to that module as our vm.funcs must live in it. + auto nestedBuilder = OpBuilder::atBlockBegin(linkedModuleOp.getBody()); + nestedBuilder.create(moduleOp.getLoc(), + "linked_module"); + + // Try linking together all executable variants for this target. + if (failed(linkExecutablesInto( + moduleOp, sourceExecutableOps, linkedExecutableOp, linkedTargetOp, + [](mlir::ModuleOp moduleOp) { + return *moduleOp.getOps().begin(); + }, + nestedBuilder))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> createVMVXLinkExecutablesPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD b/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD index ecefbbb80fb5..cde2013d60c8 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD +++ b/compiler/src/iree/compiler/Codegen/VMVX/test/BUILD @@ -18,6 +18,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "link_executables.mlir", "lower_linalg_microkernels.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt index a28e745b6bb4..76fa37db9f87 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/VMVX/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "link_executables.mlir" "lower_linalg_microkernels.mlir" TOOLS FileCheck diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/linking.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir similarity index 87% rename from compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/linking.mlir rename to compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir index d8a1cc890c71..117e571cbd2e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/linking.mlir +++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-hal-link-executables %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-vmvx-link-executables %s | FileCheck %s #vmvx_target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> #pipeline_layout = #hal.pipeline.layout target(@vmvx_linked::@vmvx_bytecode_fb::@dispatch_0) workgroups([%c1, %c1, %c1]) -// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmvx_linked::@vmvx_bytecode_fb::@dispatch_1) workgroups([%c1, %c1, %c1]) -// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmvx_linked::@vmvx_bytecode_fb::@dispatch_2) workgroups([%c1, %c1, %c1]) +// CHECK: testing.func.a = @link_executables_linked_vmvx +// CHECK-SAME: testing.func.b = @link_executables_linked_vmvx::@vmvx_bytecode_fb +// CHECK-SAME: testing.func.c = @link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_0 +// CHECK: testing.op.a = @link_executables_linked_vmvx +// CHECK-SAME: testing.op.b = @link_executables_linked_vmvx::@vmvx_bytecode_fb +// CHECK-SAME: testing.op.c = @link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_0 +// CHECK: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_0) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_1) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_2) workgroups([%c1, %c1, %c1]) // // CHECK: util.initializer -// CHECK: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmvx_linked::@vmvx_bytecode_fb::@dispatch_0) workgroups([%c1, %c1, %c1]) -// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmvx_linked::@vmvx_bytecode_fb::@dispatch_1) workgroups([%c1, %c1, %c1]) -// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmvx_linked::@vmvx_bytecode_fb::@dispatch_2) workgroups([%c1, %c1, %c1]) +// CHECK: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_0) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_1) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@link_executables_linked_vmvx::@vmvx_bytecode_fb::@dispatch_2) workgroups([%c1, %c1, %c1]) // ----- @@ -207,7 +207,7 @@ hal.executable private @dispatch_1 { // // CHECK-NOT: hal.executable private @dispatch_0 // CHECK-NOT: hal.executable private @dispatch_1 -// CHECK: hal.executable private @vmvx_linked { +// CHECK: hal.executable private @link_executables_linked_vmvx { // CHECK: hal.executable.variant public @vmvx_bytecode_fb, target = #executable_target_vmvx_bytecode_fb { // CHECK: module { // CHECK-NEXT: vm.module public @linked_module { diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD index b4477f253b19..d0ba0f64548e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD @@ -38,9 +38,11 @@ iree_compiler_cc_library( "ConvertConv2DToImg2Col.cpp", "ConvertLinalgMatmulToMmt4D.cpp", "ConvertRegionToWorkgroups.cpp", + "ConvertToFlow.cpp", "DeduplicateExecutables.cpp", "DetachElementwiseFromNamedOps.cpp", "DispatchLinalgOnTensors.cpp", + "DispatchLinalgOnTensorsViaRegionOps.cpp", "DispatchWithTransformDialect.cpp", "DumpDispatchGraph.cpp", "ExpandTensorShapes.cpp", @@ -55,13 +57,13 @@ iree_compiler_cc_library( "OptimizeNumerics.cpp", "OutlineDispatchRegions.cpp", "PadLinalgOps.cpp", - "PadTensorToTensorInsertSlice.cpp", "PassDetail.h", "Passes.cpp", "RegionOpUtils.cpp", "SplitReduction.cpp", "StripAndSplatConstantVariables.cpp", "StripSignedness.cpp", + "TensorPadToTensorInsertSlice.cpp", "VerifyInputLegality.cpp", ], hdrs = [ diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 3af5e8fb4fe5..203213850ed1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -37,9 +37,11 @@ iree_cc_library( "ConvertConv2DToImg2Col.cpp" "ConvertLinalgMatmulToMmt4D.cpp" "ConvertRegionToWorkgroups.cpp" + "ConvertToFlow.cpp" "DeduplicateExecutables.cpp" "DetachElementwiseFromNamedOps.cpp" "DispatchLinalgOnTensors.cpp" + "DispatchLinalgOnTensorsViaRegionOps.cpp" "DispatchWithTransformDialect.cpp" "DumpDispatchGraph.cpp" "ExpandTensorShapes.cpp" @@ -54,13 +56,13 @@ iree_cc_library( "OptimizeNumerics.cpp" "OutlineDispatchRegions.cpp" "PadLinalgOps.cpp" - "PadTensorToTensorInsertSlice.cpp" "PassDetail.h" "Passes.cpp" "RegionOpUtils.cpp" "SplitReduction.cpp" "StripAndSplatConstantVariables.cpp" "StripSignedness.cpp" + "TensorPadToTensorInsertSlice.cpp" "VerifyInputLegality.cpp" DEPS ::PassesIncGen diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp new file mode 100644 index 000000000000..e0a1dd9eed2a --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp @@ -0,0 +1,47 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h" +#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::iree_compiler; +using namespace mlir::iree_compiler::IREE; + +namespace { +// Pass to test conversion to flow patterns. +struct ConvertToFlowPass : public Flow::ConvertToFlowBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet convertToFlowPatterns(context); + Flow::populateTensorToFlowConversionPatterns(context, + convertToFlowPatterns); + memref::populateResolveRankedShapeTypeResultDimsPatterns( + convertToFlowPatterns); + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(convertToFlowPatterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr Flow::createConvertToFlowPass() { + return std::make_unique(); +} diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp index e3c73b6f8fc4..f8db58ad5680 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp @@ -49,8 +49,12 @@ struct DetachElementwisePattern // we see multiple output ops. if (outputOperands.size() != 1) return failure(); Value outputOperand = outputOperands.front()->get(); - if (outputOperand.getDefiningOp()) return failure(); + auto outsDefiningOp = outputOperand.getDefiningOp(); + if (!outsDefiningOp || isa(outsDefiningOp.getOperation())) { + // If not linalg op, or is a fill op, do nothing. + return failure(); + } auto outputType = outputOperand.getType().cast(); if (!outputType.getElementType().isIntOrFloat()) return failure(); auto elementType = outputType.getElementType(); @@ -88,7 +92,7 @@ struct DetachElementwisePattern for (int i = 0, e = outputMap.getNumResults(); i < e; ++i) { int pos = outputMap.getResult(i).cast().getPosition(); auto attr = linalgOp.getIteratorTypes()[pos].cast(); - if (!isParallelIterator(attr)) return failure(); + if (!linalg::isParallelIterator(attr)) return failure(); iterators.push_back(attr.getValue()); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 6c8a07c3530a..2ab384c2ff03 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -17,6 +17,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" #include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" @@ -30,6 +31,7 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -59,12 +61,6 @@ static llvm::cl::opt clInlineConstantByteLength( "dispatch region"), llvm::cl::init(256)); -static llvm::cl::opt clEnableMultiResultDispatches( - "iree-flow-enable-multi-result-dispatches", - llvm::cl::desc( - "Enable dispatch region formation to enable multi-result dispatches"), - llvm::cl::init(false)); - static const char kRootOpAttr[] = "__root_op__"; static const char kFusionGroupsAttr[] = "__fused_op__"; @@ -195,56 +191,9 @@ bool isClonableIntoDispatchOp(Operation *op) { // Methods for getting the workload information for dispatch region creation. //===----------------------------------------------------------------------===// -/// For a given operation returns the loop ranges needed to compute the op. -template -static SmallVector getLoopRanges(T operation, Location loc, - OpBuilder &builder); - -template <> -SmallVector getLoopRanges(TilingInterface tilableOp, - Location loc, - OpBuilder &builder) { - SmallVector loopRanges = tilableOp.getIterationDomain(builder); - Value one = builder.create(loc, 1); - for (auto iteratorType : llvm::enumerate(tilableOp.getLoopIteratorTypes())) { - if (iteratorType.value() == getReductionIteratorTypeName()) { - loopRanges[iteratorType.index()].size = one; - } - } - return loopRanges; -} - -template <> -SmallVector getLoopRanges( - tensor::InsertSliceOp insertSliceOp, Location loc, OpBuilder &builder) { - OpFoldResult zero = builder.getIndexAttr(0); - OpFoldResult one = builder.getIndexAttr(1); - Value source = insertSliceOp.getSource(); - SmallVector loopRanges(insertSliceOp.getSourceType().getRank(), - Range{zero, one, one}); - for (auto dim : llvm::seq(0, loopRanges.size())) { - loopRanges[dim].size = - builder.create(loc, source, dim).getResult(); - } - return loopRanges; -} - -template <> -SmallVector getLoopRanges( - tensor::ExtractSliceOp sliceOp, Location loc, OpBuilder &builder) { - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - ReifiedRankedShapedTypeDims resultDims; - (void)sliceOp.reifyResultShapes(builder, resultDims); - return llvm::to_vector(llvm::map_range(resultDims[0], [&](Value v) { - return Range{zero, v, one}; - })); -} - /// Compute the workload to use for the workgroup based on the root op. -template static SmallVector getWorkloadForRootOp(OpBuilder &builder, - OpTy rootOp) { + Operation *rootOp) { // Compute workgroup count to use for the dispatch op. These are the ranges // of the outermost parallel loops that can be distributed. Location loc = rootOp->getLoc(); @@ -256,7 +205,7 @@ static SmallVector getWorkloadForRootOp(OpBuilder &builder, Value offset = getValueOrCreateConstantIndexOp(builder, loc, r.offset); Value size = getValueOrCreateConstantIndexOp(builder, loc, r.size); Value stride = getValueOrCreateConstantIndexOp(builder, loc, r.stride); - return builder.create(rootOp.getLoc(), workload, + return builder.create(rootOp->getLoc(), workload, ValueRange{offset, size, stride}); })); } @@ -427,110 +376,28 @@ static SmallVector getOperationsToMoveIntoDispatch( int64_t groupNum = getRootNumber(rootOp); std::deque worklist; worklist.push_back(rootOp); - llvm::SmallDenseSet movedOps; - movedOps.insert(rootOp); + llvm::SmallDenseSet visitedOps; + visitedOps.insert(rootOp); while (!worklist.empty()) { Operation *currRoot = worklist.front(); worklist.pop_front(); for (auto operand : currRoot->getOperands()) { auto producer = operand.getDefiningOp(); - if (movedOps.count(producer)) continue; - if (!producer || !isInFusionGroup(producer, groupNum)) continue; - movedOps.insert(producer); + if (!producer || visitedOps.count(producer)) continue; + visitedOps.insert(producer); + if (!isInFusionGroup(producer, groupNum)) continue; worklist.push_back(producer); dispatchOps.push_back(producer); } } - return dispatchOps; + return llvm::to_vector(llvm::reverse(orderOperations(dispatchOps))); } //===---------------------------------------------------------------------===// // Methods to legalize a dispatch region op, i.e. make it isolated from above. //===---------------------------------------------------------------------===// -/// Reorders the operations in `ops` such that they could be inlined into the -/// dispatch region in that order to satisfy dependencies. -static SmallVector orderOperations(ArrayRef ops) { - LLVM_DEBUG({ - llvm::dbgs() << "Ops to be inlined :\n"; - for (auto op : ops) { - llvm::dbgs() << "\t"; - op->print(llvm::dbgs()); - llvm::dbgs() << "\n"; - } - }); - - llvm::SmallMapVector, 16> - insertAfterMap; - llvm::SetVector opSet(ops.begin(), ops.end()); - llvm::SetVector leafOps(ops.begin(), ops.end()); - // For each operation compute the list of operations in `ops` that use its - // results. Also compute the operations that form the leafs of the DAG of - // operations in `ops`. - for (auto op : ops) { - for (auto operand : op->getOperands()) { - auto definingOp = operand.getDefiningOp(); - if (!definingOp || !opSet.count(definingOp)) continue; - insertAfterMap[definingOp].push_back(op); - if (leafOps.count(op)) leafOps.remove(op); - } - } - - // The leaves are at the head of the ordered list. - SmallVector orderedOps(leafOps.begin(), leafOps.end()); - orderedOps.reserve(ops.size()); - llvm::SmallPtrSet processed; - processed.insert(leafOps.begin(), leafOps.end()); - - // `readyOps` contains the list of operations that have been just added to the - // `orderedOps` list. With these marked ready, they might make further - // operations in `ops` ready as well. - // The complexity of the algorithm is driven by these - // - Each operations is added to `readyOps` list at most once, and is removed - // after being processed - // - For every operation in `readyOps` every use of its results (within `ops`) - // is looked at once. - // - For every use, the operands of the user are processed. - // Assuming operands is O(1), i.e. constant order, the complexity is O(sum of - // number of uses of each operation). Given that the size of `ops` is at max - // O(10), and not O(100), this is assumed to be reasonable. - ArrayRef readyOps(orderedOps); - size_t startPos = 0; - while (!readyOps.empty()) { - auto op = readyOps.front(); - startPos++; - // Check all uses of `op` within `ops`. If all of the operations that define - // the operands of the user have been added to `orderedOps`, then the user - // is ready to be scheduled. - for (auto insertAfterOp : insertAfterMap[op]) { - if (processed.count(insertAfterOp)) continue; - if (llvm::all_of(insertAfterOp->getOperands(), [&](Value operand) { - Operation *operandDefiningOp = operand.getDefiningOp(); - return !operandDefiningOp || !opSet.count(operandDefiningOp) || - processed.count(operandDefiningOp); - })) { - // readyOps.push_back(insertAfterOp); - orderedOps.push_back(insertAfterOp); - processed.insert(insertAfterOp); - } - } - readyOps = ArrayRef(orderedOps).drop_front(startPos); - } - - LLVM_DEBUG({ - llvm::dbgs() << "Ops to be inlined (sorted) : \n"; - for (auto op : orderedOps) { - llvm::dbgs() << "\t"; - op->print(llvm::dbgs()); - llvm::dbgs() << "\n"; - } - }); - assert(orderedOps.size() == ops.size() && - "ordering of inlined operations failed"); - return orderedOps; -} - /// Checks if the `Value` has a use within the dispatch that is unfusable. static bool hasUnfusableUseInDispatch( Value v, IREE::Flow::DispatchWorkgroupsOp dispatchOp) { @@ -819,7 +686,7 @@ struct CreateDispatchRegionOp : Base { // Get the workload to use for the dispatch. FailureOr> workload = - getWorkloadForRootOp(rewriter, rootOp); + getWorkloadForRootOp(rewriter, rootOp.getOperation()); if (failed(workload)) { return failure(); } @@ -849,27 +716,91 @@ struct CreateDispatchRegionOp : Base { // Heuristics for fusing dispatchble ops with root ops using tile + fuse. //===----------------------------------------------------------------------===// -/// Checks if the producer and consumer LinalgOps can be fused. -static bool areFusableLinalgOps(OpOperand &use) { - return areLinalgOpsFusableUsingTileAndFuse(use); +/// Returns a bit vector of size number of loops of the `interfaceOp` with +/// the bits corresponding to outer parallel loops set to `true`. +static llvm::SmallBitVector getOuterParallelLoops(TilingInterface interfaceOp) { + SmallVector loopIteratorTypes = interfaceOp.getLoopIteratorTypes(); + llvm::SmallBitVector parallelLoops(loopIteratorTypes.size()); + for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) { + if (iteratorType.value() != getParallelIteratorTypeName()) break; + parallelLoops.set(iteratorType.index()); + } + return parallelLoops; +} + +/// Returns true if `map` is an identity map with zeros, i.e. if you +/// drop the result exprs that are constant zeros, the `map` will become an +/// identity. +static bool isIdentityMapWithZeros(AffineMap map) { + if (map.getNumSymbols() != 0) return false; + unsigned dimsSeen = 0; + for (auto result : map.getResults()) { + bool isValidExpr = TypeSwitch(result) + .Case([&dimsSeen](auto dimExpr) { + if (dimExpr.getPosition() != dimsSeen) + return false; + dimsSeen++; + return true; + }) + .Case([](auto constExpr) { + return constExpr.getValue() == 0; + }) + .Default([](AffineExpr) { return false; }); + if (!isValidExpr) return false; + } + return dimsSeen == map.getNumDims(); } -/// Returns true if this is a fusable use. -static bool isFusableWithConsumer(OpOperand &use) { - // Check for linalg producer -> consumer fusion with tile + fuse. - return areFusableLinalgOps(use); +/// Method to check if two `linalg.generic` op with producer-consumer +/// relationship through `operand` have compatible outer-parallel loops. +static bool hasCompatibleOuterParallelLoops( + OpOperand &operand, bool allowConsumerParallelismPessimization) { + auto producer = operand.get().getDefiningOp(); + auto consumer = dyn_cast(operand.getOwner()); + if (!producer || !consumer) return false; + + llvm::SmallBitVector producerParallelLoops = + getOuterParallelLoops(cast(producer.getOperation())); + llvm::SmallBitVector consumerParallelLoops = + getOuterParallelLoops(cast(consumer.getOperation())); + + if (allowConsumerParallelismPessimization) { + if (producerParallelLoops.count() > consumerParallelLoops.count()) + return false; + } else if (producerParallelLoops.count() != consumerParallelLoops.count()) { + return false; + } + + auto producerIndexingMap = + producer.getTiedIndexingMapForResult(operand.get().cast()); + auto consumerIndexingMap = consumer.getTiedIndexingMap(&operand); + if (!producerIndexingMap.isProjectedPermutation() || + !consumerIndexingMap.isProjectedPermutation()) { + return false; + } + + /// Project out the non-parallel dimensions. + llvm::SmallBitVector producerProjectedDims(producerParallelLoops); + producerProjectedDims.flip(); + auto projectedProducerMap = + getProjectedMap(producerIndexingMap, producerProjectedDims); + + llvm::SmallBitVector consumerProjectedDims(producerParallelLoops); + consumerProjectedDims.flip(); + consumerProjectedDims.resize(consumer.getNumLoops(), true); + auto projectedConsumerMap = + getProjectedMap(consumerIndexingMap, consumerProjectedDims); + + return isIdentityMapWithZeros(projectedProducerMap) && + isIdentityMapWithZeros(projectedConsumerMap); } /// For all uses of an operation, finds the use that dominates all other uses. static Optional getFusableUse(Operation *op, - DominanceInfo const &dominanceInfo) { - if (!clEnableMultiResultDispatches) { - if (op->hasOneUse()) { - OpOperand &use = *(op->use_begin()); - return &use; - } - return llvm::None; - } + DominanceInfo const &dominanceInfo, + bool fuseMultiUse) { + if (!fuseMultiUse && !op->hasOneUse()) return llvm::None; + for (auto &use : op->getUses()) { Operation *user = use.getOwner(); if (llvm::all_of(op->getUsers(), [&](Operation *c) { @@ -881,11 +812,65 @@ static Optional getFusableUse(Operation *op, return llvm::None; } +/// Returns true if the operands are fusable under the aggressive fusion +/// heuristics. +static bool areOpsAggresiveFusable(Operation *producer, Operation *consumer, + bool allowConsumerParallelismPessimization) { + // Collect all the uses from producer to consumer. + SmallVector allUses; + for (OpOperand &producerUse : producer->getUses()) { + if (producerUse.getOwner() != consumer) continue; + allUses.push_back(&producerUse); + } + + // Check that the consumer and producer have compatible outer parallel loops. + if (!llvm::all_of(allUses, [&](OpOperand *operand) { + return hasCompatibleOuterParallelLoops( + *operand, allowConsumerParallelismPessimization); + })) { + return false; + } + + // Finally only fuse if the `ins` operand can be properly bufferized. + // TODO(#10498): Handle the multi-result case. + return llvm::all_of(allUses, [](OpOperand *operand) { + return isInsOperandBufferizable(operand, /*aggressiveFusion=*/true); + }); +} + +/// Returns true if this is a fusable use, while fusing a root with its +/// consumer. +static bool isFusableWithConsumer(OpOperand &fusedOperand, + bool aggressiveFusion) { + // Use the original fusion heuristics if aggressive fusion isn't enabled. + if (!aggressiveFusion) + return areLinalgOpsFusableUsingTileAndFuse(fusedOperand); + + // Logics with aggressive fusion heuristics. + Operation *producer = fusedOperand.get().getDefiningOp(); + Operation *consumer = fusedOperand.getOwner(); + + if (!isa(producer) || !isa(consumer)) + return false; + + auto consumerLinalgOp = cast(consumer); + + // Check that the consumer is all parallel. + if (consumerLinalgOp.getNumLoops() != + consumerLinalgOp.getNumParallelLoops()) { + return false; + } + + return areOpsAggresiveFusable(producer, consumer, + /*allowConsumerParallelismPessimization=*/true); +} + /// Fuses roots with its consumers. If a root is fused with its consumer, it is /// no more tagged as a root to aid with the dispatch region formation. static void fuseRootsWithConsumers(MLIRContext *context, ArrayRef roots, - DominanceInfo const &dominanceInfo) { + DominanceInfo const &dominanceInfo, + bool aggressiveFusion) { SmallVector workList(roots.begin(), roots.end()); // Fuse with consumers where possible. while (!workList.empty()) { @@ -902,7 +887,8 @@ static void fuseRootsWithConsumers(MLIRContext *context, appendToFusionGroup(currRoot, rootNumber); }; - Optional fusableUse = getFusableUse(currRoot, dominanceInfo); + Optional fusableUse = getFusableUse( + currRoot, dominanceInfo, /*fuseMultiUse=*/aggressiveFusion); if (!fusableUse) continue; // Analyse the use to see if it is fusable. @@ -912,7 +898,7 @@ static void fuseRootsWithConsumers(MLIRContext *context, continue; } - if (isFusableWithConsumer(*(fusableUse.value()))) { + if (isFusableWithConsumer(*(fusableUse.value()), aggressiveFusion)) { updateRootTo(consumerOp); workList.push_back(consumerOp); } @@ -920,19 +906,29 @@ static void fuseRootsWithConsumers(MLIRContext *context, } /// Method to check if the consumer of a use can be fused with its producer. -static bool isFusableWithProducer(OpOperand &operand) { +static bool isFusableWithProducer(OpOperand &operand, bool aggressiveFusion) { Operation *producer = operand.get().getDefiningOp(); Operation *consumer = operand.getOwner(); - if (isa(consumer) && isa(producer)) { - auto consumerLinalgOp = cast(consumer); - auto producerLinalgOp = cast(producer); - if (consumerLinalgOp.isOutputTensor(&operand) && - producerLinalgOp.getNumLoops() == - producerLinalgOp.getNumParallelLoops()) { - return true; - } + if (!isa(consumer) || !isa(producer)) + return false; + + auto consumerLinalgOp = cast(consumer); + auto producerLinalgOp = cast(producer); + if (consumerLinalgOp.isOutputTensor(&operand) && + producerLinalgOp.getNumLoops() == + producerLinalgOp.getNumParallelLoops()) { + return true; } + + // Only fuse on inputs if both are generic ops. + if (aggressiveFusion && consumerLinalgOp.isInputTensor(&operand) && + isa(consumer) && isa(producer)) { + return areOpsAggresiveFusable( + producer, consumer, + /*allowConsumerParallelismPessimization=*/false); + } + return false; } @@ -940,21 +936,28 @@ static bool isFusableWithProducer(OpOperand &operand) { /// in reverse to fuse with producers. static void fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum, - DominanceInfo const &dominanceInfo) { - // We probably want a worklist algorithm here, but for now just look at - // immediate producers. - for (OpOperand &operand : root->getOpOperands()) { - Operation *producer = operand.get().getDefiningOp(); - if (!producer) continue; - if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) { - continue; - } + DominanceInfo const &dominanceInfo, + bool aggressiveFusion) { + SmallVector worklist; + worklist.push_back(root); + + while (!worklist.empty()) { + Operation *candidate = worklist.pop_back_val(); + for (OpOperand &operand : candidate->getOpOperands()) { + Operation *producer = operand.get().getDefiningOp(); + if (!producer) continue; + if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) { + continue; + } + + Optional fusableUse = getFusableUse( + producer, dominanceInfo, /*fuseMultiUse=*/aggressiveFusion); + if (!fusableUse || fusableUse.value()->getOwner() != candidate) continue; - Optional fusableUse = getFusableUse(producer, dominanceInfo); - if (!fusableUse || fusableUse.value()->getOwner() != root) continue; + if (!isFusableWithProducer(operand, aggressiveFusion)) continue; - if (isFusableWithProducer(operand)) { appendToFusionGroup(producer, groupNum); + worklist.push_back(producer); } } } @@ -968,7 +971,8 @@ static void fuseRootsWithProducers(MLIRContext *context, Operation *root, /// very simple heuristic is used below, but the mechanism should be general /// enough to capture any heuristic. static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp, - DominanceInfo const &dominanceInfo) { + DominanceInfo const &dominanceInfo, + bool aggressiveFusion) { unsigned numRootOps = 0; MLIRContext *context = funcOp->getContext(); OpBuilder builder(context); @@ -985,11 +989,12 @@ static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp, unsigned newGroup = numRootOps++; setRootAttribute(context, &op, newGroup); - fuseRootsWithProducers(context, &op, newGroup, dominanceInfo); + fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, + aggressiveFusion); roots.push_back(&op); } roots = llvm::to_vector(llvm::reverse(roots)); - fuseRootsWithConsumers(context, roots, dominanceInfo); + fuseRootsWithConsumers(context, roots, dominanceInfo, aggressiveFusion); } // Once all root linalg ops have been tagged, put all remaining generic ops @@ -1009,7 +1014,7 @@ static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp, roots.push_back(&op); } roots = llvm::to_vector(llvm::reverse(roots)); - fuseRootsWithConsumers(context, roots, dominanceInfo); + fuseRootsWithConsumers(context, roots, dominanceInfo, aggressiveFusion); } return numRootOps; @@ -1024,36 +1029,17 @@ struct DispatchLinalgOnTensorsPass .insert(); } - DispatchLinalgOnTensorsPass() = default; - DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {} + DispatchLinalgOnTensorsPass(bool aggressiveFusion) { + this->aggressiveFusion = aggressiveFusion; + } + DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) + : DispatchLinalgOnTensorsPass(pass.aggressiveFusion) {} void runOnOperation() override; private: Statistic numDispatches{this, "number of dispatches", "Number of Flow dispatches created"}; }; - -// Pass to test conversion to flow patterns. -struct ConvertToFlowPass : public ConvertToFlowBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - RewritePatternSet convertToFlowPatterns(context); - populateTensorToFlowConversionPatterns(context, convertToFlowPatterns); - memref::populateResolveRankedShapeTypeResultDimsPatterns( - convertToFlowPatterns); - if (failed(applyPatternsAndFoldGreedily( - getOperation(), std::move(convertToFlowPatterns)))) { - return signalPassFailure(); - } - } -}; - } // namespace /// For all ops within `funcOp` tagged as root ops, create dispatch regions. @@ -1130,7 +1116,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { auto funcOp = getOperation(); MLIRContext *context = &getContext(); DominanceInfo const &dominanceInfo = getAnalysis(); - decideFusableLinalgOps(funcOp, dominanceInfo); + decideFusableLinalgOps(funcOp, dominanceInfo, aggressiveFusion); LLVM_DEBUG({ llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n"; @@ -1213,12 +1199,8 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { } std::unique_ptr> -createDispatchLinalgOnTensorsPass() { - return std::make_unique(); -} - -std::unique_ptr createConvertToFlowPass() { - return std::make_unique(); +createDispatchLinalgOnTensorsPass(bool aggressiveFusion) { + return std::make_unique(aggressiveFusion); } } // namespace Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensorsViaRegionOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensorsViaRegionOps.cpp new file mode 100644 index 000000000000..fc173cd76273 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensorsViaRegionOps.cpp @@ -0,0 +1,753 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This is a variant of DispatchLinalgOnTensors.cpp. DispatchWorkgroupsOps are +// built from DispatchRegionOps. This file can eventually replace the original +// DispatchLinalgOnTensors.cpp +// +// Note: The heuristic part of the implementation is unchanged and copied from +// DispatchLinalgOnTensors.cpp. + +#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h" +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h" +#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" +#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::iree_compiler; +using namespace mlir::iree_compiler::IREE; + +#define DEBUG_TYPE "iree-flow-dispatch-linalg-on-tensors-via-region-ops" + +static const int kInlineConstantByteLength = 256; +static const bool kEnableMultiResultDispatches = false; +static const char kRootOpAttr[] = "__root_op__"; +static const char kFusionGroupsAttr[] = "__fused_op__"; + +//===----------------------------------------------------------------------===// +// Helpers for fusion group formation +//===----------------------------------------------------------------------===// + +namespace { +/// A rewriter that keeps track of all tensor::DimOps. +class TensorDimTrackingRewriter : public IRRewriter { + public: + /// Create a new rewriter: Scan the given op for tensor::DimOps. + TensorDimTrackingRewriter(Operation *op) : IRRewriter(op->getContext()) { + op->walk([&](tensor::DimOp dimOp) { dimOps.insert(dimOp.getOperation()); }); + } + + /// Return all tracked tensor::DimOps. + SmallVector getTensorDimOps() { + SmallVector result; + for (Operation *op : dimOps) result.push_back(cast(op)); + return result; + } + + protected: + void notifyOperationRemoved(Operation *op) override { + IRRewriter::notifyOperationRemoved(op); + if (isa(op)) dimOps.erase(op); + } + + void notifyOperationInserted(Operation *op) override { + IRRewriter::notifyOperationInserted(op); + if (isa(op)) dimOps.insert(op); + } + + private: + SmallPtrSet dimOps; +}; +} // namespace + +/// Simplfy the given tensor::DimOps as much as possible. +/// * Static dimensions are replaced by constant. +/// * Dynamic dim ops are pushed as much as possible to the top of the function, +/// i.e., if the dim of a value is known to be equal to the dim of a value on +/// the reverse SSA use-def chain, rewrite the value with a dim op of that +/// value. +static LogicalResult simplifyDimOps(RewriterBase &rewriter, + const SmallVector &dimOps) { + for (tensor::DimOp dimOp : dimOps) { + // Only DimOps with static indices are supported. + Optional idx = dimOp.getConstantIndex(); + if (!idx.has_value()) continue; + // Only DimOps with ranked tensors are supported. + auto tensorType = dimOp.getSource().getType().dyn_cast(); + if (!tensorType) continue; + + if (!tensorType.isDynamicDim(*idx)) { + // Rewrite static dimension with constant. + int64_t size = tensorType.getShape()[*idx]; + rewriter.replaceOpWithNewOp(dimOp, size); + continue; + } + + // Try to simplify dynamic dims. + SmallVector dynamicDims; + if (failed(Flow::reifyDynamicResultDims(rewriter, dimOp.getSource(), + dynamicDims))) + return failure(); + unsigned ctr = 0; + for (int64_t i = 0; i < *dimOp.getConstantIndex(); ++i) + if (tensorType.isDynamicDim(i)) ++ctr; + rewriter.replaceOp(dimOp, dynamicDims[ctr]); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Root and fusion group attribute handling +//===----------------------------------------------------------------------===// + +/// Returns true if an op has a root operation. +static bool hasRootOpAttribute(Operation *op) { + return static_cast(op->getAttrOfType(kRootOpAttr)); +} + +/// Removes root attribute. Asserts if root attribute is not present. +static void removeRootOpAttribute(Operation *op) { + op->removeAttr(kRootOpAttr); +} + +/// Sets the root attribute for an operation. The root attribute needs a number +/// to identify the root. Asserts if root attribute is already set on an +/// operation. +static void setRootAttribute(MLIRContext *context, Operation *op, + int64_t rootNumber) { + assert(!op->hasAttr(kRootOpAttr) && + "invalid to update root attribute on an op"); + op->setAttr(kRootOpAttr, + IntegerAttr::get(IntegerType::get(context, 64), rootNumber)); +} + +/// Returns the number of the root. Asserts if the operation is not already set +/// as a root. +static int64_t getRootNumber(Operation *op) { + return op->getAttrOfType(kRootOpAttr).getInt(); +} + +/// Returns true if an op is part of a fusion group. +static bool hasFusionGroupsAttribute(Operation *op) { + return static_cast(op->getAttrOfType(kFusionGroupsAttr)); +} + +/// Returns the fusion groups for the given `op`. +static SmallVector getFusionGroups(Operation *op) { + SmallVector fusionGroups = {}; + if (auto fusionGroupsAttr = op->getAttrOfType(kFusionGroupsAttr)) { + fusionGroups = llvm::to_vector<1>(llvm::map_range( + fusionGroupsAttr, + [](Attribute attr) { return attr.cast().getInt(); })); + } + return fusionGroups; +} + +/// Appends the given `op` to the `newGroups` fusion groups. +static void appendToFusionGroup(Operation *op, ArrayRef newGroups) { + SmallVector fusionGroups = getFusionGroups(op); + fusionGroups.append(newGroups.begin(), newGroups.end()); + op->setAttr(kFusionGroupsAttr, Builder(op).getI64ArrayAttr(fusionGroups)); +} + +/// Returns true if the given `op` is in the `targetGroup` fusion group. +static bool isInFusionGroup(Operation *op, unsigned targetGroup) { + if (ArrayAttr opGroupAttr = op->getAttrOfType(kFusionGroupsAttr)) { + return llvm::any_of(opGroupAttr, [&targetGroup](Attribute attr) { + return attr.cast().getInt() == targetGroup; + }); + } + return false; +} + +/// Removes the fusion groups attribute. +static void removeFusionGroupsAttribute(Operation *op) { + op->removeAttr(kFusionGroupsAttr); +} + +//===----------------------------------------------------------------------===// +// Op property charecterizations +//===----------------------------------------------------------------------===// + +/// Operations that are treated as root operations for dispatch region +/// formation. +static bool isRootOp(Operation *op) { + if (op->getParentOfType() || + op->getParentOfType()) { + return false; + } + // Any Linalg named op or generic op with reduction iterator types is a root + // op. + if (auto linalgOp = dyn_cast(op)) { + if (isa(op)) { + return linalgOp.getNumReductionLoops() != 0; + } + return !isa(op); + } + return isa(op); +} + +/// Operations that are cloned into dispatch regions formed with other +/// operations as roots. +bool isClonableIntoDispatchOp(Operation *op) { + // TODO(#8637): `tensor.collapse_shape` and `tensor.expand_shape` are + // trivially clonable too, but they cause problems + // with bufferization. Make them clonable when fixed. + if (isa(op)) { + return true; + } + if (auto constantOp = dyn_cast(op)) { + auto constantValueAttr = constantOp.getValue(); + auto constantType = constantOp.getType(); + if (constantValueAttr.isa()) { + return true; + } else if (auto denseAttr = + constantValueAttr.dyn_cast()) { + auto shapedType = constantOp.getType().cast(); + uint64_t estimatedByteLength = + (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / + 8; + return denseAttr.isSplat() || + estimatedByteLength <= kInlineConstantByteLength; + } else if (constantType.isIntOrIndexOrFloat()) { + return true; + } + } + if (llvm::all_of(op->getOperands(), + [&](Value v) { return v.getType().isIntOrFloat(); }) && + llvm::all_of(op->getResults(), + [&](Value v) { return v.getType().isIntOrFloat(); })) { + return true; + } + return false; +} + +/// Checks if the `Value` has a use within the dispatch that is unfusable. +static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) { + for (OpOperand &use : v.getUses()) { + Operation *user = use.getOwner(); + Operation *ownerWorkgroups = + user->getParentOfType(); + Operation *ownerRegion = + user->getParentOfType(); + Operation *owner = ownerWorkgroups ? ownerWorkgroups : ownerRegion; + + // Ignore uses outside of dispatch workgroups op. + if (owner != dispatchOp) continue; + + // Cannot fuse producer of `dest` with `tensor.insert_slice`. + if (auto insertSliceUser = dyn_cast(user)) { + if (insertSliceUser.getDest() == v) return true; + } + } + return false; +} + +//===----------------------------------------------------------------------===// +// Methods for getting the workload information for dispatch region creation. +//===----------------------------------------------------------------------===// + +/// Compute the workload to use for the workgroup based on the root op. +static SmallVector getWorkloadForRootOp(OpBuilder &builder, + Operation *rootOp) { + // Compute workgroup count to use for the dispatch op. These are the ranges + // of the outermost parallel loops that can be distributed. + Location loc = rootOp->getLoc(); + SmallVector loopRanges = Flow::getLoopRanges(rootOp, loc, builder); + AffineExpr s0, s1, s2; + bindSymbols(builder.getContext(), s0, s1, s2); + AffineMap workload = AffineMap::get(0, 3, (s1 - s0).ceilDiv(s2)); + return llvm::to_vector(llvm::map_range(loopRanges, [&](Range r) -> Value { + Value offset = getValueOrCreateConstantIndexOp(builder, loc, r.offset); + Value size = getValueOrCreateConstantIndexOp(builder, loc, r.size); + Value stride = getValueOrCreateConstantIndexOp(builder, loc, r.stride); + return builder.create(rootOp->getLoc(), workload, + ValueRange{offset, size, stride}); + })); +} + +//===----------------------------------------------------------------------===// +// Heuristics for fusing dispatchble ops with root ops using tile + fuse. +//===----------------------------------------------------------------------===// + +/// Collect all ops that should be cloned into the given dispatch region op. +static SmallVector getCloneableOps( + Flow::DispatchRegionOp regionOp) { + // Find values that are used inside of the dispatch region but defined outside + // of the dispatch region. + llvm::SetVector valuesDefinedAbove; + mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove); + if (valuesDefinedAbove.empty()) return {}; + + // Traverse the defining ops of these values (and the ops on their reverse + // SSA use-def chain). + SmallVector result; + llvm::SetVector visited; + SmallVector worklist; + worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end()); + while (!worklist.empty()) { + Value outsideValue = worklist.pop_back_val(); + // Skip values that were already visited. + if (visited.count(outsideValue)) continue; + visited.insert(outsideValue); + + Operation *definingOp = outsideValue.getDefiningOp(); + if (!definingOp || !(isClonableIntoDispatchOp(definingOp)) || + hasUnfusableUseInDispatch(outsideValue, regionOp)) { + valuesDefinedAbove.insert(outsideValue); + continue; + } + result.push_back(definingOp); + worklist.append(definingOp->operand_begin(), definingOp->operand_end()); + } + + return result; +} + +/// Checks if the producer and consumer LinalgOps can be fused. +static bool areFusableLinalgOps(OpOperand &use) { + return Flow::areLinalgOpsFusableUsingTileAndFuse(use); +} + +/// Returns true if this is a fusable use. +static bool isFusableWithConsumer(OpOperand &use) { + // Check for linalg producer -> consumer fusion with tile + fuse. + return areFusableLinalgOps(use); +} + +/// For all uses of an operation, finds the use that dominates all other uses. +static Optional getFusableUse(Operation *op, + DominanceInfo const &dominanceInfo) { + if (!kEnableMultiResultDispatches) { + if (op->hasOneUse()) { + OpOperand &use = *(op->use_begin()); + return &use; + } + return llvm::None; + } + for (auto &use : op->getUses()) { + Operation *user = use.getOwner(); + if (llvm::all_of(op->getUsers(), [&](Operation *c) { + return dominanceInfo.dominates(user, c); + })) { + return &use; + } + } + return llvm::None; +} + +/// Fuses roots with its consumers. If a root is fused with its consumer, it is +/// no more tagged as a root to aid with the dispatch region formation. +static void fuseRootsWithConsumers(MLIRContext *context, + ArrayRef roots, + DominanceInfo const &dominanceInfo) { + SmallVector workList(roots.begin(), roots.end()); + // Fuse with consumers where possible. + while (!workList.empty()) { + Operation *currRoot = workList.pop_back_val(); + assert(hasRootOpAttribute(currRoot) && + "unexpected non-root op in worklist"); + + // Helper function to make the consumer the root instead of the producer + // when they are to be fused. + auto updateRootTo = [&context, &currRoot](Operation *newRoot) { + int64_t rootNumber = getRootNumber(currRoot); + setRootAttribute(context, newRoot, rootNumber); + removeRootOpAttribute(currRoot); + appendToFusionGroup(currRoot, rootNumber); + }; + + Optional fusableUse = getFusableUse(currRoot, dominanceInfo); + if (!fusableUse) continue; + + // Analyse the use to see if it is fusable. + Operation *consumerOp = fusableUse.value()->getOwner(); + if (hasRootOpAttribute(consumerOp) || + hasFusionGroupsAttribute(consumerOp)) { + continue; + } + + if (isFusableWithConsumer(*(fusableUse.value()))) { + updateRootTo(consumerOp); + workList.push_back(consumerOp); + } + } +} + +/// Method to check if the consumer of a use can be fused with its producer. +static bool isFusableWithProducer(OpOperand &operand) { + Operation *producer = operand.get().getDefiningOp(); + Operation *consumer = operand.getOwner(); + + if (isa(consumer) && isa(producer)) { + auto consumerLinalgOp = cast(consumer); + auto producerLinalgOp = cast(producer); + if (consumerLinalgOp.isOutputTensor(&operand) && + producerLinalgOp.getNumLoops() == + producerLinalgOp.getNumParallelLoops()) { + return true; + } + } + return false; +} + +/// Starting from the `root` op, traverse the operand use-def chain +/// in reverse to fuse with producers. +static void fuseRootsWithProducers(MLIRContext *context, Operation *root, + unsigned groupNum, + DominanceInfo const &dominanceInfo) { + // We probably want a worklist algorithm here, but for now just look at + // immediate producers. + for (OpOperand &operand : root->getOpOperands()) { + Operation *producer = operand.get().getDefiningOp(); + if (!producer) continue; + if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) { + continue; + } + + Optional fusableUse = getFusableUse(producer, dominanceInfo); + if (!fusableUse || fusableUse.value()->getOwner() != root) continue; + + if (isFusableWithProducer(operand)) { + appendToFusionGroup(producer, groupNum); + } + } +} + +/// Some heuristic is needed to fuse a dispatchable op with root operations +/// using tile + fuse. Using some heuristic, each root operation is tagged with +/// an ID (using an IntegerAttr with name `kRootOpAttr`) and all dispatchable +/// ops to be fused with it is tagged with the same ID (using a list of +/// IntegerAttr with name `kFusionGroupsAttr`). Each dispatchable operation can +/// be marked to fuse with multiple root operations (i.e. replicated). For now a +/// very simple heuristic is used below, but the mechanism should be general +/// enough to capture any heuristic. +static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp, + DominanceInfo const &dominanceInfo) { + unsigned numRootOps = 0; + MLIRContext *context = funcOp->getContext(); + OpBuilder builder(context); + for (Block &block : funcOp.getBody()) { + // Dispatch region formation works by first cloning the root into + // the dispatch region and then pulling operations in. + // So procedure here is to + // - First find the roots + // - To fuse with consumers make the consumer the root. + SmallVector roots; + for (Operation &op : llvm::reverse(block)) { + // Start with a root operation and fuse its producers. + if (hasFusionGroupsAttribute(&op) || !isRootOp(&op)) continue; + unsigned newGroup = numRootOps++; + setRootAttribute(context, &op, newGroup); + + fuseRootsWithProducers(context, &op, newGroup, dominanceInfo); + roots.push_back(&op); + } + roots = llvm::to_vector(llvm::reverse(roots)); + fuseRootsWithConsumers(context, roots, dominanceInfo); + } + + // Once all root linalg ops have been tagged, put all remaining generic ops + // into their own dispatches. + for (Block &block : funcOp.getBody()) { + SmallVector roots; + for (Operation &op : llvm::reverse(block)) { + // If it is part of a fusion group or root op, ignore it. + if (hasFusionGroupsAttribute(&op) || hasRootOpAttribute(&op)) continue; + // Only look for Linalg ops here. Avoid moving `linalg.fill` that aren't + // fused with anything else into their own dispatches since it is better + // to convert them to splats. + if (!isa(op) || isa(op)) continue; + + unsigned newGroup = numRootOps++; + setRootAttribute(context, &op, newGroup); + roots.push_back(&op); + } + roots = llvm::to_vector(llvm::reverse(roots)); + fuseRootsWithConsumers(context, roots, dominanceInfo); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n"; + funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + + return numRootOps; +} + +//===----------------------------------------------------------------------===// +// Dispatch region formation +//===----------------------------------------------------------------------===// + +/// Clone producers into the dispatch region. +static LogicalResult cloneProducers(RewriterBase &rewriter, + Flow::DispatchRegionOp regionOp) { + SmallVector cloneableOps = getCloneableOps(regionOp); + SmallVector orderedProducers = + Flow::orderOperations(cloneableOps); + + for (Operation *producer : llvm::reverse(orderedProducers)) + if (failed( + clonePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp))) + return failure(); + + return success(); +} + +/// Helper function that builds the workload region body. +static void buildWorkloadRegionBody(OpBuilder &builder, Location loc, + ArrayRef args) { + auto numWorkgroupsOp = + builder.create(loc, args); + builder.create(loc, numWorkgroupsOp.getResults()); +} + +/// Create Flow::DispatchGroupsOps based on a fusion heuristic. +static FailureOr> createFusionGroups( + TensorDimTrackingRewriter &rewriter, FunctionOpInterface funcOp, + DominanceInfo const &dominanceInfo, bool generateWorkloadRegion) { + // Decide fusion groups (heuristic). + unsigned numRoots = decideFusableLinalgOps(funcOp, dominanceInfo); + SmallVector roots(numRoots, nullptr); + DenseMap> producers; + + // TODO: Incrementally add ops to an empty DispatchGroupOp instead of + // annotating fusion group IDs via attributes. + funcOp.walk([&](Operation *op) { + if (hasRootOpAttribute(op)) roots[getRootNumber(op)] = op; + if (hasFusionGroupsAttribute(op)) { + assert(getFusionGroups(op).size() == 1 && "expected exactly one group"); + producers[getFusionGroups(op).front()].push_back(op); + } + }); + + // Create a DispatchRegionOp for every fusion group. + OpBuilder::InsertionGuard g(rewriter); + SmallVector regionOps; + DenseMap> workloads; + for (const auto &it : llvm::enumerate(roots)) { + // Compute workload. + SmallVector workload; + if (generateWorkloadRegion) { + rewriter.setInsertionPoint(it.value()); + FailureOr> maybeWorkload = + getWorkloadForRootOp(rewriter, it.value()); + if (failed(maybeWorkload)) return failure(); + workload = *maybeWorkload; + } + + // Simplify tensor::DimOps. + SmallVector dimOps = rewriter.getTensorDimOps(); + if (failed(simplifyDimOps(rewriter, dimOps))) return failure(); + + // Create fusion group. + Flow::DispatchRegionOp regionOp; + auto maybeRegionOp = Flow::wrapOpInDispatchRegion(rewriter, it.value()); + if (failed(maybeRegionOp)) return failure(); + regionOp = *maybeRegionOp; + workloads[regionOp] = workload; + + // Sort producers topologically. All producers must be in the same block as + // the root. + // TODO: Use mlir::computeTopologicalSorting. This is currently not possible + // because some of the producers are in different blocks. + SmallVector orderedProducers = + Flow::orderOperations(producers[it.index()]); + + // Move ops into the region. + for (Operation *producer : llvm::reverse(orderedProducers)) { + auto newRegionOp = + movePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp); + if (failed(newRegionOp)) return failure(); + regionOp = *newRegionOp; + } + + regionOps.push_back(regionOp); + } + + // Clone additional producers and rewrite to DispatchWorkgroupsOp. + SmallVector result; + for (auto regionOp : regionOps) { + if (failed(cloneProducers(rewriter, regionOp))) return failure(); + auto maybeWorkgroupOp = + Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups( + regionOp, rewriter, workloads[regionOp], + generateWorkloadRegion ? buildWorkloadRegionBody : nullptr); + if (failed(maybeWorkgroupOp)) return failure(); + + result.push_back(*maybeWorkgroupOp); + } + + return result; +} + +/// Wrap a single op in a DispatchWorkgroupsOp. +static FailureOr wrapInWorkgroupsOp( + TensorDimTrackingRewriter &rewriter, Operation *op, + bool generateWorkloadRegion) { + // Compute workload. + SmallVector workload; + if (generateWorkloadRegion) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + FailureOr> maybeWorkload = + getWorkloadForRootOp(rewriter, op); + if (failed(maybeWorkload)) return failure(); + workload = *maybeWorkload; + } + + // Simplify tensor::DimOps. + SmallVector dimOps = rewriter.getTensorDimOps(); + if (failed(simplifyDimOps(rewriter, rewriter.getTensorDimOps()))) + return failure(); + + // Wrap operation. + auto regionOp = Flow::wrapOpInDispatchRegion(rewriter, op); + if (failed(regionOp)) return failure(); + if (failed(cloneProducers(rewriter, *regionOp))) return failure(); + auto workgroupsOp = Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups( + *regionOp, rewriter, workload, + generateWorkloadRegion ? buildWorkloadRegionBody : nullptr); + if (failed(workgroupsOp)) return failure(); + return *workgroupsOp; +} + +/// Wrap all ops of the given type that are direct children of the given op in +/// a DispatchWorkgroupsOp. +template +static FailureOr> wrapInWorkgroupsOp( + TensorDimTrackingRewriter &rewriter, Operation *op, + bool generateWorkloadRegion) { + // Find ops of type OpTy. + SmallVector rootOps; + for (Region &r : op->getRegions()) + for (Block &b : r.getBlocks()) + for (auto op : b.getOps()) rootOps.push_back(op.getOperation()); + + // Wrap ops in DispatchWorkgroupsOps. + SmallVector result; + for (Operation *rootOp : rootOps) { + auto workgroupsOp = + wrapInWorkgroupsOp(rewriter, rootOp, generateWorkloadRegion); + if (failed(workgroupsOp)) return failure(); + result.push_back(*workgroupsOp); + } + return result; +} + +namespace { +/// Pass declaration. +struct DispatchLinalgOnTensorsViaRegionOpsPass + : public Flow::DispatchLinalgOnTensorsViaRegionOpsBase< + DispatchLinalgOnTensorsViaRegionOpsPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + DispatchLinalgOnTensorsViaRegionOpsPass(bool generateWorkloadRegion) { + this->generateWorkloadRegion = generateWorkloadRegion; + } + DispatchLinalgOnTensorsViaRegionOpsPass( + const DispatchLinalgOnTensorsViaRegionOpsPass &pass) { + this->generateWorkloadRegion = pass.generateWorkloadRegion; + } + void runOnOperation() override; + + private: + bool generateWorkloadRegion = true; +}; +} // namespace + +void DispatchLinalgOnTensorsViaRegionOpsPass::runOnOperation() { + auto funcOp = getOperation(); + MLIRContext *context = &getContext(); + + DominanceInfo const &dominanceInfo = getAnalysis(); + TensorDimTrackingRewriter rewriter(funcOp); + + // Step 1: Create a DispatchWorkgroupsOp for every fusion group. + auto maybeWorkgroupsOps = createFusionGroups(rewriter, funcOp, dominanceInfo, + generateWorkloadRegion); + if (failed(maybeWorkgroupsOps)) return signalPassFailure(); + SmallVector workgroupsOps = *maybeWorkgroupsOps; + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After first step of dispatch region formation ---\n"; + funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + + // Step 2: Create a DispatchWorkgroupsOp for every remaining InsertSliceOp. + FailureOr> newWorkgroupsOps = + wrapInWorkgroupsOp(rewriter, funcOp, + generateWorkloadRegion); + if (failed(newWorkgroupsOps)) return signalPassFailure(); + workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end()); + + // Step 3: Create a DispatchWorkgroupsOp for every remaining ExtractSliceOp. + newWorkgroupsOps = wrapInWorkgroupsOp( + rewriter, funcOp, generateWorkloadRegion); + if (failed(newWorkgroupsOps)) return signalPassFailure(); + workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end()); + + // A few extra canonicalizations/lowerings. + { + RewritePatternSet convertToFlowPatterns(context); + Flow::populateTensorToFlowConversionPatterns(context, + convertToFlowPatterns); + memref::populateResolveRankedShapeTypeResultDimsPatterns( + convertToFlowPatterns); + IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns( + convertToFlowPatterns, context); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(convertToFlowPatterns)))) + return signalPassFailure(); + + // Finally fold `tensor.insert_slice/extract_slice` operations with + // `flow.dispatch.tensor.load/store`. + RewritePatternSet foldExtractInsertSliceOps(context); + Flow::populateTensorSliceOpWithDispatchTensorOpFoldingPatterns( + foldExtractInsertSliceOps, context); + if (failed(applyPatternsAndFoldGreedily( + funcOp, std::move(foldExtractInsertSliceOps)))) + return signalPassFailure(); + } + + // Finally walk all the ops and remove the attributes + funcOp.walk([](Operation *op) { + removeFusionGroupsAttribute(op); + removeRootOpAttribute(op); + op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker); + }); +} + +std::unique_ptr> +Flow::createDispatchLinalgOnTensorsViaRegionOpsPass( + bool generateWorkloadRegion) { + return std::make_unique( + generateWorkloadRegion); +} diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp index 6c530afd3687..b8b64ab2de66 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/Dominance.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-flow-fusion-of-tensor-ops" @@ -30,6 +31,34 @@ namespace iree_compiler { namespace IREE { namespace Flow { +/// Check if any of the use dominates all other uses of the operation. +static Optional getFusableUse(Operation *op, + DominanceInfo &dominanceInfo) { + auto uses = op->getUses(); + for (OpOperand &source : uses) { + Operation *sourceOp = source.getOwner(); + bool dominatesAllUsers = true; + for (OpOperand &target : uses) { + Operation *targetOp = target.getOwner(); + if (!dominanceInfo.dominates(sourceOp, targetOp)) { + dominatesAllUsers = false; + break; + } + } + if (dominatesAllUsers) { + // For now check that the `sourceOp` is only used once in the consumer. + // This can be generalized if needed + unsigned numUsesOfOp = 0; + for (OpOperand &operand : sourceOp->getOpOperands()) { + if (operand.get().getDefiningOp() == op) numUsesOfOp++; + } + if (numUsesOfOp != 1) return llvm::None; + return &source; + } + } + return llvm::None; +} + /// Check if the producer generic op is fusable with the consumer generic op. static bool areFusableOps(MLIRContext *context, Operation *producerOp, Operation *consumerOp) { @@ -57,6 +86,107 @@ static bool areFusableOps(MLIRContext *context, Operation *producerOp, namespace { +struct FuseElementwiseOpsWithMultipleUses + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static const char *getConsumerAttributeName() { + return "__fusable_conumer__"; + } + static const char *getProducerAttributeName() { + return "__fusable_producer__"; + } + + LogicalResult matchAndRewrite(linalg::GenericOp consumerOp, + PatternRewriter &rewriter) const override { + auto consumerMarker = + consumerOp->getAttrOfType(getConsumerAttributeName()); + if (!consumerMarker) return failure(); + + auto fusedOperandIt = + llvm::find_if(consumerOp->getOpOperands(), [&](OpOperand &operand) { + Operation *operandProducer = operand.get().getDefiningOp(); + if (!operandProducer) return false; + auto producerMarker = operandProducer->getAttrOfType( + getProducerAttributeName()); + if (!producerMarker) return false; + return consumerMarker.getValue() == producerMarker.getValue(); + }); + assert(fusedOperandIt != consumerOp->getOpOperands().end() && + "expected to find the fusable producer"); + OpOperand *fusedOperand = fusedOperandIt; + assert(linalg::areElementwiseOpsFusable(fusedOperand) && + "expected producer and consumer to be fusable"); + Operation *producerOp = fusedOperand->get().getDefiningOp(); + + // Cleanup the markers. + consumerOp->removeAttr(getConsumerAttributeName()); + producerOp->removeAttr(getProducerAttributeName()); + + FailureOr fusedOperation = + linalg::fuseElementwiseOps(rewriter, fusedOperand); + if (failed(fusedOperation)) { + return rewriter.notifyMatchFailure(consumerOp, + "failed to fuse with producer"); + } + assert(fusedOperation.value()->getNumResults() == + producerOp->getNumResults() + consumerOp->getNumResults()); + auto fusedResults = fusedOperation.value()->getResults(); + rewriter.replaceOp(producerOp, + fusedResults.take_front(producerOp->getNumResults())); + rewriter.replaceOp(consumerOp, + fusedResults.take_back(consumerOp->getNumResults())); + return success(); + } +}; + +static FailureOr fuseMultiUseProducers(Operation *funcOp, + MLIRContext *context, + DominanceInfo &dominanceInfo) { + // Try fusion of operations when producer has multiple uses. + // 1. Walk the function in pre-order. + // 2. Check if a `linalg.generic` op has a consumer `linalg.generic` op + // that dominates all uses of the producer op. Then fuse the producer + // consumer + unsigned numCandidates = 0; + OpBuilder builder(context); + funcOp->walk([&](linalg::GenericOp genericOp) { + auto consumerAttrName = + FuseElementwiseOpsWithMultipleUses::getConsumerAttributeName(); + auto producerAttrName = + FuseElementwiseOpsWithMultipleUses::getProducerAttributeName(); + if (genericOp->hasAttr(consumerAttrName) || + genericOp->hasAttr(producerAttrName)) { + return; + } + + Optional fusableUse = getFusableUse(genericOp, dominanceInfo); + if (!fusableUse) return; + if (!linalg::areElementwiseOpsFusable(fusableUse.value())) return; + + Operation *consumer = fusableUse.value()->getOwner(); + genericOp->setAttr(producerAttrName, + builder.getI64IntegerAttr(numCandidates)); + consumer->setAttr(consumerAttrName, + builder.getI64IntegerAttr(numCandidates)); + numCandidates++; + return; + }); + LLVM_DEBUG({ + llvm::dbgs() << "Num of multiuse fusable candidates : " << numCandidates + << "\n"; + funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + }); + RewritePatternSet fusionPatterns(context); + fusionPatterns.insert(context); + linalg::GenericOp::getCanonicalizationPatterns(fusionPatterns, context); + if (failed(applyPatternsAndFoldGreedily(funcOp->getRegions(), + std::move(fusionPatterns)))) { + return funcOp->emitOpError("multi use producer -> consumer fusion failed"); + } + return numCandidates; +} + /// Pass to fuse linalg on tensor operations as well as fusion of hal.interface* /// operations with linalg.tensor_reshape operation. struct FusionOfTensorOpsPass @@ -64,120 +194,157 @@ struct FusionOfTensorOpsPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } + FusionOfTensorOpsPass(bool fuseMultiUse, unsigned multiUseFusionIteration) { + this->fuseMultiUse = fuseMultiUse; + this->multiUseFusionIteration = multiUseFusionIteration; + } + FusionOfTensorOpsPass(const FusionOfTensorOpsPass &pass) + : FusionOfTensorOpsPass(pass.fuseMultiUse, pass.multiUseFusionIteration) { + } void runOnOperation() override { - RewritePatternSet fusionPatterns(&getContext()); - Operation *op = getOperation(); - MLIRContext *context = op->getContext(); - - // Only fuse operations where all uses of the producer are generic - // operations. If an operation is used in a named op, it will be computed - // anyway, so the consumers can just use that value. - linalg::ControlFusionFn fuseElementwiseOpsControlFn = - [&](OpOperand *fusedOperand) { - Operation *producer = fusedOperand->get().getDefiningOp(); - if (!producer) return false; - Operation *consumer = fusedOperand->getOwner(); - - // Limit the number of operands. We have hard limit (32) of bindings - // passing down to HAL. Set the number to be as same as the limit -- - // IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT. - constexpr int64_t kIreeMaxOperandCount = 32; - DenseSet operands; - operands.insert(producer->operand_begin(), producer->operand_end()); - operands.insert(consumer->operand_begin(), - std::next(consumer->operand_begin(), - fusedOperand->getOperandNumber())); - operands.insert(std::next(consumer->operand_begin(), - fusedOperand->getOperandNumber() + 1), - consumer->operand_end()); - if (operands.size() >= kIreeMaxOperandCount) return false; - - return areFusableOps(context, producer, consumer); - }; - linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, - fuseElementwiseOpsControlFn); - - // Always fold reshape by expansion. - linalg::ControlFusionFn fuseByExpansionControlFn = - [](OpOperand *fusedOperand) { - Operation *producer = fusedOperand->get().getDefiningOp(); - if (!producer) { - return false; - } - // Do not fuse producer generic op if it has more than one user. - if (auto producerGenericOp = dyn_cast(producer)) { - return producerGenericOp->hasOneUse(); - } - // Fuse in all other cases. - return true; - }; - linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, - fuseByExpansionControlFn); - - // Constant fold Linalg operations. - auto constantFoldControlFn = [](OpOperand *fusedOperand) { - auto producer = fusedOperand->get().getDefiningOp(); - return producer && producer->hasOneUse(); - }; - linalg::populateConstantFoldLinalgOperations(fusionPatterns, - constantFoldControlFn); - - AffineApplyOp::getCanonicalizationPatterns(fusionPatterns, context); - linalg::GenericOp::getCanonicalizationPatterns(fusionPatterns, context); - tensor::ExpandShapeOp::getCanonicalizationPatterns(fusionPatterns, context); - tensor::CollapseShapeOp::getCanonicalizationPatterns(fusionPatterns, + Operation *funcOp = getOperation(); + MLIRContext *context = funcOp->getContext(); + + { + RewritePatternSet fusionPatterns(&getContext()); + // Only fuse operations where all uses of the producer are generic + // operations. If an operation is used in a named op, it will be computed + // anyway, so the consumers can just use that value. + linalg::ControlFusionFn fuseElementwiseOpsControlFn = + [&](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) return false; + Operation *consumer = fusedOperand->getOwner(); + + // Limit the number of operands. We have hard limit (32) of bindings + // passing down to HAL. Set the number to be as same as the limit -- + // IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT. + constexpr int64_t kIreeMaxOperandCount = 32; + DenseSet operands; + operands.insert(producer->operand_begin(), producer->operand_end()); + operands.insert(consumer->operand_begin(), + std::next(consumer->operand_begin(), + fusedOperand->getOperandNumber())); + operands.insert(std::next(consumer->operand_begin(), + fusedOperand->getOperandNumber() + 1), + consumer->operand_end()); + if (operands.size() >= kIreeMaxOperandCount) return false; + + return areFusableOps(context, producer, consumer); + }; + linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, + fuseElementwiseOpsControlFn); + + // Always fold reshape by expansion. + linalg::ControlFusionFn fuseByExpansionControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) { + return false; + } + // Do not fuse producer generic op if it has more than one user. + if (auto producerGenericOp = + dyn_cast(producer)) { + return producerGenericOp->hasOneUse(); + } + // Fuse in all other cases. + return true; + }; + linalg::populateFoldReshapeOpsByExpansionPatterns( + fusionPatterns, fuseByExpansionControlFn); + + // Constant fold Linalg operations. + auto constantFoldControlFn = [](OpOperand *fusedOperand) { + auto producer = fusedOperand->get().getDefiningOp(); + return producer && producer->hasOneUse(); + }; + linalg::populateConstantFoldLinalgOperations(fusionPatterns, + constantFoldControlFn); + + AffineApplyOp::getCanonicalizationPatterns(fusionPatterns, context); + linalg::GenericOp::getCanonicalizationPatterns(fusionPatterns, context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(fusionPatterns, context); - context->getLoadedDialect() - ->getCanonicalizationPatterns(fusionPatterns); - memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns); + tensor::CollapseShapeOp::getCanonicalizationPatterns(fusionPatterns, + context); + context->getLoadedDialect() + ->getCanonicalizationPatterns(fusionPatterns); + memref::populateResolveRankedShapeTypeResultDimsPatterns(fusionPatterns); + + if (failed(applyPatternsAndFoldGreedily(funcOp->getRegions(), + std::move(fusionPatterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After first fixed point ---\n"; + funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + { + // For fusion by collapsing, do so if the reshape is blocking tile and + // fuse. + linalg::ControlFusionFn fuseByCollapsingControlFn = + [](OpOperand *fusedOperand) { + auto producer = fusedOperand->get().getDefiningOp(); + if (!producer) { + return false; + } + + auto reshapeOp = dyn_cast(producer); + if (!reshapeOp) return true; + + return reshapeOp.getSrc().getDefiningOp() != + nullptr; + }; + + RewritePatternSet collapsingReshapePatterns(&getContext()); + linalg::populateFoldReshapeOpsByCollapsingPatterns( + collapsingReshapePatterns, fuseByCollapsingControlFn); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + collapsingReshapePatterns, context); + tensor::ExpandShapeOp::getCanonicalizationPatterns( + collapsingReshapePatterns, context); + memref::populateResolveRankedShapeTypeResultDimsPatterns( + collapsingReshapePatterns); + if (failed(applyPatternsAndFoldGreedily( + funcOp->getRegions(), std::move(collapsingReshapePatterns)))) { + return signalPassFailure(); + } - if (failed(applyPatternsAndFoldGreedily(op->getRegions(), - std::move(fusionPatterns)))) { - return signalPassFailure(); + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After second fixed point ---\n"; + funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); } - LLVM_DEBUG({ - llvm::dbgs() << "\n--- After first fixed point ---\n"; - op->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - - // For fusion by collapsing, do so if the reshape is blocking tile and fuse. - linalg::ControlFusionFn fuseByCollapsingControlFn = - [](OpOperand *fusedOperand) { - auto producer = fusedOperand->get().getDefiningOp(); - if (!producer) { - return false; - } - - auto reshapeOp = dyn_cast(producer); - if (!reshapeOp) return true; - - return reshapeOp.getSrc().getDefiningOp() != - nullptr; - }; - - RewritePatternSet collapsingReshapePatterns(&getContext()); - linalg::populateFoldReshapeOpsByCollapsingPatterns( - collapsingReshapePatterns, fuseByCollapsingControlFn); - tensor::CollapseShapeOp::getCanonicalizationPatterns( - collapsingReshapePatterns, context); - tensor::ExpandShapeOp::getCanonicalizationPatterns( - collapsingReshapePatterns, context); - memref::populateResolveRankedShapeTypeResultDimsPatterns( - collapsingReshapePatterns); - if (failed(applyPatternsAndFoldGreedily( - op->getRegions(), std::move(collapsingReshapePatterns)))) { - return signalPassFailure(); + if (fuseMultiUse) { + // Run fusion of producer with consumer when producer has multiple uses. + // For now run this sequence a fixed times (2 by default). Ideally we + // would run it till no candidates exist. + for (auto i : llvm::seq(0, multiUseFusionIteration)) { + (void)i; + auto &dominanceInfo = getAnalysis(); + FailureOr numOfFusableCandidates = + fuseMultiUseProducers(funcOp, context, dominanceInfo); + if (failed(numOfFusableCandidates)) return signalPassFailure(); + if (numOfFusableCandidates.value() == 0) break; + } } } }; } // namespace -std::unique_ptr createFusionOfTensorOpsPass() { - return std::make_unique(); +std::unique_ptr> +createFusionOfTensorOpsPass(bool fuseMultiUse, + unsigned multiUseFusionIteration) { + return std::make_unique(fuseMultiUse, + multiUseFusionIteration); } } // namespace Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp index 616b9642d83d..a88811056c76 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp @@ -38,16 +38,25 @@ static llvm::cl::opt clFuseReductionBroadcastElementwise( /// indexing map. // TODO: This restriction can go away if we can vectorize always, but that has // a long tail of tasks. -static bool canInsOperandTieWithOutsOperand(OpOperand *insOperand) { +bool isInsOperandBufferizable(OpOperand *insOperand, bool aggressiveFusion) { + // Ignore the check if in-place bufferization is not required. + if (!clEnsureInplaceableConsumer) return true; + auto linalgOp = dyn_cast(insOperand->getOwner()); if (!linalgOp) return false; AffineMap insOperandIndexingMap = linalgOp.getTiedIndexingMap(insOperand); auto canTieWithOutsOperand = [&](OpOperand *outsOperand) { - if (linalgOp.getTiedIndexingMap(outsOperand) != insOperandIndexingMap) { - return false; + AffineMap outsOperandIndexingMap = linalgOp.getTiedIndexingMap(outsOperand); + + if (outsOperandIndexingMap != insOperandIndexingMap) { + if (!aggressiveFusion) return false; + // If the operand is a projected permutation a small stack might be + // fine. + if (!insOperandIndexingMap.isProjectedPermutation()) return false; } + // TODO(#8411): Until ops are vectorized (always), we need // to check that the elementtype matches for the operands to be tied. // For now just doing this check for convolution ops since we expect @@ -219,7 +228,7 @@ bool areLinalgOpsFusableUsingTileAndFuse(OpOperand &use) { // 4. In-place bufferization requirements (for now) require that the use in // the consumer can re-use the buffer for a result. - return !clEnsureInplaceableConsumer || canInsOperandTieWithOutsOperand(&use); + return isInsOperandBufferizable(&use, /*aggressiveFusion=*/false); } } // namespace Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h index 245ad754cb7a..2f1b5f40bf23 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h @@ -18,6 +18,10 @@ namespace iree_compiler { namespace IREE { namespace Flow { +/// Returns true if the `ins` operand can be properly bufferized after the +/// fusion. +bool isInsOperandBufferizable(OpOperand *insOperand, bool aggressiveFusion); + /// Returns true if the `use` is from a producer linalg op that can be fused /// with the consumer linalg op using tile + fuse. bool areLinalgOpsFusableUsingTileAndFuse(OpOperand &use); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp index 3f0966d05a44..0d187ea2a572 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp @@ -35,7 +35,7 @@ struct GenericOpInterchangePattern unsigned numParallelLoop = genericOp.getNumParallelLoops(); if (numParallelLoop == 0) return failure(); for (auto iter : llvm::enumerate(genericOp.iterator_types())) { - if (isParallelIterator(iter.value())) { + if (linalg::isParallelIterator(iter.value())) { interchange.push_back(iter.index()); if (iter.index() >= numParallelLoop) needInterchange = true; } @@ -43,7 +43,7 @@ struct GenericOpInterchangePattern // If all the parallel loops are outter loops skip the pattern. if (!needInterchange) return failure(); for (auto iter : llvm::enumerate(genericOp.iterator_types())) { - if (isReductionIterator(iter.value())) { + if (linalg::isReductionIterator(iter.value())) { interchange.push_back(iter.index()); } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 1492a729a19e..11ffd5e97361 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -66,9 +66,9 @@ static llvm::cl::opt clEnablePaddingLinalgOps( "flow-padding-size"), llvm::cl::init(false)); -static llvm::cl::opt clEnableFusePaddingIntoConsumerOps( - "iree-flow-enable-fuse-padding-into-consumer-ops", - llvm::cl::desc("Enable fusing linalg pad_tensor ops into consumer ops"), +static llvm::cl::opt clEnableFusePaddingIntoLinalgConsumerOps( + "iree-flow-enable-fuse-padding-into-linalg-consumer-ops", + llvm::cl::desc("Enable fusing tensor.pad ops into Linalg consumer ops"), llvm::cl::init(false)); static llvm::cl::opt clLinalgOpsPaddingSize( @@ -84,6 +84,13 @@ static llvm::cl::opt clEnableLinalgDetensorize( llvm::cl::desc("Enable detensorizing linalg ops to operate on primitives"), llvm::cl::init(true)); +static llvm::cl::opt clEnableAggressiveFusion( + "iree-flow-enable-aggressive-fusion", + llvm::cl::desc( + "Enable the aggressive fusion heuristic to fuse multiuse ops and ops " + "with reduction loops"), + llvm::cl::init(false)); + static llvm::cl::opt clMmt4dTargetOptions( "iree-flow-mmt4d-target-options", llvm::cl::desc("Convert linalg.matmul ops to MMT4D ops targetting the " @@ -110,6 +117,17 @@ static llvm::cl::opt clDispatchTransformFileName( "the transformations to apply to form dispatch regions."), llvm::cl::init("")); +static llvm::cl::opt clDispatchViaRegionOps( + "iree-flow-dispatch-via-region-ops", + llvm::cl::desc("Create dispatches via DispatchRegionOps"), + llvm::cl::init(false)); + +static llvm::cl::opt clDispatchViaRegionOpsGenerateWorkloadRegion( + "iree-flow-dispatch-via-region-ops-generate-workload-region", + llvm::cl::desc("Generate the workload region when running with " + "iree-flow-dispatch-via-region-ops"), + llvm::cl::init(true)); + namespace mlir { namespace iree_compiler { namespace IREE { @@ -187,8 +205,9 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager, .addPass(IREE::Flow::createConvertConv2D1x1ToMatmulPass) .addPredicatedPass(clEnableConvToImg2Col, IREE::Flow::createConvertConv2DToImg2ColPass) - .addPredicatedPass(clDispatchTransformFileName.empty(), - IREE::Flow::createDetachElementwiseFromNamedOpsPass) + .addPredicatedPass( + clDispatchTransformFileName.empty() && !clDispatchViaRegionOps, + IREE::Flow::createDetachElementwiseFromNamedOpsPass) // Input should now be legal. .addPass(IREE::Flow::createVerifyInputLegalityPass) // Catch matmul ops before we do anything else with them. @@ -212,11 +231,11 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager, passManager.addPass(IREE::Flow::createExpandTensorShapesPass()); buildGlobalOptimizationPassPipeline(passManager, transformOptions); - FunctionLikeNest(passManager) - // Pad tensors. - .addPredicatedPass((!clEnableFusePaddingIntoConsumerOps), - IREE::Flow::createPadTensorToTensorInsertSlicePass) + // Pad tensors. + passManager.addPass(IREE::Flow::createTensorPadToTensorInsertSlicePass( + /*skipSingleLinalgOpUses=*/clEnableFusePaddingIntoLinalgConsumerOps)); + FunctionLikeNest(passManager) // Preprocess the input to a form more amenable for fusion // - Convert all elementwise ops to Linalg // - Remove unit-extent dimensions. @@ -226,9 +245,10 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager, .addPass(memref::createResolveShapedTypeResultDimsPass) .addPass(mlir::createCanonicalizerPass) .addPass(mlir::createCSEPass) - // Elementwise fusion. - .addPass(createFusionOfTensorOpsPass) + .addPass([]() { + return createFusionOfTensorOpsPass(clEnableAggressiveFusion); + }) .addPredicatedPass(clEnableLinalgDetensorize, mlir::createLinalgDetensorizePass) .addPass(mlir::createCanonicalizerPass) @@ -252,8 +272,21 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager, clDispatchTransformFileName); }) // Only want use the transform dialect for some dispatch regions and let - // the DispatchLinalgOnTensorsPass unconditionally handle the rest. - .addPass(createDispatchLinalgOnTensorsPass) + // the DispatchLinalgOnTensorsPass handle the rest. + .addPredicatedPass( + !clDispatchViaRegionOps, + []() { + return createDispatchLinalgOnTensorsPass(clEnableAggressiveFusion); + }) + // DispatchLinalgOnTensorsViaRegionsPass is a variant of + // DispatchLinalgOnTensorsPass that lowers via DispatchRegionOps. This is + // on an opt-in basis until the pass is stable enough to replace + // DispatchLinalgOnTensorsPass. + .addPredicatedPass(clDispatchViaRegionOps, + [&]() { + return createDispatchLinalgOnTensorsViaRegionOpsPass( + clDispatchViaRegionOpsGenerateWorkloadRegion); + }) //////////////////////////////////////////////////////////////////////// .addPass(createCaptureDispatchDynamicDimsPass) .addPass(mlir::createCanonicalizerPass) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h index 6805810901a3..71a122688608 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -82,9 +82,10 @@ std::unique_ptr createConvertConv2DToImg2ColPass(); // Creates a pass to convert dispatch.region ops to dispatch.workgroups ops. std::unique_ptr createConvertRegionToWorkgroupsPass(); -// Pass to convert a linalg.pad_tensor operation into a linalg.fill + -// subtensor_insert. This allows lowering the operation into a single kernel. -std::unique_ptr createPadTensorToTensorInsertSlicePass(); +// Pass to convert a tensor.pad operation into a linalg.fill + +// tensor.insert_slice. +std::unique_ptr createTensorPadToTensorInsertSlicePass( + bool skipSingleLinalgOpUses = false); // Pass to convert a linalg.matmul into linalg.mmt4d given some target ISA // information currently passed as pass options. @@ -97,7 +98,9 @@ std::unique_ptr createConvertLinalgMatmulToMmt4DPass(StringRef options); std::unique_ptr createDetachElementwiseFromNamedOpsPass(); // Creates a pass to fuse Linalg operations on tensors. -std::unique_ptr createFusionOfTensorOpsPass(); +std::unique_ptr> +createFusionOfTensorOpsPass(bool fuseMultiUse = false, + unsigned multiUseFusionIteration = 2); // Infers and inserts util.numeric.optional_narrow ops at points that may be // beneficial. @@ -140,7 +143,14 @@ std::unique_ptr createVerifyInputLegalityPass(); // Pass to perform dispatch of Linalg on tensor ops by tiling and distribution. // A dispatch region is created for each tiled loop nest. std::unique_ptr> -createDispatchLinalgOnTensorsPass(); +createDispatchLinalgOnTensorsPass(bool aggressiveFusion = false); + +// Pass to perform dispatch of Linalg on tensor ops by tiling and distribution. +// A dispatch region is created for each tiled loop nest. (First create +// DispatchRegionOps, then DispatchWorkgroupsOps.) +std::unique_ptr> +createDispatchLinalgOnTensorsViaRegionOpsPass( + bool generateWorkloadRegion = true); // Pass to perform dispatch of Linalg on tensor ops by using the transform // dialect. Dispatch regions are created as specified by the transform module diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index 593726cc0107..3c4f6f8b4d23 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -66,6 +66,16 @@ def DispatchLinalgOnTensors : InterfacePass<"iree-flow-dispatch-linalg-on-tensors-pass", "mlir::FunctionOpInterface"> { let summary = "Dispatch Linalg operations on tensors by using tile and distribute"; let constructor = "mlir::iree_compiler::IREE::Flow::createDispatchLinalgOnTensorsPass()"; + let options = [ + Option<"aggressiveFusion", "aggressive-fusion", "bool", + /*default=*/"false", "Fuse with aggressive heuristics">, + ]; +} + +def DispatchLinalgOnTensorsViaRegionOps : + InterfacePass<"iree-flow-dispatch-linalg-on-tensors-via-regionops-pass", "mlir::FunctionOpInterface"> { + let summary = "Dispatch Linalg operations on tensors by using tile and distribute (via DispatchRegionOps)"; + let constructor = "mlir::iree_compiler::IREE::Flow::createDispatchLinalgOnTensorsViaRegionOpsPass()"; } def DispatchWithTransformDialect : @@ -92,9 +102,15 @@ def ExportBenchmarkFuncs : } def FusionOfTensorOps : - Pass<"iree-flow-fusion-of-tensor-ops", ""> { + InterfacePass<"iree-flow-fusion-of-tensor-ops", "mlir::FunctionOpInterface"> { let summary = "Fuse operations on tensors"; let constructor = "mlir::iree_compiler::IREE::Flow::createFusionOfTensorOpsPass()"; + let options = [ + Option<"fuseMultiUse", "fuse-multi-use", "bool", + /*default=*/"false", "Fuse ops with multiuse">, + Option<"multiUseFusionIteration", "multi-use-fusion-iteration", "unsigned", + /*default=*/"2", "Number of iterations to fuse multiuse ops"> + ]; } def InferNumericNarrowing : @@ -161,10 +177,16 @@ def ConvertLinalgMatmulToMmt4D : ]; } -def PadTensorToTensorInsertSlice : - Pass<"iree-flow-pad-tensor-to-tensor-insert-slice", ""> { - let summary = "Convert linalg.pad_tensor into linalg.fill + tensor.insert_slice"; - let constructor = "mlir::iree_compiler::IREE::Flow::createPadTensorToTensorInsertSlicePass()"; +def TensorPadToTensorInsertSlice : + Pass<"iree-flow-tensor-pad-to-tensor-insert-slice", ""> { + let summary = "Convert tensor.pad into linalg.fill + tensor.insert_slice"; + let constructor = "mlir::iree_compiler::IREE::Flow::createTensorPadToTensorInsertSlicePass()"; + let options = [ + Option<"optionSkipSingleLinalgOpUses", "skip-one-linalg-use-case", "bool", + /*default=*/"false", + "Skip the op that has only one use which is used" + "by a Linalg op">, + ]; } def DumpDispatchGraph : Pass<"iree-flow-dump-dispatch-graph-pass"> { diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 6c06596c8567..81cf629a3662 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -7,14 +7,74 @@ #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dominance.h" -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace Flow { +using namespace mlir; +using namespace mlir::iree_compiler; +using namespace mlir::iree_compiler::IREE; + +#define DEBUG_TYPE "iree-flow-region-op-utils" + +static SmallVector getLoopRangesImpl(TilingInterface tilableOp, + Location loc, OpBuilder &builder) { + SmallVector loopRanges = tilableOp.getIterationDomain(builder); + Value one = builder.create(loc, 1); + for (auto iteratorType : llvm::enumerate(tilableOp.getLoopIteratorTypes())) { + if (iteratorType.value() == getReductionIteratorTypeName()) { + loopRanges[iteratorType.index()].size = one; + } + } + return loopRanges; +} + +static SmallVector getLoopRangesImpl(tensor::InsertSliceOp insertSliceOp, + Location loc, OpBuilder &builder) { + OpFoldResult zero = builder.getIndexAttr(0); + OpFoldResult one = builder.getIndexAttr(1); + Value source = insertSliceOp.getSource(); + SmallVector loopRanges(insertSliceOp.getSourceType().getRank(), + Range{zero, one, one}); + for (auto dim : llvm::seq(0, loopRanges.size())) { + loopRanges[dim].size = + builder.create(loc, source, dim).getResult(); + } + return loopRanges; +} + +static SmallVector getLoopRangesImpl(tensor::ExtractSliceOp sliceOp, + Location loc, OpBuilder &builder) { + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + ReifiedRankedShapedTypeDims resultDims; + LogicalResult status = sliceOp.reifyResultShapes(builder, resultDims); + (void)status; + assert(succeeded(status) && "reifyResultShapes failed"); + return llvm::to_vector(llvm::map_range(resultDims[0], [&](Value v) { + return Range{zero, v, one}; + })); +} + +/// For a given operation returns the loop ranges needed to compute the op. +SmallVector Flow::getLoopRanges(Operation *op, Location loc, + OpBuilder &builder) { + return llvm::TypeSwitch>(op) + .Case([&](TilingInterface op) { + return getLoopRangesImpl(op, loc, builder); + }) + .Case([&](tensor::InsertSliceOp op) { + return getLoopRangesImpl(op, loc, builder); + }) + .Case([&](tensor::ExtractSliceOp op) { + return getLoopRangesImpl(op, loc, builder); + }) + .Default([](Operation *op) -> SmallVector { + llvm_unreachable("op not supported"); + }); +} /// Return `true` if the given type is a ShapedType and has at least one /// dynamic dimension. @@ -25,8 +85,8 @@ static bool hasDynamicShape(Type t) { } /// Reify the dynamic dimensions of the given value. -static LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value, - SmallVector &dynamicDims) { +LogicalResult Flow::reifyDynamicResultDims(OpBuilder &b, Value value, + SmallVector &dynamicDims) { OpBuilder::InsertionGuard guard(b); // Case 1: No dynamic result dims. @@ -85,7 +145,7 @@ static LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value, // Append a result to the given DispatchRegionOp. The newly created // DispatchRegionOp is returned. -FailureOr appendDispatchRegionResult( +FailureOr Flow::appendDispatchRegionResult( RewriterBase &rewriter, Flow::DispatchRegionOp regionOp, Value result) { OpBuilder::InsertionGuard guard(rewriter); @@ -119,8 +179,8 @@ FailureOr appendDispatchRegionResult( return newRegionOp; } -Flow::DispatchRegionOp makeEmptyDispatchRegion(OpBuilder &builder, - Location loc) { +Flow::DispatchRegionOp Flow::makeEmptyDispatchRegion(OpBuilder &builder, + Location loc) { OpBuilder::InsertionGuard guard(builder); // Create RegionOp. @@ -135,7 +195,7 @@ Flow::DispatchRegionOp makeEmptyDispatchRegion(OpBuilder &builder, // Clone a `target` op that is preceding the given dispatch region op into the // dispatch region. -LogicalResult clonePrecedingOpIntoDispatchRegion( +LogicalResult Flow::clonePrecedingOpIntoDispatchRegion( RewriterBase &rewriter, Operation *target, Flow::DispatchRegionOp regionOp) { Block &body = regionOp.getBody().front(); @@ -165,7 +225,7 @@ LogicalResult clonePrecedingOpIntoDispatchRegion( // Move a `target` op that is preceding the given dispatch region op into the // dispatch region. -FailureOr movePrecedingOpIntoDispatchRegion( +FailureOr Flow::movePrecedingOpIntoDispatchRegion( RewriterBase &rewriter, Operation *target, Flow::DispatchRegionOp regionOp) { #ifndef NDEBUG @@ -214,8 +274,8 @@ FailureOr movePrecedingOpIntoDispatchRegion( return regionOp; } -FailureOr wrapOpInDispatchRegion(RewriterBase &rewriter, - Operation *op) { +FailureOr Flow::wrapOpInDispatchRegion( + RewriterBase &rewriter, Operation *op) { // Make an empty dispatch region right before the op. rewriter.setInsertionPointAfter(op); Flow::DispatchRegionOp regionOp = @@ -226,7 +286,84 @@ FailureOr wrapOpInDispatchRegion(RewriterBase &rewriter, return newRegionOp; } -} // namespace Flow -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir +/// Reorders the operations in `ops` such that they could be inlined into the +/// dispatch region in that order to satisfy dependencies. +SmallVector Flow::orderOperations(ArrayRef ops) { + LLVM_DEBUG({ + llvm::dbgs() << "Ops to be inlined :\n"; + for (auto op : ops) { + llvm::dbgs() << "\t"; + op->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + } + }); + + llvm::SmallMapVector, 16> + insertAfterMap; + llvm::SetVector opSet(ops.begin(), ops.end()); + llvm::SetVector leafOps(ops.begin(), ops.end()); + // For each operation compute the list of operations in `ops` that use its + // results. Also compute the operations that form the leafs of the DAG of + // operations in `ops`. + for (auto op : ops) { + for (auto operand : op->getOperands()) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp || !opSet.count(definingOp)) continue; + insertAfterMap[definingOp].push_back(op); + if (leafOps.count(op)) leafOps.remove(op); + } + } + + // The leaves are at the head of the ordered list. + SmallVector orderedOps(leafOps.begin(), leafOps.end()); + orderedOps.reserve(ops.size()); + llvm::SmallPtrSet processed; + processed.insert(leafOps.begin(), leafOps.end()); + + // `readyOps` contains the list of operations that have been just added to the + // `orderedOps` list. With these marked ready, they might make further + // operations in `ops` ready as well. + // The complexity of the algorithm is driven by these + // - Each operations is added to `readyOps` list at most once, and is removed + // after being processed + // - For every operation in `readyOps` every use of its results (within `ops`) + // is looked at once. + // - For every use, the operands of the user are processed. + // Assuming operands is O(1), i.e. constant order, the complexity is O(sum of + // number of uses of each operation). Given that the size of `ops` is at max + // O(10), and not O(100), this is assumed to be reasonable. + ArrayRef readyOps(orderedOps); + size_t startPos = 0; + while (!readyOps.empty()) { + auto op = readyOps.front(); + startPos++; + // Check all uses of `op` within `ops`. If all of the operations that define + // the operands of the user have been added to `orderedOps`, then the user + // is ready to be scheduled. + for (auto insertAfterOp : insertAfterMap[op]) { + if (processed.count(insertAfterOp)) continue; + if (llvm::all_of(insertAfterOp->getOperands(), [&](Value operand) { + Operation *operandDefiningOp = operand.getDefiningOp(); + return !operandDefiningOp || !opSet.count(operandDefiningOp) || + processed.count(operandDefiningOp); + })) { + // readyOps.push_back(insertAfterOp); + orderedOps.push_back(insertAfterOp); + processed.insert(insertAfterOp); + } + } + readyOps = ArrayRef(orderedOps).drop_front(startPos); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Ops to be inlined (sorted) : \n"; + for (auto op : orderedOps) { + llvm::dbgs() << "\t"; + op->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + } + }); + assert(orderedOps.size() == ops.size() && + "ordering of inlined operations failed"); + return orderedOps; +} diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h index d75177525638..7d107b3ae5be 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h @@ -6,6 +6,8 @@ #ifndef IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_ #define IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_ +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Support/LogicalResult.h" namespace mlir { @@ -20,6 +22,14 @@ namespace IREE { namespace Flow { class DispatchRegionOp; +/// For a given operation returns the loop ranges needed to compute the op. +SmallVector getLoopRanges(Operation *op, Location loc, + OpBuilder &builder); + +/// Reify the dynamic dimensions of the given value. +LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value, + SmallVector &dynamicDims); + /// Append a result to the given DispatchRegionOp. The newly created /// DispatchRegionOp is returned. FailureOr appendDispatchRegionResult( @@ -68,8 +78,22 @@ FailureOr movePrecedingOpIntoDispatchRegion( FailureOr wrapOpInDispatchRegion(RewriterBase &rewriter, Operation *op); +/// Sort the given ops topologically, so that they can be inlined into a +/// dispatch region without dominance violations. +/// +/// Example: +/// +/// %0 = "some_op"() +/// %1 = "another_op"(%1) +/// +/// In the above example, "some_op" is before "another_op" in the result. +// TODO: Improve mlir::sortTopologically. This function does currently not +// support ops from different blocks. +SmallVector orderOperations(ArrayRef ops); + } // namespace Flow } // namespace IREE } // namespace iree_compiler } // namespace mlir + #endif // IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_ diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/PadTensorToTensorInsertSlice.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/TensorPadToTensorInsertSlice.cpp similarity index 66% rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/PadTensorToTensorInsertSlice.cpp rename to compiler/src/iree/compiler/Dialect/Flow/Transforms/TensorPadToTensorInsertSlice.cpp index df83519d1a6f..03f0cf0331cb 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/PadTensorToTensorInsertSlice.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/TensorPadToTensorInsertSlice.cpp @@ -4,11 +4,9 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- PadTensorToInsertSlice.cpp ----- Pass to legalize linalg.pad_tensor-===// +//===- TensorPadToInsertSlice.cpp ----- Pass to legalize tensor.pad -------===// // -// Pass to convert linalg.pad_tensor to linalg.fill + tensor.insert_slice -// operations which is the only way Vulkan backend can lower it to a single -// kernel. +// Pass to convert tensor.pad to linalg.fill + tensor.insert_slice. // //===----------------------------------------------------------------------===// @@ -31,11 +29,14 @@ namespace IREE { namespace Flow { namespace { -/// Pattern to convert a linalg.pad_tensor operation into a fill + tensor -/// insert_slice. This is needed till pad_tensor op can be fused with its +/// Pattern to convert a tensor.tensor operation into a fill + +/// tensor.insert_slice. This is needed till tensor.pad op can be fused with its /// consumers. -struct PadTensorOpConversion : public OpRewritePattern { +struct TensorPadOpConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + TensorPadOpConversion(MLIRContext *context, bool skipSingleLinalgOpUses) + : OpRewritePattern(context, skipSingleLinalgOpUses), + skipSingleLinalgOpUses(skipSingleLinalgOpUses) {} LogicalResult matchAndRewrite(tensor::PadOp padTensorOp, PatternRewriter &rewriter) const override { @@ -51,6 +52,17 @@ struct PadTensorOpConversion : public OpRewritePattern { return failure(); } + if (skipSingleLinalgOpUses && padTensorOp->hasOneUse()) { + Operation *use = padTensorOp->use_begin()->getOwner(); + // TODO(#10312): Relax the condition to not check quantized ops. They + // are going to be deprecated. We don't expect them being IREE's input. + if (isa(use) && + !isa(use)) { + return failure(); + } + } + OpBuilder::InsertionGuard g(rewriter); Location loc = padTensorOp.getLoc(); auto lowPad = padTensorOp.getMixedLowPad(); @@ -101,32 +113,55 @@ struct PadTensorOpConversion : public OpRewritePattern { padTensorOp, source, fill, lowPad, sourceShape, strides); return success(); } + + private: + // Option to skip the pattern when tensor.pad op has one use and is used by + // a Linalg op. + bool skipSingleLinalgOpUses = false; }; -struct PadTensorToTensorInsertSlicePass - : public PadTensorToTensorInsertSliceBase< - PadTensorToTensorInsertSlicePass> { +struct TensorPadToTensorInsertSlicePass + : public TensorPadToTensorInsertSliceBase< + TensorPadToTensorInsertSlicePass> { + TensorPadToTensorInsertSlicePass(bool skipSingleLinalgOpUses) + : skipSingleLinalgOpUses(skipSingleLinalgOpUses) {} void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); } + LogicalResult initializeOptions(StringRef options) override { + if (failed(Pass::initializeOptions(options))) { + return failure(); + } + // `skipSingleLinalgOpUses` may have been set to `true` in the constructor + // already. The |= is so we preserve that rather than overwrite it with the + // default value `false` of `optionSkipSingleLinalgOpUses`. + skipSingleLinalgOpUses |= optionSkipSingleLinalgOpUses; + return success(); + } + void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.insert(context); + patterns.insert(context, skipSingleLinalgOpUses); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } + + private: + bool skipSingleLinalgOpUses; }; } // namespace -std::unique_ptr createPadTensorToTensorInsertSlicePass() { - return std::make_unique(); +std::unique_ptr createTensorPadToTensorInsertSlicePass( + bool skipSingleLinalgOpUses) { + return std::make_unique( + skipSingleLinalgOpUses); } } // namespace Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD index cc939824e91e..abbe41635a29 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD @@ -24,12 +24,14 @@ iree_lit_test_suite( "deduplicate_executables.mlir", "detach_elementwise_from_named_ops.mlir", "dispatch_linalg_on_tensors.mlir", + "dispatch_linalg_on_tensors_aggressive_fusion.mlir", "dispatch_linalg_on_tensors_fusion.mlir", "dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir", "dispatch_linalg_on_tensors_fusion_with_transpose.mlir", "dispatch_linalg_transform_dialect.mlir", "expand_tensor_shapes.mlir", "export_benchmark_funcs.mlir", + "fusion_of_tensor_ops.mlir", "infer_numeric_narrowing.mlir", "initialize_empty_tensor.mlir", "inject_dispatch_tracing.mlir", @@ -39,7 +41,7 @@ iree_lit_test_suite( "optimize_numerics.mlir", "outline_dispatch_regions.mlir", "pad_linalg_ops.mlir", - "pad_tensor_to_tensor.mlir", + "tensor_pad_to_tensor_insert_slice.mlir", "region_to_workgroups.mlir", "strip_and_splat_constant_variables.mlir", "strip_signedness.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index a5c9cc9bd0f3..e5cce098c2a7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -22,12 +22,14 @@ iree_lit_test_suite( "deduplicate_executables.mlir" "detach_elementwise_from_named_ops.mlir" "dispatch_linalg_on_tensors.mlir" + "dispatch_linalg_on_tensors_aggressive_fusion.mlir" "dispatch_linalg_on_tensors_fusion.mlir" "dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir" "dispatch_linalg_on_tensors_fusion_with_transpose.mlir" "dispatch_linalg_transform_dialect.mlir" "expand_tensor_shapes.mlir" "export_benchmark_funcs.mlir" + "fusion_of_tensor_ops.mlir" "infer_numeric_narrowing.mlir" "initialize_empty_tensor.mlir" "inject_dispatch_tracing.mlir" @@ -37,10 +39,10 @@ iree_lit_test_suite( "optimize_numerics.mlir" "outline_dispatch_regions.mlir" "pad_linalg_ops.mlir" - "pad_tensor_to_tensor.mlir" "region_to_workgroups.mlir" "strip_and_splat_constant_variables.mlir" "strip_signedness.mlir" + "tensor_pad_to_tensor_insert_slice.mlir" "transform_dispatch_region_formation.mlir" "transformation_pipeline.mlir" "verify_input_ir.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir index aedca264a4a5..acd5fa865509 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir @@ -1,17 +1,26 @@ // RUN: iree-opt --split-input-file --iree-flow-detach-elementwise-from-named-ops --mlir-print-local-scope %s | FileCheck %s func.func @matmul(%a: tensor, %b: tensor<64x?xf32>, %c: tensor) -> tensor { - %0 = linalg.matmul ins(%a, %b : tensor, tensor<64x?xf32>) outs(%c : tensor) -> tensor - return %0 : tensor + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%c : tensor) outs(%c : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %1 = arith.addf %b0, %b0 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.matmul ins(%a, %b : tensor, tensor<64x?xf32>) outs(%0 : tensor) -> tensor + return %1 : tensor } // CHECK-LABEL: func @matmul -// CHECK-SAME: (%[[A:.+]]: tensor, %[[B:.+]]: tensor<64x?xf32>, %[[C:.+]]: tensor) +// CHECK-SAME: (%[[A:.+]]: tensor, %[[B:.+]]: tensor<64x?xf32>, %[[ARG2:.+]]: tensor) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 - +// CHECK: %[[C:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG2]] : // CHECK: %[[DIM0:.+]] = tensor.dim %[[C]], %[[C0]] // CHECK: %[[DIM1:.+]] = tensor.dim %[[C]], %[[C1]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]]] @@ -32,16 +41,25 @@ func.func @matmul(%a: tensor, %b: tensor<64x?xf32>, %c: tensor, %b: tensor, %c: tensor) -> tensor { - %0 = linalg.batch_matmul ins(%a, %b : tensor, tensor) outs(%c : tensor) -> tensor - return %0 : tensor + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%c : tensor) outs(%c : tensor) { + ^bb0(%b0 : i32, %b1 : i32): + %1 = arith.addi %b0, %b0 : i32 + linalg.yield %1 : i32 + } -> tensor + %1 = linalg.batch_matmul ins(%a, %b : tensor, tensor) outs(%0 : tensor) -> tensor + return %1 : tensor } // CHECK-LABEL: func @batch_matmul -// CHECK-SAME: (%[[A:.+]]: tensor, %[[B:.+]]: tensor, %[[C:.+]]: tensor) +// CHECK-SAME: (%[[A:.+]]: tensor, %[[B:.+]]: tensor, %[[ARG2:.+]]: tensor) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[I0:.+]] = arith.constant 0 : i32 - +// CHECK: %[[C:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG2]] : // CHECK: %[[DIM0:.+]] = tensor.dim %[[C]], %[[C0]] : tensor // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 8, 16] : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[I0]] : i32) outs(%[[INIT]] : tensor) -> tensor @@ -57,14 +75,24 @@ func.func @batch_matmul(%a: tensor, %b: tensor, %c: tenso // ----- -func.func @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>, %init: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> { - %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} - ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> - return %0 : tensor<1x112x112x32xf32> +func.func @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>, %init: tensor<32xf32>) -> tensor<1x112x112x32xf32> { + %init0 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%init : tensor<32xf32>) outs(%init0 : tensor<1x112x112x32xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + linalg.yield %b0 : f32 + } -> tensor<1x112x112x32xf32> + %1 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) outs(%0 : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> + return %1 : tensor<1x112x112x32xf32> } // CHECK-LABEL: func @conv -// CHECK-SAME: (%{{.+}}: tensor<1x225x225x3xf32>, %{{.+}}: tensor<3x3x3x32xf32>, %[[INIT:.+]]: tensor<1x112x112x32xf32>) +// CHECK-SAME: (%{{.+}}: tensor<1x225x225x3xf32>, %{{.+}}: tensor<3x3x3x32xf32>, %[[BIAS:.+]]: tensor<32xf32>) +// CHECK: %[[INIT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[BIAS]] : // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf // CHECK: linalg.generic @@ -73,6 +101,33 @@ func.func @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>, // ----- +func.func @keep_fill(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %gemm = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%fill : tensor) -> tensor + return %gemm : tensor +} +// CHECK-LABEL: func.func @keep_fill +// CHECK-NOT: linalg.generic + +// ----- + +func.func @keep_arg(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @keep_arg +// CHECK-NOT: linalg.generic + +// ----- + func.func @fft_cst_output(%arg0 : tensor<3x2190x1x512xf32>) -> (tensor<3x2190x1x512xf32>, tensor<3x2190x1x512xf32>) { %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index 6b53bfcc2ac0..89cd23e96e14 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-enable-multi-result-dispatches --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass), cse, canonicalize, cse" %s | FileCheck %s +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true}), cse, canonicalize, cse" %s | FileCheck %s func.func @tile_matmul_alone(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { @@ -1066,7 +1066,7 @@ func.func @extract_slice(%arg0 : tensor, %arg1 : index, %arg2 : index, // ----- // TODO(ravishankarm): Enable after upstream pad op tiling issues are addressed. -// func.func @pad_tensor(%arg0 : tensor, %arg1 : index, %arg2 : index, +// func.func @tensor.pad(%arg0 : tensor, %arg1 : index, %arg2 : index, // %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor { // %0 = tensor.pad %arg0 low[%arg1, %arg2] high[%arg3, %arg4] { // ^bb0(%arg6 : index, %arg7 : index): diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir new file mode 100644 index 000000000000..5ffa94cc371d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_aggressive_fusion.mlir @@ -0,0 +1,107 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true})" %s | FileCheck %s + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func @softmax(%arg0: tensor<12x128x128xf32>) -> tensor<12x128x128xf32> { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant -3.40282347E+38 : f32 + %0 = linalg.init_tensor [12, 128] : tensor<12x128xf32> + %1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<12x128xf32>) -> tensor<12x128xf32> + %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<12x128x128xf32>) outs(%1 : tensor<12x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.maxf %arg1, %arg2 : f32 + linalg.yield %7 : f32 + } -> tensor<12x128xf32> + %3 = linalg.init_tensor [12, 128, 128] : tensor<12x128x128xf32> + %4 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<12x128xf32>) -> tensor<12x128xf32> + %5:2 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %2 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%3, %4 : tensor<12x128x128xf32>, tensor<12x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): + %7 = arith.subf %arg1, %arg2 : f32 + %8 = math.exp %7 : f32 + %9 = arith.addf %8, %arg4 : f32 + linalg.yield %8, %9 : f32, f32 + } -> (tensor<12x128x128xf32>, tensor<12x128xf32>) + %6 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5#0, %5#1 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%3 : tensor<12x128x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %7 = arith.divf %cst, %arg2 : f32 + %8 = arith.mulf %arg1, %7 : f32 + linalg.yield %8 : f32 + } -> tensor<12x128x128xf32> + return %6 : tensor<12x128x128xf32> + } +} +// CHECK-LABEL: func @softmax( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x128x128xf32> +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups +// CHECK-SAME: (%[[ARG0]]) +// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.tensor +// CHECK: %[[LOAD0:.+]] = flow.dispatch.tensor.load %[[ARG1]] +// CHECK: %[[FILL0:.+]] = linalg.fill +// CHECK: %[[FILL1:.+]] = linalg.fill +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LOAD0]] : +// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[LOAD0]], %[[GENERIC0]] : +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC1]]#0, %[[GENERIC1]]#1 : +// CHECK: flow.dispatch.tensor.store %[[GENERIC2]] +// CHECK: flow.return +// CHECK: return %[[DISPATCH]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4, d0)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0)> +#map2 = affine_map<(d0) -> (d0)> +module { + func.func @batchnorm_training(%arg0: tensor<12xf32>, %arg1: tensor<12x12x12x12x12xf32>, %arg2: tensor<12xf32>) -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) { + %cst = arith.constant 1.420000e+00 : f32 + %cst_0 = arith.constant 1.450000e+00 : f32 + %cst_1 = arith.constant 1.300000e+00 : f32 + %cst_2 = arith.constant 0.000000e+00 : f32 + %0 = linalg.init_tensor [12] : tensor<12xf32> + %1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<12xf32>) -> tensor<12xf32> + %2 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg1, %arg2 : tensor<12x12x12x12x12xf32>, tensor<12xf32>) outs(%1 : tensor<12xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.subf %arg3, %arg4 : f32 + %5 = arith.mulf %4, %4 : f32 + %6 = arith.addf %arg5, %5 : f32 + linalg.yield %6 : f32 + } -> tensor<12xf32> + %3:3 = linalg.generic {indexing_maps = [#map2, #map2, #map2, #map2, #map2], iterator_types = ["parallel"]} ins(%arg0, %2 : tensor<12xf32>, tensor<12xf32>) outs(%0, %0, %0 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32): + %4 = arith.divf %arg4, %cst_0 : f32 + %5 = arith.addf %4, %cst_1 : f32 + %6 = math.sqrt %5 : f32 + %7 = arith.subf %arg3, %6 : f32 + %8 = arith.mulf %7, %cst : f32 + %9 = arith.subf %arg3, %8 : f32 + linalg.yield %5, %6, %9 : f32, f32, f32 + } -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) + return %3#0, %3#1, %3#2 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32> + } +} +// CHECK-LABEL: func @batchnorm_training( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<12x12x12x12x12xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<12xf32> +// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.workgroups +// CHECK-SAME: (%[[ARG1]], %[[ARG2]], %[[ARG0]]) +// CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor +// CHECK-DAG: %[[LOAD0:.+]] = flow.dispatch.tensor.load %[[ARG3]] +// CHECK-DAG: %[[LOAD1:.+]] = flow.dispatch.tensor.load %[[ARG4]] +// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.tensor.load %[[ARG5]] +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LOAD0]], %[[LOAD1]] : +// CHECK: %[[GENERIC1:.+]]:3 = linalg.generic +// CHECK-SAME: ins(%[[LOAD2]], %[[GENERIC0]] : +// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC1]]#0 +// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC1]]#1 +// CHECK-DAG: flow.dispatch.tensor.store %[[GENERIC1]]#2 +// CHECK: flow.return +// CHECK: return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2 diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir index ef08d5b0fd3f..e17cd08469b5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-enable-multi-result-dispatches --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass)" --canonicalize -cse %s | FileCheck %s +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="func.func(iree-flow-dispatch-linalg-on-tensors-pass{aggressive-fusion=true})" --canonicalize -cse %s | FileCheck %s func.func @fuse_conv2d_elementwise(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>, %offset: tensor<32xf32>) -> tensor<1x112x112x32xf32> { %cst = arith.constant 0.000000e+00 : f32 diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir index a552edd51e9b..2095fe499dd3 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_fusion_reduction_broadcast_elementwise.mlir @@ -106,7 +106,7 @@ func.func @reduction_broadcast_elementwise_binary2(%a1: tensor<128x384xf32>, %a2 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func.func @reduction_broadcast_elementwise_type_mismatch(%a: tensor<12x16x16xf32>, %b: tensor<12x16x32xf32>) -> tensor<12x16x32xf32> { +func.func @reduction_broadcast_elementwise_type_mismatch(%a: tensor<12x16x16xf32>, %b: tensor<12x16x32xf32>) -> tensor<12x16x32xi32> { %cst_47 = arith.constant 0.000000e+00 : f32 %37 = linalg.init_tensor [12, 16] : tensor<12x16xf32> %38 = linalg.fill ins(%cst_47 : f32) outs(%37 : tensor<12x16xf32>) -> tensor<12x16xf32> @@ -115,13 +115,14 @@ func.func @reduction_broadcast_elementwise_type_mismatch(%a: tensor<12x16x16xf32 %780 = arith.maxf %arg3, %arg4 : f32 linalg.yield %780 : f32 } -> tensor<12x16xf32> - %40 = linalg.init_tensor [12, 16, 32] : tensor<12x16x32xf32> - %42 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%b, %39 : tensor<12x16x32xf32>, tensor<12x16xf32>) outs(%40 : tensor<12x16x32xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %40 = linalg.init_tensor [12, 16, 32] : tensor<12x16x32xi32> + %42 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%b, %39 : tensor<12x16x32xf32>, tensor<12x16xf32>) outs(%40 : tensor<12x16x32xi32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: i32): %780 = arith.subf %arg3, %arg4 : f32 - linalg.yield %780 : f32 - } -> tensor<12x16x32xf32> - return %42 : tensor<12x16x32xf32> + %781 = arith.fptosi %780 : f32 to i32 + linalg.yield %781 : i32 + } -> tensor<12x16x32xi32> + return %42 : tensor<12x16x32xi32> } // Check that two generic ops are NOT dispatched together since the input type @@ -163,4 +164,3 @@ func.func @reduction_broadcast_elementwise_dynamic(%a: tensor<12x16x?xf32>, %b: // CHECK-LABEL: func.func @reduction_broadcast_elementwise_dynamic // CHECK: flow.dispatch.workgroups // CHECK: flow.dispatch.workgroups - diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir new file mode 100644 index 000000000000..bd280d5a873d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir @@ -0,0 +1,162 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="func.func(iree-flow-fusion-of-tensor-ops{fuse-multi-use=true})" %s | FileCheck %s + +func.func @softmax(%arg0 : tensor<12x128x128xf32>) -> tensor<12x128x128xf32> { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant -3.40282347E+38 : f32 + %1 = linalg.init_tensor [12, 128] : tensor<12x128xf32> + %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<12x128xf32>) -> tensor<12x128xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<12x128x128xf32>) outs(%2 : tensor<12x128xf32>) { + ^bb0(%b0: f32, %b1: f32): + %11 = arith.maxf %b0, %b1 : f32 + linalg.yield %11 : f32 + } -> tensor<12x128xf32> + %4 = linalg.init_tensor [12, 128, 128] : tensor<12x128x128xf32> + %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %3 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%4 : tensor<12x128x128xf32>) { + ^bb0(%b0: f32, %b1: f32, %arg2: f32): + %11 = arith.subf %b0, %b1 : f32 + linalg.yield %11 : f32 + } -> tensor<12x128x128xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor<12x128x128xf32>) outs(%4 : tensor<12x128x128xf32>) { + ^bb0(%b0: f32, %b1: f32): + %11 = math.exp %b0 : f32 + linalg.yield %11 : f32 + } -> tensor<12x128x128xf32> + %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<12x128xf32>) -> tensor<12x128xf32> + %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor<12x128x128xf32>) outs(%7 : tensor<12x128xf32>) { + ^bb0(%b0: f32, %b1: f32): + %11 = arith.addf %b0, %b1 : f32 + linalg.yield %11 : f32 + } -> tensor<12x128xf32> + %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<12x128xf32>) outs(%1 : tensor<12x128xf32>) { + ^bb0(%b0: f32, %b1: f32): + %11 = arith.divf %cst, %b0 : f32 + linalg.yield %11 : f32 + } -> tensor<12x128xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor<12x128x128xf32>, tensor<12x128xf32>) outs(%4 : tensor<12x128x128xf32>) { + ^bb0(%b0: f32, %b1: f32, %arg2: f32): + %11 = arith.mulf %b0, %b1 : f32 + linalg.yield %11 : f32 + } -> tensor<12x128x128xf32> + return %10 : tensor<12x128x128xf32> +} +// CHECK-LABEL: func.func @softmax +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x128x128xf32> +// CHECK: %[[INIT0:.+]] = linalg.init_tensor [12, 128] +// CHECK: %[[FILL0:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0]] : +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK-SAME: outs(%[[FILL0]] : +// CHECK: %[[INIT1:.+]] = linalg.init_tensor [12, 128, 128] +// CHECK: %[[FILL1:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0]] : +// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[GENERIC0]] : +// CHECK-SAME: outs(%[[INIT1]], %[[FILL1]] : +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC1]]#0, %[[GENERIC1]]#1 : +// CHECK-SAME: outs(%[[INIT1]] : +// CHECK: return %[[GENERIC2]] + +// ----- + +func.func @batchnorm_training(%10 : tensor<12xf32>, %11 : tensor<12x12x12x12x12xf32>, %12 : tensor<12xf32>) -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) +{ + %cst = arith.constant 1.42 : f32 + %cst_1 = arith.constant 1.45 : f32 + %cst_0 = arith.constant 1.3 : f32 + %cst_2 = arith.constant 0.0 : f32 + %13 = linalg.init_tensor [12] : tensor<12xf32> + %14 = linalg.fill ins(%cst_2 : f32) outs(%13 : tensor<12xf32>) -> tensor<12xf32> + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4, d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0)>], + iterator_types = ["parallel", "reduction", "reduction", "reduction", "reduction"]} + ins(%11, %12 : tensor<12x12x12x12x12xf32>, tensor<12xf32>) outs(%14 : tensor<12xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %19 = arith.subf %arg1, %arg2 : f32 + %20 = arith.mulf %19, %19 : f32 + %21 = arith.addf %arg3, %20 : f32 + linalg.yield %21 : f32 + } -> tensor<12xf32> + %16 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} + ins(%15: tensor<12xf32>) outs(%13 : tensor<12xf32>) { + ^bb0(%arg1: f32, %arg2 : f32): + %19 = arith.divf %arg1, %cst_1 : f32 + %20 = arith.addf %19, %cst_0 : f32 + linalg.yield %20 : f32 + } -> tensor<12xf32> + %17 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%16 : tensor<12xf32>) outs(%13 : tensor<12xf32>) { + ^bb0(%arg1: f32, %arg2 : f32): + %19 = math.sqrt %arg1 : f32 + linalg.yield %19 : f32 + } -> tensor<12xf32> + %18 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} + {__internal_linalg_transform__ = "tensor_fuse_err"} + ins(%10, %17 : tensor<12xf32>, tensor<12xf32>) outs(%13 : tensor<12xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3 : f32): + %19 = arith.subf %arg1, %arg2 : f32 + %20 = arith.mulf %19, %cst : f32 + %21 = arith.subf %arg1, %20 : f32 + linalg.yield %21 : f32 + } -> tensor<12xf32> + return %16, %17, %18 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32> +} +// CHECK-LABEL: func @batchnorm_training( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<12x12x12x12x12xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<12xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [12] : tensor<12xf32> +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]] : +// CHECK-SAME: outs(%[[FILL]] : +// CHECK: %[[GENERIC1:.+]]:3 = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[GENERIC0]] : +// CHECK-SAME: outs(%[[INIT]], %[[INIT]], %[[INIT]] : +// CHECK: return %[[GENERIC1]]#0, %[[GENERIC1]]#1, %[[GENERIC1]]#2 + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_only_with_same_marker(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<5x5xf32>, tensor<5x5xf32>, tensor<5x5xf32>) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %cst_1 = arith.constant 3.000000e+00 : f32 + %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %2 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %3 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%0 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): + %8 = arith.addf %arg2, %cst : f32 + linalg.yield %8 : f32 + } -> tensor<5x5xf32> + %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<5x5xf32>) outs(%1 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): + %8 = arith.subf %arg2, %cst_0 : f32 + linalg.yield %8 : f32 + } -> tensor<5x5xf32> + %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%2 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): + %8 = arith.addf %arg2, %cst_1 : f32 + linalg.yield %8 : f32 + } -> tensor<5x5xf32> + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %5 : tensor<5x5xf32>, tensor<5x5xf32>) outs(%3 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + %8 = arith.subf %arg2, %arg3 : f32 + linalg.yield %8 : f32 + } -> tensor<5x5xf32> + return %4, %5, %6, %7 : tensor<5x5xf32>, tensor<5x5xf32>, tensor<5x5xf32>, tensor<5x5xf32> + } +} +// CHECK-LABEL: func.func @fuse_only_with_same_marke +// CHECK: linalg.generic +// CHECK-NOT: linalg.generic diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/tensor_pad_to_tensor_insert_slice.mlir similarity index 64% rename from compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir rename to compiler/src/iree/compiler/Dialect/Flow/Transforms/test/tensor_pad_to_tensor_insert_slice.mlir index b47cc20ddf7a..0c8c1a1a5106 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/tensor_pad_to_tensor_insert_slice.mlir @@ -1,7 +1,8 @@ -// RUN: iree-opt --split-input-file --iree-flow-pad-tensor-to-tensor-insert-slice --canonicalize %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice --canonicalize %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice=skip-one-linalg-use-case --canonicalize %s | FileCheck %s --check-prefix=SKIP module { - func.func @pad_tensor(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { + func.func @tensor_pad(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index @@ -15,7 +16,7 @@ module { } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 3)> -// CHECK: func.func @pad_tensor +// CHECK: func.func @tensor_pad // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index @@ -37,7 +38,7 @@ module { // ----- module { - func.func @pad_tensor_static(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf32> { + func.func @tensor_pad_static(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf32> { %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index %c5 = arith.constant 5 : index @@ -50,7 +51,7 @@ module { return %1 : tensor<18x12xf32> } } -// CHECK-LABEL: func.func @pad_tensor_static +// CHECK-LABEL: func.func @tensor_pad_static // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<12x4xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-DAG: %[[VAL:.+]] = tensor.extract %[[ARG1]] @@ -60,3 +61,20 @@ module { // CHECK-SAME: outs(%[[INIT]] : // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][4, 5] [12, 4] [1, 1] // CHECK: return %[[RESULT]] + +// ----- + +func.func @_main(%arg0: tensor<1x33x33x480xf32>, %arg1: tensor<3x3x480x1xf32>) -> tensor<1x33x33x480xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.pad %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0] { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : f32 + } : tensor<1x33x33x480xf32> to tensor<1x41x41x480xf32> + %1 = linalg.init_tensor [1, 33, 33, 480] : tensor<1x33x33x480xf32> + %2 = tensor.collapse_shape %arg1 [[0], [1], [2, 3]] : tensor<3x3x480x1xf32> into tensor<3x3x480xf32> + %3 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x33x33x480xf32>) -> tensor<1x33x33x480xf32> + %4 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<4> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%0, %2 : tensor<1x41x41x480xf32>, tensor<3x3x480xf32>) outs(%3 : tensor<1x33x33x480xf32>) -> tensor<1x33x33x480xf32> + return %4 : tensor<1x33x33x480xf32> +} +// CHECK-NOT: tensor.pad +// SKIP: tensor.pad diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD index 59d873799405..a638133904cf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD @@ -37,6 +37,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM", "//compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM", "//compiler/src/iree/compiler/Dialect/VM/IR", + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:FuncDialect", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt index 29261e7045d0..17c8d8ff0434 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt @@ -40,6 +40,7 @@ iree_cc_library( iree::compiler::Dialect::VM::Conversion::StandardToVM iree::compiler::Dialect::VM::Conversion::UtilToVM iree::compiler::Dialect::VM::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp index bd29e66f18f4..6de450705252 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "iree/compiler/Utils/StringUtils.h" #include "llvm/ADT/DenseMap.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -75,9 +76,8 @@ IREE::VM::RodataOp createExecutableBinaryRodata( auto insertPoint = builder.saveInsertionPoint(); builder.setInsertionPoint(builder.getInsertionBlock()->getParentOp()); - std::string rodataName = - (executableOp.getName() + "_" + binaryOp.getName()).str(); - std::replace(rodataName.begin(), rodataName.end(), '-', '_'); + std::string rodataName = sanitizeSymbolName( + (executableOp.getName() + "_" + binaryOp.getName()).str()); auto rodataOp = builder.create( binaryOp.getLoc(), rodataName, binaryOp.getData()); rodataOp.setPrivate(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir index 9afc06b82b6b..55c7f4cdec0d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --cse %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --cse --iree-vm-target-index-bits=32 %s | FileCheck %s // CHECK-LABEL: @buffer_subspan // CHECK-SAME: (%[[BUFFER:.+]]: !vm.ref) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_view_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_view_ops.mlir index 1b7e322435b0..512302d78cba 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_view_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_view_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --iree-vm-target-index-bits=32 %s | FileCheck %s // CHECK-LABEL: vm.func private @buffer_view_dims // CHECK-SAME: %[[VIEW:.+]]: !vm.ref diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 894a4fbcd852..934c0cb601d6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --canonicalize %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --canonicalize --iree-vm-target-index-bits=32 %s | FileCheck %s // CHECK-LABEL: @command_buffer_create func.func @command_buffer_create(%arg0: !hal.device) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir index 643e7bb40a3e..6ffca9b6ec43 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --canonicalize %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm --canonicalize --iree-vm-target-index-bits=32 %s | FileCheck %s // CHECK-LABEL: @device_allocator // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir index 507e76a90ac3..d6a20e7c00c6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir @@ -117,7 +117,7 @@ func.func @cmdExecute(%arg0: !stream.resource, %arg1: index, %arg2: ! hal.executable private @ex { hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info + translation_info = #iree_codegen.translation_info } { ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD index 63c63cacccdb..0d6802922bf7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD @@ -72,6 +72,7 @@ iree_compiler_cc_library( ":HALOpsGen", ":HALTypesGen", "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:FuncDialect", diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 2ca1befe2f61..cc94221af240 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -50,6 +50,7 @@ iree_cc_library( MLIRTransformUtils MLIRViewLikeInterface iree::compiler::Dialect::Util::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index 6d35c8a3602f..f19a658f4ff9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -9,6 +9,7 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Utils/StringUtils.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -198,9 +199,7 @@ void DeviceTargetAttr::print(AsmPrinter &p) const { } std::string DeviceTargetAttr::getSymbolNameFragment() { - auto deviceName = getDeviceID().getValue().lower(); - std::replace(deviceName.begin(), deviceName.end(), '-', '_'); - return deviceName; + return sanitizeSymbolName(getDeviceID().getValue().lower()); } Attribute DeviceTargetAttr::getMatchExpression() { @@ -307,9 +306,7 @@ void ExecutableTargetAttr::print(AsmPrinter &p) const { } std::string ExecutableTargetAttr::getSymbolNameFragment() { - auto format = getFormat().getValue().lower(); - std::replace(format.begin(), format.end(), '-', '_'); - return format; + return sanitizeSymbolName(getFormat().getValue().lower()); } Attribute ExecutableTargetAttr::getMatchExpression() { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp index 1b67e382ae19..e613599add85 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Dialect/HAL/Target/CUDA/cuda_libdevice.h" #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" #include "iree/compiler/Utils/FlatbufferUtils.h" +#include "iree/compiler/Utils/StringUtils.h" #include "iree/schemas/cuda_executable_def_builder.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/Constants.h" @@ -144,14 +145,6 @@ static void linkAndOptimize(llvm::Module &module, MPM.run(module); } -/// Sanitize the function name as CUDA driver doesn't allow function names with -/// '.' character. -static std::string sanitizeNameForCuda(llvm::StringRef name) { - std::string sanitizedName(name); - std::replace(sanitizedName.begin(), sanitizedName.end(), '.', '_'); - return sanitizedName; -} - class CUDATargetBackend final : public TargetBackend { public: CUDATargetBackend() = default; @@ -231,7 +224,7 @@ class CUDATargetBackend final : public TargetBackend { auto *llvmFunc = llvmModule->getFunction(func.getName()); if (llvmFunc->isDeclaration()) continue; // setName will make sure the function name is unique. - llvmFunc->setName(sanitizeNameForCuda(func.getName())); + llvmFunc->setName(sanitizeSymbolName(func.getName())); entryPointNames.emplace_back(llvmFunc->getName()); std::array workgroupSize; uint32_t workgroupLocalMemory = 0; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD index 50d8e3b5f2f5..552f37b6481b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD @@ -42,6 +42,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/HAL/Target", "//compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/Builtins", + "//compiler/src/iree/compiler/Utils", "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:AArch64AsmParser", @@ -113,6 +114,7 @@ iree_compiler_cc_library( srcs = ["LinkerTool.cpp"], hdrs = ["LinkerTool.h"], deps = platform_trampoline_deps("LinkerTools", "compiler/src/iree/compiler/Dialect/HAL/Target/LLVM") + [ + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", ], ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt index 01dcf8f39145..df969100c8e5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt @@ -51,6 +51,7 @@ iree_cc_library( iree::compiler::Codegen::Utils iree::compiler::Dialect::HAL::Target iree::compiler::Dialect::HAL::Target::LLVM::Builtins + iree::compiler::Utils PUBLIC ) @@ -99,6 +100,7 @@ iree_cc_library( DEPS LLVMSupport iree::compiler::Dialect::HAL::Target::LLVM::internal::LinkerTools_internal + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp index 04706fc7e8a1..233337558e63 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp @@ -19,13 +19,13 @@ #include "iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h" #include "iree/compiler/Dialect/HAL/Target/LLVM/StaticLibraryGenerator.h" #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "iree/compiler/Utils/ModuleUtils.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/TargetSelect.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -45,29 +45,6 @@ namespace HAL { static constexpr char kQueryFunctionName[] = "iree_hal_executable_library_query"; -static llvm::Optional findFirstFileLoc(Location baseLoc) { - if (auto loc = baseLoc.dyn_cast()) { - for (auto &childLoc : loc.getLocations()) { - auto childResult = findFirstFileLoc(childLoc); - if (childResult) return childResult; - } - } else if (auto loc = baseLoc.dyn_cast()) { - return loc; - } - return llvm::None; -} - -static std::string guessModuleName(mlir::ModuleOp moduleOp) { - std::string moduleName = moduleOp.getName().value_or("").str(); - if (!moduleName.empty()) return moduleName; - auto loc = findFirstFileLoc(moduleOp.getLoc()); - if (loc.has_value()) { - return llvm::sys::path::stem(loc.value().getFilename()).str(); - } else { - return "llvm_module"; - } -} - // Appends the |debugDatabase| to the end of |baseFile| and writes the footer // so the runtime can find it. static LogicalResult appendDebugDatabase(std::vector &baseFile, @@ -180,39 +157,8 @@ class LLVMCPUTargetBackend final : public TargetBackend { buildLLVMCPUCodegenPassPipeline(passManager); } - LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override { - OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); - - auto sourceExecutableOps = - llvm::to_vector<8>(moduleOp.getOps()); - if (sourceExecutableOps.size() <= 1) return success(); - - // TODO(benvanik): rework linking to support multiple formats. - auto sharedTargetAttr = getExecutableTarget(builder.getContext()); - - // Guess a module name, if needed, to make the output files readable. - auto moduleName = guessModuleName(moduleOp); - - // Create our new "linked" hal.executable. - std::string linkedExecutableName = - llvm::formatv("{0}_linked_{1}", moduleName, "llvm_cpu"); - auto linkedExecutableOp = builder.create( - moduleOp.getLoc(), linkedExecutableName); - linkedExecutableOp.setVisibility( - sourceExecutableOps.front().getVisibility()); - - // Add our hal.executable.variant with an empty module. - builder.setInsertionPointToStart(&linkedExecutableOp.getBlock()); - auto linkedTargetOp = builder.create( - moduleOp.getLoc(), sharedTargetAttr.getSymbolNameFragment(), - sharedTargetAttr); - builder.setInsertionPoint(&linkedTargetOp.getBlock().back()); - builder.create(moduleOp.getLoc()); - - // Try linking together all executables in moduleOp. - return linkExecutablesInto( - moduleOp, sourceExecutableOps, linkedExecutableOp, linkedTargetOp, - [](mlir::ModuleOp moduleOp) { return moduleOp; }, builder); + void buildLinkingPassPipeline(OpPassManager &passManager) override { + buildLLVMCPULinkingPassPipeline(passManager); } LogicalResult serializeExecutable(const SerializationOptions &options, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp index 0afc173eeb39..aba0a6fec1c0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h" +#include "iree/compiler/Utils/StringUtils.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Process.h" @@ -16,30 +17,14 @@ namespace iree_compiler { namespace IREE { namespace HAL { -// Sanitizes potentially user provided portions of a file name by replacing -// all but a small set of alpha numeric and safe punctuation characters with -// '_'. This is intended for components of temporary files that are uniqued -// independently, where the input is meant to aid debugability but does not -// need to be retained verbatim. -static void sanitizeFilePart(llvm::SmallVectorImpl &part) { - for (char &c : part) { - if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || - (c >= '0' && c <= '9') || c == '_' || c == '-' || c == '.') - continue; - c = '_'; - } -} - // static Artifact Artifact::createTemporary(StringRef prefix, StringRef suffix) { - llvm::SmallString<8> prefixCopy(prefix); - llvm::SmallString<8> suffixCopy(suffix); - sanitizeFilePart(prefixCopy); - sanitizeFilePart(suffixCopy); + auto sanitizedPrefix = sanitizeFileName(prefix); + auto sanitizedSuffix = sanitizeFileName(suffix); llvm::SmallString<32> filePath; if (std::error_code error = llvm::sys::fs::createTemporaryFile( - prefixCopy, suffixCopy, filePath)) { + sanitizedPrefix, sanitizedSuffix, filePath)) { llvm::errs() << "failed to generate temporary file: " << error.message(); return {}; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp index 5ccb495d546d..025bb822a4e8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp @@ -10,7 +10,6 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/Dialect.h" @@ -64,246 +63,6 @@ void TargetOptions::bindOptions(OptionsBinder &binder) { llvm::cl::cat(halTargetOptionsCategory)); } -// Renames |op| within |moduleOp| with a new name that is unique within both -// |moduleOp| and |optionalSymbolTable| (if one is provided). -static void renameWithDisambiguatedName( - Operation *op, Operation *moduleOp, - DenseMap &targetSymbolMap, - SymbolTable *optionalSymbolTable) { - StringRef originalName = SymbolTable::getSymbolName(op).getValue(); - - // Iteratively try suffixes until we find one that isn't used. - std::string disambiguatedName; - int uniqueingCounter = 0; - do { - disambiguatedName = - llvm::formatv("{0}_{1}", originalName, uniqueingCounter++).str(); - } while ( - targetSymbolMap.lookup(disambiguatedName) || - (optionalSymbolTable && optionalSymbolTable->lookup(disambiguatedName))); - - SymbolTableCollection symbolTable; - SymbolUserMap symbolUsers(symbolTable, moduleOp); - mlir::StringAttr nameAttr = - mlir::StringAttr::get(op->getContext(), disambiguatedName); - symbolUsers.replaceAllUsesWith(op, nameAttr); - SymbolTable::setSymbolName(op, disambiguatedName); -} - -// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version. -// Only difference is one has the symbol map that we don't even need. - -// Destructively merges |sourceModuleOp| into |targetModuleOp|. -// |targetSymbolMap| is updated with the new symbols. -// -// If a private symbol in |sourceModuleOp| conflicts with another symbol -// (public or private) tracked in |targetSymbolMap|, it will be renamed. -// -// Fails if a public symbol in |sourceModuleOp| conflicts with another public -// symbol tracked in |targetSymbolMap|. -static LogicalResult mergeModuleInto( - Operation *sourceModuleOp, Operation *targetModuleOp, - DenseMap &targetSymbolMap) { - auto &sourceBlock = sourceModuleOp->getRegion(0).front(); - auto &targetBlock = targetModuleOp->getRegion(0).front(); - SymbolTable sourceSymbolTable(sourceModuleOp); - auto allOps = llvm::to_vector<8>( - llvm::map_range(sourceBlock, [&](Operation &op) { return &op; })); - - for (auto &op : allOps) { - if (op->hasTrait()) continue; - if (auto symbolOp = dyn_cast(op)) { - auto symbolName = symbolOp.getName(); - - // Resolve symbol name conflicts. - if (auto targetOp = targetSymbolMap[symbolName]) { - if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) { - // Private symbols can be safely folded into duplicates or renamed. - if (OperationEquivalence::isEquivalentTo( - targetOp, op, OperationEquivalence::exactValueMatch, - OperationEquivalence::exactValueMatch, - OperationEquivalence::Flags::IgnoreLocations)) { - // Optimization: skip over duplicate private symbols. - // We could let CSE do this later, but we may as well check here. - continue; - } else { - // Preserve the op but give it a unique name. - renameWithDisambiguatedName(op, sourceModuleOp, targetSymbolMap, - &sourceSymbolTable); - } - } else { - // The source symbol has 'nested' or 'public' visibility. - if (SymbolTable::getSymbolVisibility(targetOp) != - SymbolTable::Visibility::Private) { - // Oops! Both symbols are public and we can't safely rename either. - // If you hit this with ops that you think are safe to rename, mark - // them private. - // - // Note: we could also skip linking between executables with - // conflicting symbol names. We think such conflicts will be better - // fixed in other ways, so we'll emit an error until we find a case - // where that isn't true. - return op->emitError() - << "multiple public symbols with the name: " << symbolName; - } else { - // Keep the original name for our new op, rename the target op. - renameWithDisambiguatedName(targetOp, targetModuleOp, - targetSymbolMap, - /*optionalSymbolTable=*/nullptr); - } - } - } - targetSymbolMap[SymbolTable::getSymbolName(op).getValue()] = op; - } - if (!targetBlock.empty() && - targetBlock.back().hasTrait()) { - op->moveBefore(&targetBlock.back()); - } else { - op->moveBefore(&targetBlock, targetBlock.end()); - } - } - - // Now that we're done cloning its ops, delete the original target op. - sourceModuleOp->erase(); - - return success(); -} - -struct SymbolReplacements { - DenseMap executableRefs; - DenseMap variantRefs; - DenseMap exportRefs; -}; - -// Replaces each usage of an entry point with its original symbol name with a -// new symbol name. -// -// Due to replaceSubElements recursing into symbol refs we need to perform -// replacement in descending symbol ref length; otherwise replacing the -// executable name in `@old_executable::@old_export` would result in -// `@new_executable::@old_export` and an export update would then not match the -// new/old mismatched ref. This means we have to do three walks over the entire -// module in order to do the replacements; not great. -static void replaceEntryPointUses( - mlir::ModuleOp moduleOp, const SymbolReplacements &symbolReplacements) { - auto replaceSymbolRefs = [](Operation *rootOp, - const DenseMap &map) { - auto allUses = SymbolTable::getSymbolUses(rootOp); - if (!allUses) return; - for (auto use : *allUses) { - auto oldAttr = use.getSymbolRef(); - auto newAttr = map.lookup(oldAttr); - if (!newAttr) continue; - auto newDict = use.getUser()->getAttrDictionary().replaceSubElements( - [&](Attribute attr) -> std::pair { - if (attr == oldAttr) { - // Found old->new replacement. - return {newAttr, WalkResult::skip()}; - } else if (attr.isa()) { - // Don't recurse into symbol refs - we only want to match roots. - return {attr, WalkResult::skip()}; - } - // Non-symbol ref attr. - return {attr, WalkResult::advance()}; - }); - use.getUser()->setAttrs(newDict.cast()); - } - }; - replaceSymbolRefs(moduleOp, symbolReplacements.exportRefs); - replaceSymbolRefs(moduleOp, symbolReplacements.variantRefs); - replaceSymbolRefs(moduleOp, symbolReplacements.executableRefs); - for (auto funcLikeOp : moduleOp.getOps()) { - replaceSymbolRefs(funcLikeOp, symbolReplacements.exportRefs); - replaceSymbolRefs(funcLikeOp, symbolReplacements.variantRefs); - replaceSymbolRefs(funcLikeOp, symbolReplacements.executableRefs); - } -} - -LogicalResult TargetBackend::linkExecutablesInto( - mlir::ModuleOp moduleOp, - ArrayRef sourceExecutableOps, - IREE::HAL::ExecutableOp linkedExecutableOp, - IREE::HAL::ExecutableVariantOp linkedTargetOp, - std::function getInnerModuleFn, - OpBuilder &builder) { - int nextEntryPointOrdinal = 0; - DenseMap targetSymbolMap; - SymbolReplacements symbolReplacements; - - auto linkedTargetBuilder = - OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); - auto linkedModuleOp = getInnerModuleFn(linkedTargetOp.getInnerModule()); - - // Iterate over all source executable ops, linking as many as we can. - for (auto sourceExecutableOp : sourceExecutableOps) { - // Remap root executable refs. - symbolReplacements.executableRefs[SymbolRefAttr::get(sourceExecutableOp)] = - SymbolRefAttr::get(linkedExecutableOp); - - auto variantOps = llvm::to_vector<4>( - sourceExecutableOp.getOps()); - for (auto variantOp : variantOps) { - // Only process targets matching our pattern. - if (variantOp.getTarget().getBackend().getValue() != name()) continue; - - // Remap variant refs. - auto oldVariantRefAttr = - SymbolRefAttr::get(builder.getContext(), sourceExecutableOp.getName(), - {SymbolRefAttr::get(variantOp)}); - auto newVariantRefAttr = - SymbolRefAttr::get(builder.getContext(), linkedExecutableOp.getName(), - {SymbolRefAttr::get(linkedTargetOp)}); - symbolReplacements.variantRefs[oldVariantRefAttr] = newVariantRefAttr; - - // Clone export ops and queue remapping ordinals and updating - // symbol refs. - for (auto exportOp : variantOp.getOps()) { - auto newExportOp = - linkedTargetBuilder.create( - exportOp.getLoc(), exportOp.getSymNameAttr(), - builder.getIndexAttr(nextEntryPointOrdinal++), - exportOp.getLayout(), ArrayAttr{}, IntegerAttr{}); - newExportOp->setDialectAttrs(exportOp->getDialectAttrs()); - - // Add to replacement table for fixing up dispatch calls referencing - // this export. - auto oldExportRefAttr = SymbolRefAttr::get( - builder.getContext(), sourceExecutableOp.getName(), - {SymbolRefAttr::get(variantOp), SymbolRefAttr::get(exportOp)}); - auto newExportRefAttr = SymbolRefAttr::get( - builder.getContext(), linkedExecutableOp.getName(), - {SymbolRefAttr::get(linkedTargetOp), - SymbolRefAttr::get(newExportOp)}); - symbolReplacements.exportRefs[oldExportRefAttr] = newExportRefAttr; - } - - // Merge the existing module into the new linked module op. - auto sourceModuleOp = getInnerModuleFn(variantOp.getInnerModule()); - if (failed(mergeModuleInto(sourceModuleOp, linkedModuleOp, - targetSymbolMap))) { - return failure(); - } - - variantOp.erase(); - } - - if (sourceExecutableOp.getOps().empty()) { - sourceExecutableOp.erase(); - } - } - - // Update references to @executable::@target::@entry symbols. - replaceEntryPointUses(moduleOp, symbolReplacements); - - // Remove if we didn't add anything. - if (linkedTargetOp.getOps().empty()) { - linkedTargetOp.erase(); - linkedExecutableOp.erase(); - } - - return success(); -} - void dumpDataToPath(StringRef path, StringRef baseName, StringRef suffix, StringRef extension, StringRef data) { auto fileName = (llvm::join_items("_", baseName, suffix) + extension).str(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h index 866d2446fe76..297c0973b592 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h @@ -168,47 +168,55 @@ class TargetBackend { // } virtual void buildTranslationPassPipeline(OpPassManager &passManager) = 0; - // Links compatible executables within the provided |moduleOp| together into - // zero or more new linked executables. Implementations should move - // executable contents (including interfaces, entry points, and functions) - // into new executables and update any relevant references as they do so. + // Inserts passes used to link `hal.executable.variant` ops together. + // The pass manager will be nested on the parent module of `hal.executable` + // ops and the pipeline will need to find relevant variant ops itself. // - // Which executables to link together and how many new executables to produce - // are left to implementations to determine. For example, an implementation - // may choose to link all executables (even with different interfaces) into - // a single combined executable, or it could choose to limit the number linked - // together in order to shard binary size across multiple executables. + // Implementations should clone executable contents (including interfaces, + // entry points, and functions) into new executables and update any relevant + // references as they do so. // - // The input |moduleOp| may contain executables containing multiple targets, - // so implementations should check target backend filters against their own - // `name()` prior to modifying them. + // Which executable variants to link together and how many new executables to + // produce are left to implementations to determine. For example, an + // implementation may choose to link all executables (even with different + // interfaces) into a single combined executable, or it could choose to limit + // the number linked together in order to shard binary size across multiple + // executables. // - // Sample output structure: - // hal.executable @linked_executable { - // hal.interface @io_0 { ... } - // hal.interface @io_1 { ... } - // hal.executable.variant @target, target="target-backend" { - // hal.executable.export @main_dispatch_0 attributes { ... } - // hal.executable.export @main_dispatch_1 attributes { ... } - // hal.executable.export @main_dispatch_2 attributes { ... } - // module { - // func.func @main_0(...) { ... } - // func.func @main_1(...) { ... } - // func.func @main_2(...) { ... } - // } + // For example, as input: + // hal.executable @some_executable_0 { + // hal.interface... + // hal.executable.variant @target_a, target="target-backend" { + // module { ... } // } // } - // // Other targets within executables are not modified - // hal.executable @main_dispatch_0 { - // hal.interface @io { ... } - // hal.executable.variant @other, target="other" { - // hal.executable.export @main_dispatch_0 attributes { ... } + // hal.executable @some_executable_1 { + // hal.interface... + // hal.executable.variant @target_b, target="target-backend" { + // module { ... } + // } + // hal.executable.variant @target_c, target="other-backend" { // module { ... } // } // } - virtual LogicalResult linkExecutables(mlir::ModuleOp moduleOp) { - return success(); - } + // + // As output: + // hal.executable @some_executable_1 { // untouched, not relevant + // hal.interface... + // hal.executable.variant @target_c, target="other-backend" { + // module { ... } + // } + // } + // hal.executable @some_executable_linked { + // hal.interface... + // hal.executable.variant @target_a, target="target-backend" { + // module { ... } + // } + // hal.executable.variant @target_b, target="target-backend" { + // module { ... } + // } + // } + virtual void buildLinkingPassPipeline(OpPassManager &passManager) {} struct SerializationOptions { // Debug level for serialization (0-3). @@ -236,17 +244,6 @@ class TargetBackend { assert(false && "unimplemented serializeExecutable"); return failure(); } - - protected: - // Links all executables for the current target found in |moduleOp| into - // |linkedExecutableOp|. Functions will be cloned into |linkedModuleOp|. - LogicalResult linkExecutablesInto( - mlir::ModuleOp moduleOp, - ArrayRef sourceExecutableOps, - IREE::HAL::ExecutableOp linkedExecutableOp, - IREE::HAL::ExecutableVariantOp linkedTargetOp, - std::function getInnerModuleFn, - OpBuilder &builder); }; // Dumps binary data to a file formed by joining the given path components: diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp index 3b3404378b62..f028ca60bce0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.h" #include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h" +#include "iree/compiler/Codegen/Passes.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h" @@ -17,7 +18,6 @@ #include "iree/compiler/Dialect/VMVX/Transforms/Passes.h" #include "iree/compiler/Utils/FlatbufferUtils.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OperationSupport.h" @@ -69,42 +69,8 @@ class VMVXTargetBackend final : public TargetBackend { IREE::VM::buildVMTransformPassPipeline(nestedModulePM, vmOptions); } - LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override { - OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); - - auto sourceExecutableOps = - llvm::to_vector<8>(moduleOp.getOps()); - if (sourceExecutableOps.size() <= 1) return success(); - - // TODO(benvanik): rework linking to support multiple formats. - auto sharedTargetAttr = getExecutableTarget(builder.getContext()); - - // Create our new "linked" hal.executable. - std::string linkedExecutableName = llvm::formatv("{0}_linked", name()); - auto linkedExecutableOp = builder.create( - moduleOp.getLoc(), linkedExecutableName); - linkedExecutableOp.setVisibility( - sourceExecutableOps.front().getVisibility()); - - // Add our VMVX hal.executable.variant with an empty module. - builder.setInsertionPointToStart(&linkedExecutableOp.getBlock()); - auto linkedTargetOp = builder.create( - moduleOp.getLoc(), sharedTargetAttr.getSymbolNameFragment(), - sharedTargetAttr); - builder.setInsertionPoint(&linkedTargetOp.getBlock().back()); - auto linkedModuleOp = builder.create(moduleOp.getLoc()); - - // Add an empty vm.module to that module (as our vm.funcs must live in it). - builder.setInsertionPointToStart(linkedModuleOp.getBody()); - builder.create(moduleOp.getLoc(), "linked_module"); - - // Try linking together all executables in moduleOp. - return linkExecutablesInto( - moduleOp, sourceExecutableOps, linkedExecutableOp, linkedTargetOp, - [](mlir::ModuleOp moduleOp) { - return *moduleOp.getOps().begin(); - }, - builder); + void buildLinkingPassPipeline(OpPassManager &passManager) override { + buildVMVXLinkingPassPipeline(passManager); } LogicalResult serializeExecutable(const SerializationOptions &options, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/BUILD index 7ccff4577c91..04494b966911 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/BUILD @@ -16,7 +16,6 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ - "linking.mlir", "smoketest.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/CMakeLists.txt index 9105b0e50206..62958ba0b730 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/CMakeLists.txt @@ -14,7 +14,6 @@ iree_lit_test_suite( NAME lit SRCS - "linking.mlir" "smoketest.mlir" TOOLS FileCheck diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir index b57becb678dc..5e420395a602 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/test/smoketest.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false},canonicalize' --mlir-print-local-scope %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false},canonicalize' --iree-vm-target-index-bits=64 --mlir-print-local-scope %s | FileCheck %s module attributes { hal.device.targets = [ @@ -50,27 +50,27 @@ stream.executable public @add_dispatch_0 { // CHECK-NEXT: vm.func private @add_dispatch_0( // CHECK-SAME: %[[SCRATCHPAD:.+]]: !vm.buffer, %[[CONSTANTS:.+]]: !vm.buffer, // CHECK-SAME: %[[BINDINGS:.+]]: !vm.list -// CHECK-DAG: %zero = vm.const.i32.zero -// CHECK-DAG: %c1 = vm.const.i32 1 -// CHECK-DAG: %c2 = vm.const.i32 2 -// CHECK-DAG: %c4 = vm.const.i32 4 -// CHECK-NEXT: %[[LHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %zero : (!vm.list, i32) -> !vm.buffer -// CHECK-NEXT: %[[RHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %c1 : (!vm.list, i32) -> !vm.buffer -// CHECK-NEXT: %[[RET_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %c2 : (!vm.list, i32) -> !vm.buffer -// CHECK: vm.br ^bb1(%zero : i32) -// CHECK-NEXT: ^bb1(%[[IDX:.+]]: i32): // 2 preds: ^bb0, ^bb2 -// CHECK-NEXT: %slt = vm.cmp.lt.i32.s %[[IDX]], %{{.+}} : i32 +// CHECK-DAG: %[[C0_I32:.+]] = vm.const.i32.zero +// CHECK-DAG: %[[C0_I64:.+]] = vm.const.i64.zero +// CHECK-DAG: %[[C1_I32:.+]] = vm.const.i32 1 +// CHECK-DAG: %[[C1_I64:.+]] = vm.const.i64 1 +// CHECK-DAG: %[[C2_I32:.+]] = vm.const.i32 2 +// CHECK-NEXT: %[[LHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C0_I32]] : (!vm.list, i32) -> !vm.buffer +// CHECK-NEXT: %[[RHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C1_I32]] : (!vm.list, i32) -> !vm.buffer +// CHECK-NEXT: %[[RET_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C2_I32]] : (!vm.list, i32) -> !vm.buffer +// CHECK: vm.br ^bb1(%[[C0_I64]] : i64) +// CHECK-NEXT: ^bb1(%[[IDX:.+]]: i64): +// CHECK-NEXT: %slt = vm.cmp.lt.i64.s %[[IDX]], %{{.+}} : i64 // CHECK-NEXT: vm.cond_br %slt, ^bb2, ^bb3 -// CHECK-NEXT: ^bb2: // pred: ^bb1 -// CHECK: %[[BYTE_OFFSET_32:.+]] = vm.mul.i32 %{{.+}}, %c4 : i32 -// CHECK-NEXT: %[[BYTE_OFFSET:.+]] = vm.ext.i32.i64.u %[[BYTE_OFFSET_32]] -// CHECK-NEXT: %[[LHS:.+]] = vm.buffer.load.f32 %[[LHS_BUF]][%[[BYTE_OFFSET]]] : !vm.buffer -> f32 -// CHECK-NEXT: %[[RHS:.+]] = vm.buffer.load.f32 %[[RHS_BUF]][%[[BYTE_OFFSET]]] : !vm.buffer -> f32 +// CHECK-NEXT: ^bb2: +// CHECK-NEXT: %[[ELEMENT_OFFSET:.+]] = vm.add.i64 %[[IDX]], %{{.+}} +// CHECK-NEXT: %[[LHS:.+]] = vm.buffer.load.f32 %[[LHS_BUF]][%[[ELEMENT_OFFSET]]] : !vm.buffer -> f32 +// CHECK-NEXT: %[[RHS:.+]] = vm.buffer.load.f32 %[[RHS_BUF]][%[[ELEMENT_OFFSET]]] : !vm.buffer -> f32 // CHECK-NEXT: %[[RET:.+]] = vm.add.f32 %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: vm.buffer.store.f32 %[[RET]], %[[RET_BUF]][%[[BYTE_OFFSET]]] : f32 -> !vm.buffer -// CHECK-NEXT: %[[NEXT_IDX:.+]] = vm.add.i32 %[[IDX]], %c1 : i32 -// CHECK-NEXT: vm.br ^bb1(%[[NEXT_IDX]] : i32) -// CHECK-NEXT: ^bb3: // pred: ^bb1 +// CHECK-NEXT: vm.buffer.store.f32 %[[RET]], %[[RET_BUF]][%[[ELEMENT_OFFSET]]] : f32 -> !vm.buffer +// CHECK-NEXT: %[[NEXT_IDX:.+]] = vm.add.i64 %[[IDX]], %[[C1_I64]] : i64 +// CHECK-NEXT: vm.br ^bb1(%[[NEXT_IDX]] : i64) +// CHECK-NEXT: ^bb3: // CHECK-NEXT: vm.return // CHECK-NEXT: } // CHECK-NEXT: vm.export @add_dispatch_0 diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index fffc2d529fb4..3c5bbb9568c3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -37,32 +37,6 @@ namespace iree_compiler { namespace IREE { namespace HAL { -namespace { -llvm::Optional findFirstFileLoc(Location baseLoc) { - if (auto loc = baseLoc.dyn_cast()) { - for (auto &childLoc : loc.getLocations()) { - auto childResult = findFirstFileLoc(childLoc); - if (childResult) return childResult; - } - } else if (auto loc = baseLoc.dyn_cast()) { - return loc; - } - return llvm::None; -} - -std::string guessModuleName(ModuleOp moduleOp) { - std::string moduleName = - moduleOp.getName().has_value() ? moduleOp.getName().value().str() : ""; - if (!moduleName.empty()) return moduleName; - auto loc = findFirstFileLoc(moduleOp.getLoc()); - if (loc.has_value()) { - return llvm::sys::path::stem(loc.value().getFilename()).str(); - } else { - return "spirv_module"; - } -} -} // namespace - VulkanSPIRVTargetOptions getVulkanSPIRVTargetOptionsFromFlags() { // TODO(antiagainst): Enable option categories once the following bug is // fixed: https://bugs.llvm.org/show_bug.cgi?id=44223 static @@ -142,116 +116,6 @@ class VulkanSPIRVTargetBackend : public TargetBackend { buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false); } - // TODO(antiagainst): Re-enable SPIR-V linking once the tensorflow integration - // crash is fixed. -#if 0 - LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override { - // Note: Vulkan flavored SPIR-V does not have linking in the conventional - // sense. For example, there is no cross-module symbol reference and symbol - // resolution and such. It's more just combining all SPIR-V modules into the - // one, with multiple entry points. - - // 1. Create source executable groups according to their executable - // interface. We only combine executables in the same group. - - // Map from an executable interface's hash to all source executables having - // that interface. - llvm::DenseMap> - sourceExecutableOpGroups; - - int numExecutables = 0; - for (auto op : moduleOp.getOps()) { - llvm::hash_code hash = interfaceOps.front().getInterfaceHash(); - sourceExecutableOpGroups[hash].push_back(op); - - ++numExecutables; - } - if (numExecutables <= 1) return success(); - - SymbolTable symbolTable(moduleOp); - - auto sharedTargetsAttr = getExecutableTargets(moduleOp.getContext()); - if (llvm::size(sharedTargetsAttr) != 1) { - return moduleOp.emitError("only one executable target is supported now"); - } - - auto sharedTargetAttr = sharedTargetsAttr.getValue() - .front() - .cast(); - - // Guess a module name, if needed, to make the output files readable. - auto moduleName = guessModuleName(moduleOp); - - // 2. Create "linked" executables for each source executable group. - // This just pulls in spv.module ops that should be combined into the same - // hal.executable.variant inner module. - - SmallVector innerModuleOps; - innerModuleOps.reserve(sourceExecutableOpGroups.size()); - for (const auto &hashExecutablePair : sourceExecutableOpGroups) { - llvm::hash_code hash = hashExecutablePair.first; - const auto &sourceExecutableOps = hashExecutablePair.second; - - // Just one executable for this group. No need to link. - if (sourceExecutableOps.size() == 1) continue; - - OpBuilder builder(moduleOp.getContext()); - - // Create a new "linked" hal.executable for collecting all source - // executables in this group. - std::string linkedExecutableName = - llvm::formatv("{0}_linked_{1}", moduleName, name()); - auto linkedExecutableOp = builder.create( - moduleOp.getLoc(), linkedExecutableName); - symbolTable.insert(linkedExecutableOp, moduleOp.getBody()->begin()); - - // Add our hal.executable.variant with an empty module. - builder.setInsertionPointToStart(linkedExecutableOp.getBody()); - auto linkedTargetOp = builder.create( - moduleOp.getLoc(), sharedTargetAttr.getSymbolNameFragment(), - sharedTargetAttr); - builder.setInsertionPoint(&linkedTargetOp.getBlock().back()); - innerModuleOps.push_back( - builder.create(moduleOp.getLoc())); - - // Try linking together all executables in moduleOp. - if (failed(linkExecutablesInto( - moduleOp, sourceExecutableOps, linkedExecutableOp, linkedTargetOp, - [](mlir::ModuleOp moduleOp) { return moduleOp; }, builder))) - return failure(); - } - - // 3. Now we can have multiple spv.module ops in the same - // hal.executable.variant inner module. Combining them into one. - - auto symbolRenameListener = [](spirv::ModuleOp symbolTable, - StringRef oldSymbol, StringRef newSymbol) { - // We don't care about global variable renaming. There should not exist - // duplicated functions. But double check that. - if (Operation *op = SymbolTable::lookupSymbolIn(symbolTable, oldSymbol)) { - assert(!isa(op) && - "found duplicated spv.func names when linking!"); - } - }; - - for (mlir::ModuleOp innerModule : innerModuleOps) { - auto spvModules = - llvm::to_vector<4>(innerModule.getBody()->getOps()); - if (spvModules.size() <= 1) continue; - - OpBuilder builder(innerModule); - auto newModule = builder.create(innerModule.getLoc()); - - // Create the combined spv.module op and erase the old inner module. - builder.setInsertionPointToStart(newModule.getBody()); - spirv::combine(spvModules, builder, symbolRenameListener).release(); - innerModule.erase(); - } - - return success(); - } -#endif - LogicalResult serializeExecutable(const SerializationOptions &options, IREE::HAL::ExecutableVariantOp variantOp, OpBuilder &executableBuilder) override { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp index 214e61523d20..583ec2406cf9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp @@ -12,6 +12,8 @@ #include "iree/compiler/Dialect/HAL/Target/TargetBackend.h" #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -74,8 +76,16 @@ class AssignTargetDevicesPass for (const auto &targetName : targets) { auto targetBackend = getTargetBackend(targetName); if (!targetBackend) { + std::string backends; + llvm::raw_string_ostream os(backends); + llvm::interleaveComma( + getTargetBackends(getRegisteredTargetBackends()), os, + [&os](const std::shared_ptr< + mlir::iree_compiler::IREE::HAL::TargetBackend> + b) { os << b->name(); }); moduleOp.emitError() - << "target backend '" << targetName << "' not registered"; + << "target backend '" << targetName + << "' not registered; registered backends: " << os.str(); signalPassFailure(); return; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp index 81701a7d81b8..dbf4ee126d09 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" namespace mlir { namespace iree_compiler { @@ -31,14 +32,6 @@ class LinkTargetExecutablesPass LinkTargetExecutablesPass(const LinkTargetExecutablesPass &pass) {} LinkTargetExecutablesPass(StringRef target) { this->target = target.str(); } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - auto targetBackend = getTargetBackend(target); - if (targetBackend) { - targetBackend->getDependentDialects(registry); - } - } - StringRef getArgument() const override { return "iree-hal-link-target-executables"; } @@ -47,6 +40,14 @@ class LinkTargetExecutablesPass return "Links together hal.executables for the specified target."; } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + auto targetBackend = getTargetBackend(target); + if (targetBackend) { + targetBackend->getDependentDialects(registry); + } + } + void runOnOperation() override { auto moduleOp = getOperation(); auto targetBackend = getTargetBackend(target); @@ -55,23 +56,14 @@ class LinkTargetExecutablesPass return signalPassFailure(); } - // Ask the target backend to link all executables it wants. - if (failed(targetBackend->linkExecutables(moduleOp))) { - moduleOp.emitError() << "failed to link executables for target backend " - << target; + OpPassManager passManager(moduleOp.getOperationName()); + targetBackend->buildLinkingPassPipeline(passManager); + if (failed(runPipeline(passManager, moduleOp))) { + moduleOp.emitError() + << "failed to run linking of executable variants for backend " + << target; return signalPassFailure(); } - - // Backends may move target ops from executables into linked executables. - // If an executable ends up with no targets, remove it. - auto executableOps = - llvm::to_vector<4>(moduleOp.getOps()); - for (auto executableOp : executableOps) { - auto targetOps = executableOp.getOps(); - if (targetOps.empty()) { - executableOp.erase(); - } - } } private: @@ -103,10 +95,19 @@ class LinkExecutablesPass void runOnOperation() override { auto moduleOp = getOperation(); + + // Add pipelines for each target backend used in the module. + // These will create/rearrange executables. OpPassManager passManager(moduleOp.getOperationName()); for (const auto &targetName : gatherExecutableTargetNames(moduleOp)) { passManager.addPass(createLinkTargetExecutablesPass(targetName)); } + + // Cleanup any remaining empty executables after each pipeline has run. + // We do this to aid debugging as then the pipelines can (mostly) be run in + // any order and not radically change the IR. + passManager.addPass(mlir::createSymbolDCEPass()); + if (failed(runPipeline(passManager, moduleOp))) { moduleOp.emitError() << "failed to link executables"; return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index d3c7fb390d19..4c5513fb2c5c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -21,7 +21,7 @@ module attributes {hal.device.targets = [#device_target_cpu]} { hal.executable private @ex { hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info + translation_info = #iree_codegen.translation_info } { ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir index 8a396b53c45d..914346e9cf7f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir @@ -28,7 +28,7 @@ module attributes {hal.device.targets = [#device_target_cpu]} { hal.executable private @ex0 { hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout_0) attributes { - translation_info = #iree_codegen.translation_info + translation_info = #iree_codegen.translation_info } { ^bb0(%device: !hal.device, %arg0: index): %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 @@ -41,7 +41,7 @@ module attributes {hal.device.targets = [#device_target_cpu]} { } hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout_1) attributes { - translation_info = #iree_codegen.translation_info + translation_info = #iree_codegen.translation_info } { ^bb0(%device: !hal.device, %arg0: index, %arg1: index): %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir index 80f00677d2ef..455295debed8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir @@ -23,7 +23,7 @@ module attributes {hal.device.targets = [#device_target_cpu]} { // CHECK-NEXT: hal.executable.variant {{.+}}, target = <"llvm-cpu" hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info + translation_info = #iree_codegen.translation_info } { ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors %c1 = arith.constant 1 : index @@ -42,7 +42,7 @@ module attributes {hal.device.targets = [#device_target_cpu]} { hal.executable private @ex1 { hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { hal.executable.export public @dispatch1 ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info + translation_info = #iree_codegen.translation_info } { ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index 904a99ea146d..083560a2e813 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -86,7 +86,7 @@ class AbstractResourceUsage static_assert(BEST_STATE == BaseType::getBestState(), "unexpected BEST_STATE value"); - static bool isValidState(uint16_t bits) { + static bool isValidStateBits(uint16_t bits) { // bool isIndirect = (bits & NOT_INDIRECT) != NOT_INDIRECT; // bool isExternal = (bits & NOT_EXTERNAL) != NOT_EXTERNAL; bool isMutated = (bits & NOT_MUTATED) != NOT_MUTATED; @@ -117,6 +117,11 @@ class AbstractResourceUsage return true; } + bool isValidState() const override { + return this->getAssumed() != BaseType::getWorstState() && + isValidStateBits(this->getAssumed()); + } + ResourceUsageBitfield convertBitsToResourceUsage(uint16_t bits) const { return static_cast(~bits & BEST_STATE); } @@ -131,6 +136,7 @@ class AbstractResourceUsage const std::string getAsStr(AsmState &asmState) const override { std::string str; + if (!isValidState()) return "*"; auto append = [&](const char *part) { if (!str.empty()) str += '|'; str += part; @@ -363,13 +369,24 @@ class ValueResourceUsage : public AbstractResourceUsage { if (kFavorTransients && isSourceExternal && isTargetInternal) { LLVM_DEBUG({ llvm::dbgs() << "[ValueResourceUsage] skipping forward prop of " - "external into internal:"; + "external into internal: "; op.print(llvm::dbgs(), solver.getAsmState()); llvm::dbgs() << "\n"; }); return; } - getState() ^= sourceUsage.getState(); + auto newState = getState(); + newState ^= sourceUsage.getState(); + if (!newState.isValidState()) { + LLVM_DEBUG({ + llvm::dbgs() << "[ValueResourceUsage] skipping update from " + "producer as it would create an invalid state: "; + op.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + return; + } + getState() = newState; }) .Case([&](IREE::Stream::AsyncStoreOp op) { removeAssumedBits(NOT_STAGING_WRITE); @@ -531,14 +548,26 @@ class ValueResourceUsage : public AbstractResourceUsage { bool isTargetExternal = !resultUsage.isAssumed(NOT_EXTERNAL); if (kFavorTransients && isSourceInternal && isTargetExternal) { LLVM_DEBUG({ - llvm::dbgs() << "[ValueResourceUsage] skipping back prop of " - "external into internal due to kFavorTransients:"; + llvm::dbgs() + << "[ValueResourceUsage] skipping back prop of external into " + "internal due to kFavorTransients: "; op.print(llvm::dbgs(), solver.getAsmState()); llvm::dbgs() << "\n"; }); return; } - getState() ^= resultUsage.getState(); + auto newState = getState(); + newState ^= resultUsage.getState(); + if (!newState.isValidState()) { + LLVM_DEBUG({ + llvm::dbgs() << "[ValueResourceUsage] skipping update from use " + "as it would create an invalid state: "; + op.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + return; + } + getState() = newState; }) .Case([&](IREE::Stream::AsyncLoadOp op) { removeAssumedBits(NOT_STAGING_READ); @@ -589,12 +618,6 @@ class ValueResourceUsage : public AbstractResourceUsage { removeAssumedBits(NOT_EXTERNAL); } - // Filter out impossible states by marking the state invalid. - // The fixpoint framework will try again. - if (!isValidState(assumedBits)) { - return indicatePessimisticFixpoint(); - } - return assumedBits == getAssumed() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp index c3e66c025fc6..f529decdf6ee 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp @@ -340,7 +340,9 @@ class PackDispatchOperandsPass auto exportOp = symbolTable.lookupNearestSymbolFrom( dispatchOp, dispatchOp.getEntryPoint()); - updateDispatchOp(dispatchOp, exportOp); + if (exportOp) { + updateDispatchOp(dispatchOp, exportOp); + } return WalkResult::advance(); }); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp index d3368e58578c..6b23643e040a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp @@ -161,6 +161,7 @@ struct UsageRefinementPattern : public OpRewritePattern { return true; } else { // Directly overwrite the existing lifetime. + assert(result.getType() != newType); result.setType(newType); return true; } @@ -275,6 +276,7 @@ struct ApplyGenericOp : public UsageRefinementPattern { PatternRewriter &rewriter) const override { bool didChange = this->applyRegionTransitions(op, rewriter); rewriter.startRootUpdate(op); + rewriter.setInsertionPointAfter(op); for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (result.getType().template isa()) { @@ -304,9 +306,10 @@ struct ApplyStreamableOp : public UsageRefinementPattern { auto affinityAttr = getOpAffinity(op); rewriter.startRootUpdate(op); + rewriter.setInsertionPointAfter(op); auto sizeAwareOp = - dyn_cast(op.getOperation()); + cast(op.getOperation()); for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (!result.getType().template isa()) { @@ -331,6 +334,7 @@ struct ApplyStreamableOp : public UsageRefinementPattern { static void insertUsageRefinementPatterns(MLIRContext *context, ResourceUsageAnalysis &analysis, RewritePatternSet &patterns) { + // NOTE: only ops that return values or contain regions need to be handled. patterns.insert(context, analysis); patterns.insert, ApplyGenericOp, @@ -386,7 +390,10 @@ class RefineUsagePass : public RefineUsageBase { RewritePatternSet patterns(&getContext()); insertUsageRefinementPatterns(&getContext(), analysis, patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(moduleOp, frozenPatterns))) { + GreedyRewriteConfig rewriteConfig; + rewriteConfig.useTopDownTraversal = true; + if (failed(applyPatternsAndFoldGreedily(moduleOp, frozenPatterns, + rewriteConfig))) { return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp index baeecbb4b7dd..e97f98e81077 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp @@ -1189,6 +1189,9 @@ static LogicalResult allocateExecutionRegion( auto resultAllocation = reserveResultAllocation(resultReservations); for (auto &reservationSet : resultAllocation.reservationSets) { // Allocate and tie an operand to the result. + // TODO(benvanik): change this to an alloca. We may need a higher-level + // analysis to decide when to deallocate, or just leave it to be deallocated + // as part of garbage collection. auto allocOp = externalBuilder.create( externalBuilder.getFusedLoc(reservationSet.reservationLocs), reservationSet.reservationTypes, reservationSet.reservationSizes, diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td index 5c61a057933c..a5246fa303f4 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td @@ -65,9 +65,7 @@ def Util_TiedOpStorageAttr : let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; } -def Util_GlobalRefAttr : ConfinedAttr, -]>; +defvar Util_GlobalRefAttr = FlatSymbolRefAttr; def Util_AnySerializableAttr : Attr()">, diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp index 0ec857ab7947..1f8fbf4c7fc4 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp @@ -157,6 +157,14 @@ LogicalResult detail::verifyGlobalOp(IREE::Util::GlobalOpInterface globalOp) { LogicalResult detail::verifyGlobalAddressOp( GlobalAddressOpInterface addressOp, SymbolTableCollection &symbolTable) { + if (!isa_and_nonnull( + symbolTable.lookupNearestSymbolFrom(addressOp.getOperation(), + addressOp.getGlobalAttr()))) { + return addressOp->emitOpError( + "attribute 'global' failed to satisfy constraint: flat symbol " + "reference attribute referencing to a 'IREE::Util::GlobalOpInterface' " + "symbol"); + } auto globalOp = lookupGlobalOp(addressOp, addressOp.getGlobalAttr(), symbolTable); if (!globalOp) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp index 1f1a332bf659..5541f91d9636 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp @@ -804,6 +804,9 @@ class ZeroExtendIOpConversion : public OpConversionPattern { } else if (srcType.isInteger(8) && dstType.isInteger(32)) { rewriter.replaceOpWithNewOp(srcOp, dstType, adaptor.getIn()); + } else if (srcType.isInteger(8) && dstType.isInteger(64)) { + rewriter.replaceOpWithNewOp(srcOp, dstType, + adaptor.getIn()); } else if (srcType.isInteger(16) && dstType.isInteger(32)) { rewriter.replaceOpWithNewOp(srcOp, dstType, adaptor.getIn()); @@ -811,8 +814,6 @@ class ZeroExtendIOpConversion : public OpConversionPattern { rewriter.replaceOpWithNewOp(srcOp, dstType, adaptor.getIn()); } else { - // TODO(benvanik): we should be building a sequence of extensions for - // things like i8 -> i64. return rewriter.notifyMatchFailure(srcOp, "unsupported zero extension"); } return success(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir index 1e78bfb39c03..18ab6e32c167 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="test-iree-convert-std-to-vm" %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="test-iree-convert-std-to-vm" --iree-vm-target-index-bits=32 %s | FileCheck %s // ----- // CHECK-LABEL: @t001_cmp_select diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/func_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/func_ops.mlir index 6e72a2e1693a..a202b2445bfb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/func_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/func_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --allow-unregistered-dialect --pass-pipeline="test-iree-convert-std-to-vm" %s | FileCheck %s +// RUN: iree-opt --split-input-file --allow-unregistered-dialect --pass-pipeline="test-iree-convert-std-to-vm" --iree-vm-target-index-bits=32 %s | FileCheck %s // ----- // CHECK-LABEL: @t001_iree_reflection diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/TargetOptions.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/TargetOptions.h index eb761dac8629..f6f9f6c6133b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/TargetOptions.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/TargetOptions.h @@ -18,7 +18,7 @@ namespace VM { // Controls VM translation targets. struct TargetOptions { // Target size of `index` when converted to an integer in bits. - int indexBits = 32; + int indexBits = 64; // Whether the f32 extension is enabled in the target VM. bool f32Extension = true; diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp index 8b21f7901920..8c8ec48975f4 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp @@ -200,6 +200,14 @@ struct BufferFillOpConversion } }; +static Value unscaleOffset(Location loc, Value offset, int64_t scale, + OpBuilder &builder) { + if (scale == 1) return offset; + return builder.createOrFold( + loc, offset.getType(), offset, + builder.create(loc, scale)); +} + struct BufferLoadOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -224,27 +232,33 @@ struct BufferLoadOpConversion } else if (integerType.isInteger(16)) { if (integerType.isSigned() || integerType.isSignless()) { rewriter.replaceOpWithNewOp( - loadOp, newType, adaptor.getSource(), byteOffset); + loadOp, newType, adaptor.getSource(), + unscaleOffset(loadOp.getLoc(), byteOffset, 2, rewriter)); } else { rewriter.replaceOpWithNewOp( - loadOp, newType, adaptor.getSource(), byteOffset); + loadOp, newType, adaptor.getSource(), + unscaleOffset(loadOp.getLoc(), byteOffset, 2, rewriter)); } } else if (integerType.isInteger(32)) { rewriter.replaceOpWithNewOp( - loadOp, newType, adaptor.getSource(), byteOffset); + loadOp, newType, adaptor.getSource(), + unscaleOffset(loadOp.getLoc(), byteOffset, 4, rewriter)); } else if (integerType.isInteger(64)) { rewriter.replaceOpWithNewOp( - loadOp, newType, adaptor.getSource(), byteOffset); + loadOp, newType, adaptor.getSource(), + unscaleOffset(loadOp.getLoc(), byteOffset, 8, rewriter)); } else { return rewriter.notifyMatchFailure( loadOp, "invalid integer buffer element type"); } } else if (oldType.isF32()) { rewriter.replaceOpWithNewOp( - loadOp, newType, adaptor.getSource(), byteOffset); + loadOp, newType, adaptor.getSource(), + unscaleOffset(loadOp.getLoc(), byteOffset, 4, rewriter)); } else if (oldType.isF64()) { rewriter.replaceOpWithNewOp( - loadOp, newType, adaptor.getSource(), byteOffset); + loadOp, newType, adaptor.getSource(), + unscaleOffset(loadOp.getLoc(), byteOffset, 8, rewriter)); } else { return rewriter.notifyMatchFailure(loadOp, "invalid float buffer element type"); @@ -270,19 +284,29 @@ struct BufferStoreOpConversion storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource()); } else if (oldType.isInteger(16)) { rewriter.replaceOpWithNewOp( - storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource()); + storeOp, adaptor.getTarget(), + unscaleOffset(storeOp.getLoc(), byteOffset, 2, rewriter), + adaptor.getSource()); } else if (oldType.isInteger(32)) { rewriter.replaceOpWithNewOp( - storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource()); + storeOp, adaptor.getTarget(), + unscaleOffset(storeOp.getLoc(), byteOffset, 4, rewriter), + adaptor.getSource()); } else if (oldType.isInteger(64)) { rewriter.replaceOpWithNewOp( - storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource()); + storeOp, adaptor.getTarget(), + unscaleOffset(storeOp.getLoc(), byteOffset, 8, rewriter), + adaptor.getSource()); } else if (oldType.isF32()) { rewriter.replaceOpWithNewOp( - storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource()); + storeOp, adaptor.getTarget(), + unscaleOffset(storeOp.getLoc(), byteOffset, 4, rewriter), + adaptor.getSource()); } else if (oldType.isF64()) { rewriter.replaceOpWithNewOp( - storeOp, adaptor.getTarget(), byteOffset, adaptor.getSource()); + storeOp, adaptor.getTarget(), + unscaleOffset(storeOp.getLoc(), byteOffset, 8, rewriter), + adaptor.getSource()); } else { return rewriter.notifyMatchFailure(storeOp, "invalid buffer element type"); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir index d1579671463a..1d4dd0ee1c40 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/alignment_ops.mlir @@ -1,15 +1,14 @@ -// RUN: iree-opt --split-input-file --iree-vm-conversion --cse %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-vm-conversion --cse --iree-vm-target-index-bits=32 %s | FileCheck %s // CHECK-LABEL: @utilAlign func.func @utilAlign(%arg0 : index, %arg1: index) -> (index) { %result = util.align %arg0, %arg1 : index - //CHECK-DAG: %c1 = vm.const.i32 1 - //CHECK-DAG: %0 = vm.sub.i32 %arg1, %c1 : i32 - //CHECK-DAG: %1 = vm.add.i32 %arg0, %0 : i32 - //CHECK-DAG: %2 = vm.not.i32 %0 : i32 - //CHECK-DAG: %3 = vm.and.i32 %1, %2 : i32 - - //CHECK-DAG: vm.return %3 : i32 + // CHECK-DAG: %c1 = vm.const.i32 1 + // CHECK-DAG: %0 = vm.sub.i32 %arg1, %c1 : i32 + // CHECK-DAG: %1 = vm.add.i32 %arg0, %0 : i32 + // CHECK-DAG: %2 = vm.not.i32 %0 : i32 + // CHECK-DAG: %3 = vm.and.i32 %1, %2 : i32 + // CHECK: vm.return %3 : i32 return %result : index } @@ -18,13 +17,12 @@ func.func @utilAlign(%arg0 : index, %arg1: index) -> (index) { // CHECK-LABEL: @utilAlignInt32 func.func @utilAlignInt32(%arg0 : i32, %arg1: i32) -> (i32) { %result = util.align %arg0, %arg1 : i32 - //CHECK-DAG: %c1 = vm.const.i32 1 - //CHECK-DAG: %0 = vm.sub.i32 %arg1, %c1 : i32 - //CHECK-DAG: %1 = vm.add.i32 %arg0, %0 : i32 - //CHECK-DAG: %2 = vm.not.i32 %0 : i32 - //CHECK-DAG: %3 = vm.and.i32 %1, %2 : i32 - - //CHECK-DAG: vm.return %3 : i32 + // CHECK-DAG: %c1 = vm.const.i32 1 + // CHECK-DAG: %0 = vm.sub.i32 %arg1, %c1 : i32 + // CHECK-DAG: %1 = vm.add.i32 %arg0, %0 : i32 + // CHECK-DAG: %2 = vm.not.i32 %0 : i32 + // CHECK-DAG: %3 = vm.and.i32 %1, %2 : i32 + // CHECK: vm.return %3 : i32 return %result : i32 } @@ -33,13 +31,12 @@ func.func @utilAlignInt32(%arg0 : i32, %arg1: i32) -> (i32) { // CHECK-LABEL: @utilAlignInt64 func.func @utilAlignInt64(%arg0 : i64, %arg1: i64) -> (i64) { %result = util.align %arg0, %arg1 : i64 - //CHECK-DAG: %c1 = vm.const.i64 1 - //CHECK-DAG: %0 = vm.sub.i64 %arg1, %c1 : i64 - //CHECK-DAG: %1 = vm.add.i64 %arg0, %0 : i64 - //CHECK-DAG: %2 = vm.not.i64 %0 : i64 - //CHECK-DAG: %3 = vm.and.i64 %1, %2 : i64 - - //CHECK-DAG: vm.return %3 : i64 + // CHECK-DAG: %c1 = vm.const.i64 1 + // CHECK-DAG: %0 = vm.sub.i64 %arg1, %c1 : i64 + // CHECK-DAG: %1 = vm.add.i64 %arg0, %0 : i64 + // CHECK-DAG: %2 = vm.not.i64 %0 : i64 + // CHECK-DAG: %3 = vm.and.i64 %1, %2 : i64 + // CHECK: vm.return %3 : i64 return %result : i64 } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir index 0f4542e62fa8..f20947a1f5e6 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir @@ -177,11 +177,12 @@ func.func @buffer_fill_index(%arg0: !util.buffer, %arg1: index, %arg2: index) { // CHECK-LABEL: @buffer_load_i1 func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i1 { - %c100 = arith.constant 100 : index - // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100 - // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%[[C100]]] : !vm.buffer -> i32 - // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%c100] : !vm.buffer -> i32 - %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i1 + %byte_offset = arith.constant 128 : index + // CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128 + // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32 + // CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128 + // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32 + %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> i1 // CHECK: return %[[VALUE]] return %0 : i1 } @@ -190,11 +191,12 @@ func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i1 { // CHECK-LABEL: @buffer_load_i32 func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i32 { - %c100 = arith.constant 100 : index - // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100 - // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%[[C100]]] : !vm.buffer -> i32 - // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%c100] : !vm.buffer -> i32 - %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i32 + %byte_offset = arith.constant 128 : index + // CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32 + // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32 + // CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32 + // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32 + %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> i32 // CHECK: return %[[VALUE]] return %0 : i32 } @@ -203,11 +205,12 @@ func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i32 { // CHECK-LABEL: @buffer_load_i64 func.func @buffer_load_i64(%arg0: !util.buffer, %arg1: index) -> i64 { - %c100 = arith.constant 100 : index - // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100 - // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%[[C100]]] : !vm.buffer -> i64 - // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%c100] : !vm.buffer -> i64 - %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i64 + %byte_offset = arith.constant 128 : index + // CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16 + // CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i64 + // CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16 + // CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i64 + %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> i64 // CHECK: return %[[VALUE]] return %0 : i64 } @@ -216,10 +219,10 @@ func.func @buffer_load_i64(%arg0: !util.buffer, %arg1: index) -> i64 { // CHECK-LABEL: @buffer_load_index func.func @buffer_load_index(%arg0: !util.buffer, %arg1: index) -> index { - %c100 = arith.constant 100 : index + %byte_offset = arith.constant 100 : index // CHECK-32: vm.buffer.load.i32 // CHECK-64: vm.buffer.load.i64 - %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> index + %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> index return %0 : index } @@ -227,11 +230,12 @@ func.func @buffer_load_index(%arg0: !util.buffer, %arg1: index) -> index { // CHECK-LABEL: @buffer_store_i1 func.func @buffer_store_i1(%arg0: !util.buffer, %arg1: index, %arg2: i1) { - %c100 = arith.constant 100 : index - // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100 - // CHECK-32: vm.buffer.store.i8 %arg2, %arg0[%[[C100]]] : i32 -> !vm.buffer - // CHECK-64: vm.buffer.store.i8 %arg2, %arg0[%c100] : i32 -> !vm.buffer - util.buffer.store %arg2, %arg0[%c100] : i1 -> !util.buffer{%arg1} + %byte_offset = arith.constant 128 : index + // CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128 + // CHECK-32: vm.buffer.store.i8 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer + // CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128 + // CHECK-64: vm.buffer.store.i8 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer + util.buffer.store %arg2, %arg0[%byte_offset] : i1 -> !util.buffer{%arg1} return } @@ -239,11 +243,12 @@ func.func @buffer_store_i1(%arg0: !util.buffer, %arg1: index, %arg2: i1) { // CHECK-LABEL: @buffer_store_i32 func.func @buffer_store_i32(%arg0: !util.buffer, %arg1: index, %arg2: i32) { - %c100 = arith.constant 100 : index - // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100 - // CHECK-32: vm.buffer.store.i32 %arg2, %arg0[%[[C100]]] : i32 -> !vm.buffer - // CHECK-64: vm.buffer.store.i32 %arg2, %arg0[%c100] : i32 -> !vm.buffer - util.buffer.store %arg2, %arg0[%c100] : i32 -> !util.buffer{%arg1} + %byte_offset = arith.constant 128 : index + // CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32 + // CHECK-32: vm.buffer.store.i32 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer + // CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32 + // CHECK-64: vm.buffer.store.i32 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer + util.buffer.store %arg2, %arg0[%byte_offset] : i32 -> !util.buffer{%arg1} return } @@ -251,11 +256,12 @@ func.func @buffer_store_i32(%arg0: !util.buffer, %arg1: index, %arg2: i32) { // CHECK-LABEL: @buffer_store_i64 func.func @buffer_store_i64(%arg0: !util.buffer, %arg1: index, %arg2: i64) { - %c100 = arith.constant 100 : index - // CHECK-32-DAG: %[[C100:.+]] = vm.const.i64 100 - // CHECK-32: vm.buffer.store.i64 %arg2, %arg0[%[[C100]]] : i64 -> !vm.buffer - // CHECK-64: vm.buffer.store.i64 %arg2, %arg0[%c100] : i64 -> !vm.buffer - util.buffer.store %arg2, %arg0[%c100] : i64 -> !util.buffer{%arg1} + %byte_offset = arith.constant 128 : index + // CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16 + // CHECK-32: vm.buffer.store.i64 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i64 -> !vm.buffer + // CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16 + // CHECK-64: vm.buffer.store.i64 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i64 -> !vm.buffer + util.buffer.store %arg2, %arg0[%byte_offset] : i64 -> !util.buffer{%arg1} return } @@ -263,9 +269,9 @@ func.func @buffer_store_i64(%arg0: !util.buffer, %arg1: index, %arg2: i64) { // CHECK-LABEL: @buffer_store_index func.func @buffer_store_index(%arg0: !util.buffer, %arg1: index, %arg2: index) { - %c100 = arith.constant 100 : index + %byte_offset = arith.constant 100 : index // CHECK-32: vm.buffer.store.i32 // CHECK-64: vm.buffer.store.i64 - util.buffer.store %arg2, %arg0[%c100] : index -> !util.buffer{%arg1} + util.buffer.store %arg2, %arg0[%byte_offset] : index -> !util.buffer{%arg1} return } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/list_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/list_ops.mlir index ab404a6fc9d4..ed2322f21b76 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/list_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/list_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-vm-conversion %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-vm-conversion --iree-vm-target-index-bits=32 %s | FileCheck %s // CHECK-LABEL: @list_ops module @list_ops { module { diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index f75aa429ab64..24e1f11677d2 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -14,11 +14,9 @@ #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Dialect/VM/Utils/CallingConvention.h" #include "llvm/ADT/TypeSwitch.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" @@ -2694,12 +2692,10 @@ class CompareRefOpConversion : public OpConversionPattern { return cmpOp.emitError() << "parent func op not found in cache."; } - bool moveLhs = vmAnalysis.value().get().isLastValueUse( - cmpOp.getLhs(), cmpOp.getOperation()) && - false; - bool moveRhs = vmAnalysis.value().get().isLastValueUse( - cmpOp.getRhs(), cmpOp.getOperation()) && - false; + bool moveLhs = + vmAnalysis.value().get().isMove(cmpOp.getLhs(), cmpOp.getOperation()); + bool moveRhs = + vmAnalysis.value().get().isMove(cmpOp.getRhs(), cmpOp.getOperation()); Optional refLhs = typeConverter->materializeRef(cmpOp.getLhs()); @@ -2758,9 +2754,8 @@ class CompareRefNotZeroOpConversion return cmpOp.emitError() << "parent func op not found in cache."; } - bool move = vmAnalysis.value().get().isLastValueUse(cmpOp.getOperand(), - cmpOp.getOperation()) && - false; + bool move = vmAnalysis.value().get().isMove(cmpOp.getOperand(), + cmpOp.getOperation()); Optional ref = typeConverter->materializeRef(cmpOp.getOperand()); @@ -2809,15 +2804,8 @@ class ConstZeroOpConversion : public OpConversionPattern { ConstZeroOpTy constZeroOp, Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto type = constZeroOp.getType(); - Attribute value; - if (type.template isa()) { - value = rewriter.getIntegerAttr(type, 0); - } else if (type.template isa()) { - value = rewriter.getFloatAttr(type, 0.0); - } else { - return failure(); - } + Attribute value = rewriter.getZeroAttr(type); rewriter.replaceOpWithNewOp(constZeroOp, type, value); return success(); @@ -3494,8 +3482,7 @@ class GlobalLoadStoreRefOpConversion Value srcRef = isLoad ? stateRef : localRef.value(); Value destRef = isLoad ? localRef.value() : stateRef; - bool move = - vmAnalysis.value().get().isLastValueUse(localValue, op) && false; + bool move = vmAnalysis.value().get().isMove(localValue, op); returnIfError( /*rewriter=*/rewriter, @@ -4093,9 +4080,8 @@ class ListSetRefOpConversion if (failed(vmAnalysis)) { return setOp.emitError() << "parent func op not found in cache."; } - bool move = vmAnalysis.value().get().isLastValueUse(setOp.value(), - setOp.getOperation()) && - false; + bool move = + vmAnalysis.value().get().isMove(setOp.value(), setOp.getOperation()); StringRef callee = move ? "iree_vm_list_set_ref_move" : "iree_vm_list_set_ref_retain"; @@ -4481,8 +4467,7 @@ class ConvertVMToEmitCPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + mlir::func::FuncDialect, IREE::Util::UtilDialect>(); } StringRef getArgument() const override { return "iree-convert-vm-to-emitc"; } @@ -4538,10 +4523,9 @@ class ConvertVMToEmitCPass RewritePatternSet patterns(&getContext()); populateVMToEmitCPatterns(target, typeConverter, patterns); - target.addLegalDialect< - emitc::EmitCDialect, mlir::BuiltinDialect, mlir::cf::ControlFlowDialect, - mlir::func::FuncDialect, mlir::arith::ArithmeticDialect, - mlir::math::MathDialect>(); + target.addLegalDialect(); target.addDynamicallyLegalOp( [&](mlir::func::FuncOp op) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h index 89f76e9a3ff1..14349d01d146 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h @@ -51,9 +51,10 @@ struct VMAnalysis { return registerAllocation.mapToRegister(ref).ordinal(); } - bool isLastValueUse(Value ref, Operation *op) { + bool isMove(Value ref, Operation *op) { assert(ref.getType().isa()); - return valueLiveness.isLastValueUse(ref, op); + bool lastUse = valueLiveness.isLastValueUse(ref, op); + return lastUse && false; } void cacheLocalRef(int64_t ordinal, emitc::ApplyOp &applyOp) { diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD index 53c154b44925..a4f836d6afd4 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD +++ b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD @@ -73,6 +73,7 @@ iree_compiler_cc_library( ":VMOpInterfacesGen", ":VMOpsGen", "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:ControlFlowInterfaces", diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt index a92fe5a7d190..5b4350bb27b7 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt @@ -49,6 +49,7 @@ iree_cc_library( MLIRSupport MLIRTransformUtils iree::compiler::Dialect::Util::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index d218bd994a81..11e4b2f896ad 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -727,8 +727,8 @@ void MulI64Op::getCanonicalizationPatterns(RewritePatternSet &results, context); } -template -static OpFoldResult foldDivSOp(T op, ArrayRef operands) { +template +static OpFoldResult foldDivSOp(DivOpT op, ArrayRef operands) { if (matchPattern(op.getRhs(), m_Zero())) { // x / 0 = death op.emitOpError() << "is a divide by constant zero"; @@ -739,21 +739,31 @@ static OpFoldResult foldDivSOp(T op, ArrayRef operands) { } else if (matchPattern(op.getRhs(), m_One())) { // x / 1 = x return op.getLhs(); + } else if (auto mulOp = + dyn_cast_or_null(op.getLhs().getDefiningOp())) { + // Only applies to signed divides (matches LLVM behavior). + if (mulOp.getRhs() == op.getRhs()) { + // c = mul a, b + // d = div c, b + // -> + // d = a + return mulOp.getLhs(); + } } return constFoldBinaryOp( operands, [](const APInt &a, const APInt &b) { return a.sdiv(b); }); } OpFoldResult DivI32SOp::fold(ArrayRef operands) { - return foldDivSOp(*this, operands); + return foldDivSOp(*this, operands); } OpFoldResult DivI64SOp::fold(ArrayRef operands) { - return foldDivSOp(*this, operands); + return foldDivSOp(*this, operands); } -template -static OpFoldResult foldDivUOp(T op, ArrayRef operands) { +template +static OpFoldResult foldDivUOp(DivOpT op, ArrayRef operands) { if (matchPattern(op.getRhs(), m_Zero())) { // x / 0 = death op.emitOpError() << "is a divide by constant zero"; diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index eeb8280618f3..0acbedf9a69b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "iree/compiler/Utils/StringUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "mlir/IR/Attributes.h" @@ -1297,8 +1298,7 @@ void CondFailOp::print(OpAsmPrinter &p) { } void ImportResolvedOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { - std::string name = ("has_" + getImport()).str(); - std::replace(name.begin(), name.end(), '.', '_'); + std::string name = sanitizeSymbolName(("has_" + getImport()).str()); setResultName(setNameFn, getResult(), name); } diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir index 23d1a029c26d..67834b7f3844 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: @add_i32_folds vm.module @add_i32_folds { // CHECK-LABEL: @add_i32_0_y - vm.func @add_i32_0_y(%arg0 : i32) -> i32 { + vm.func @add_i32_0_y(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %zero = vm.const.i32.zero %0 = vm.add.i32 %zero, %arg0 : i32 @@ -13,7 +13,7 @@ vm.module @add_i32_folds { } // CHECK-LABEL: @add_i32_x_0 - vm.func @add_i32_x_0(%arg0 : i32) -> i32 { + vm.func @add_i32_x_0(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %zero = vm.const.i32.zero %0 = vm.add.i32 %arg0, %zero : i32 @@ -36,7 +36,7 @@ vm.module @add_i32_folds { // CHECK-LABEL: @sub_i32_folds vm.module @sub_i32_folds { // CHECK-LABEL: @sub_i32_x_0 - vm.func @sub_i32_x_0(%arg0 : i32) -> i32 { + vm.func @sub_i32_x_0(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %zero = vm.const.i32.zero %0 = vm.sub.i32 %arg0, %zero : i32 @@ -59,14 +59,14 @@ vm.module @sub_i32_folds { // CHECK-LABEL: @add_sub_i32_folds vm.module @add_sub_i32_folds { // CHECK-LABEL: @add_sub_x - vm.func @add_sub_x(%arg0 : i32, %arg1 : i32) -> i32 { + vm.func @add_sub_x(%arg0: i32, %arg1: i32) -> i32 { // CHECK-NEXT: vm.return %arg0 %0 = vm.add.i32 %arg0, %arg1 : i32 %1 = vm.sub.i32 %0, %arg1 : i32 vm.return %1 : i32 } // CHECK-LABEL: @add_sub_x_rev - vm.func @add_sub_x_rev(%arg0 : i32, %arg1 : i32) -> i32 { + vm.func @add_sub_x_rev(%arg0: i32, %arg1: i32) -> i32 { // CHECK-NEXT: vm.return %arg0 %0 = vm.add.i32 %arg1, %arg0 : i32 %1 = vm.sub.i32 %arg1, %0 : i32 @@ -74,14 +74,14 @@ vm.module @add_sub_i32_folds { } // CHECK-LABEL: @sub_add_x - vm.func @sub_add_x(%arg0 : i32, %arg1 : i32) -> i32 { + vm.func @sub_add_x(%arg0: i32, %arg1: i32) -> i32 { // CHECK-NEXT: vm.return %arg0 %0 = vm.sub.i32 %arg0, %arg1 : i32 %1 = vm.add.i32 %0, %arg1 : i32 vm.return %1 : i32 } // CHECK-LABEL: @sub_add_x_rev - vm.func @sub_add_x_rev(%arg0 : i32, %arg1 : i32) -> i32 { + vm.func @sub_add_x_rev(%arg0: i32, %arg1: i32) -> i32 { // CHECK-NEXT: vm.return %arg0 %0 = vm.sub.i32 %arg0, %arg1 : i32 %1 = vm.add.i32 %arg1, %0 : i32 @@ -94,7 +94,7 @@ vm.module @add_sub_i32_folds { // CHECK-LABEL: @mul_i32_folds vm.module @mul_i32_folds { // CHECK-LABEL: @mul_i32_by_0 - vm.func @mul_i32_by_0(%arg0 : i32) -> i32 { + vm.func @mul_i32_by_0(%arg0: i32) -> i32 { // CHECK: %zero = vm.const.i32.zero // CHECK-NEXT: vm.return %zero : i32 %zero = vm.const.i32.zero @@ -103,7 +103,7 @@ vm.module @mul_i32_folds { } // CHECK-LABEL: @mul_i32_1_y - vm.func @mul_i32_1_y(%arg0 : i32) -> i32 { + vm.func @mul_i32_1_y(%arg0: i32) -> i32 { // CHECK-NEXT: vm.return %arg0 : i32 %c1 = vm.const.i32 1 %0 = vm.mul.i32 %c1, %arg0 : i32 @@ -111,7 +111,7 @@ vm.module @mul_i32_folds { } // CHECK-LABEL: @mul_i32_x_1 - vm.func @mul_i32_x_1(%arg0 : i32) -> i32 { + vm.func @mul_i32_x_1(%arg0: i32) -> i32 { // CHECK-NEXT: vm.return %arg0 : i32 %c1 = vm.const.i32 1 %0 = vm.mul.i32 %arg0, %c1 : i32 @@ -134,7 +134,7 @@ vm.module @mul_i32_folds { // CHECK-LABEL: @mul_mul_i32_folds vm.module @mul_mul_i32_folds { // CHECK-LABEL: @mul_mul_i32_const - vm.func @mul_mul_i32_const(%arg0 : i32) -> i32 { + vm.func @mul_mul_i32_const(%arg0: i32) -> i32 { // CHECK: %c40 = vm.const.i32 40 %c4 = vm.const.i32 4 %c10 = vm.const.i32 10 @@ -151,7 +151,7 @@ vm.module @mul_mul_i32_folds { // CHECK-LABEL: @div_i32_folds vm.module @div_i32_folds { // CHECK-LABEL: @div_i32_0_y - vm.func @div_i32_0_y(%arg0 : i32) -> i32 { + vm.func @div_i32_0_y(%arg0: i32) -> i32 { // CHECK: %zero = vm.const.i32.zero // CHECK-NEXT: vm.return %zero : i32 %zero = vm.const.i32.zero @@ -160,7 +160,7 @@ vm.module @div_i32_folds { } // CHECK-LABEL: @div_i32_x_1 - vm.func @div_i32_x_1(%arg0 : i32) -> i32 { + vm.func @div_i32_x_1(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %c1 = vm.const.i32 1 %0 = vm.div.i32.s %arg0, %c1 : i32 @@ -176,6 +176,14 @@ vm.module @div_i32_folds { %0 = vm.div.i32.s %c15, %c5 : i32 vm.return %0 : i32 } + + // CHECK-LABEL: @mul_div_i32 + vm.func @mul_div_i32(%arg0: i32, %arg1: i32) -> i32 { + %0 = vm.mul.i32 %arg0, %arg1 : i32 + %1 = vm.div.i32.s %0, %arg1 : i32 + // CHECK-NEXT: vm.return %arg0 + vm.return %1 : i32 + } } // ----- @@ -183,7 +191,7 @@ vm.module @div_i32_folds { // CHECK-LABEL: @rem_i32_folds vm.module @rem_i32_folds { // CHECK-LABEL: @rem_i32_x_1 - vm.func @rem_i32_x_1(%arg0 : i32) -> i32 { + vm.func @rem_i32_x_1(%arg0: i32) -> i32 { // CHECK: %zero = vm.const.i32.zero // CHECK-NEXT: vm.return %zero : i32 %c1 = vm.const.i32 1 @@ -192,7 +200,7 @@ vm.module @rem_i32_folds { } // CHECK-LABEL: @rem_i32_0_y - vm.func @rem_i32_0_y(%arg0 : i32) -> i32 { + vm.func @rem_i32_0_y(%arg0: i32) -> i32 { // CHECK: %zero = vm.const.i32.zero // CHECK-NEXT: vm.return %zero : i32 %zero = vm.const.i32.zero @@ -244,7 +252,7 @@ vm.module @not_i32_folds { // CHECK-LABEL: @and_i32_folds vm.module @and_i32_folds { // CHECK-LABEL: @and_i32_zero - vm.func @and_i32_zero(%arg0 : i32) -> i32 { + vm.func @and_i32_zero(%arg0: i32) -> i32 { // CHECK: %zero = vm.const.i32.zero // CHECK-NEXT: vm.return %zero : i32 %zero = vm.const.i32.zero @@ -253,7 +261,7 @@ vm.module @and_i32_folds { } // CHECK-LABEL: @and_i32_eq - vm.func @and_i32_eq(%arg0 : i32) -> i32 { + vm.func @and_i32_eq(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %0 = vm.and.i32 %arg0, %arg0 : i32 vm.return %0 : i32 @@ -275,7 +283,7 @@ vm.module @and_i32_folds { // CHECK-LABEL: @or_i32_folds vm.module @or_i32_folds { // CHECK-LABEL: @or_i32_0_y - vm.func @or_i32_0_y(%arg0 : i32) -> i32 { + vm.func @or_i32_0_y(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %zero = vm.const.i32.zero %0 = vm.or.i32 %zero, %arg0 : i32 @@ -283,7 +291,7 @@ vm.module @or_i32_folds { } // CHECK-LABEL: @or_i32_x_0 - vm.func @or_i32_x_0(%arg0 : i32) -> i32 { + vm.func @or_i32_x_0(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %zero = vm.const.i32.zero %0 = vm.or.i32 %arg0, %zero : i32 @@ -291,7 +299,7 @@ vm.module @or_i32_folds { } // CHECK-LABEL: @or_i32_x_x - vm.func @or_i32_x_x(%arg0 : i32) -> i32 { + vm.func @or_i32_x_x(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %0 = vm.or.i32 %arg0, %arg0 : i32 vm.return %0 : i32 @@ -313,7 +321,7 @@ vm.module @or_i32_folds { // CHECK-LABEL: @xor_i32_folds vm.module @xor_i32_folds { // CHECK-LABEL: @xor_i32_0_y - vm.func @xor_i32_0_y(%arg0 : i32) -> i32 { + vm.func @xor_i32_0_y(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %zero = vm.const.i32.zero %0 = vm.xor.i32 %zero, %arg0 : i32 @@ -321,7 +329,7 @@ vm.module @xor_i32_folds { } // CHECK-LABEL: @xor_i32_x_0 - vm.func @xor_i32_x_0(%arg0 : i32) -> i32 { + vm.func @xor_i32_x_0(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %zero = vm.const.i32.zero %0 = vm.xor.i32 %arg0, %zero : i32 @@ -329,7 +337,7 @@ vm.module @xor_i32_folds { } // CHECK-LABEL: @xor_i32_x_x - vm.func @xor_i32_x_x(%arg0 : i32) -> i32 { + vm.func @xor_i32_x_x(%arg0: i32) -> i32 { // CHECK: %zero = vm.const.i32.zero // CHECK-NEXT: vm.return %zero : i32 %0 = vm.xor.i32 %arg0, %arg0 : i32 @@ -392,7 +400,7 @@ vm.module @shl_i32_folds { } // CHECK-LABEL: @shl_i32_x_by_0 - vm.func @shl_i32_x_by_0(%arg0 : i32) -> i32 { + vm.func @shl_i32_x_by_0(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %c0 = vm.const.i32 0 %0 = vm.shl.i32 %arg0, %c0 : i32 @@ -425,7 +433,7 @@ vm.module @shr_i32_s_folds { } // CHECK-LABEL: @shr_i32_s_x_by_0 - vm.func @shr_i32_s_x_by_0(%arg0 : i32) -> i32 { + vm.func @shr_i32_s_x_by_0(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %c0 = vm.const.i32.zero %0 = vm.shr.i32.s %arg0, %c0 : i32 @@ -458,7 +466,7 @@ vm.module @shr_i32_u_folds { } // CHECK-LABEL: @shr_i32_u_x_by_0 - vm.func @shr_i32_u_x_by_0(%arg0 : i32) -> i32 { + vm.func @shr_i32_u_x_by_0(%arg0: i32) -> i32 { // CHECK: vm.return %arg0 : i32 %c0 = vm.const.i32 0 %0 = vm.shr.i32.u %arg0, %c0 : i32 diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h index f58100e7e2cd..f1e7eeb5b3f9 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h @@ -32,7 +32,7 @@ struct EncodedBytecodeFunction { class BytecodeEncoder : public VMFuncEncoder { public: // Matches IREE_VM_BYTECODE_VERSION_MAJOR. - static constexpr uint32_t kVersionMajor = 12; + static constexpr uint32_t kVersionMajor = 13; // Matches IREE_VM_BYTECODE_VERSION_MINOR. static constexpr uint32_t kVersionMinor = 0; static constexpr uint32_t kVersion = (kVersionMajor << 16) | kVersionMinor; diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp index b7c15f449403..55ca07e436c8 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp @@ -33,6 +33,7 @@ struct FromMemRefSubView : public OpRewritePattern { auto sourceType = source.getType().cast(); int sourceRank = sourceType.getRank(); int subRank = subType.getRank(); + (void)subRank; // Create a descriptor for the source. IndexType indexType = rewriter.getIndexType(); @@ -44,20 +45,28 @@ struct FromMemRefSubView : public OpRewritePattern { loc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, sizeStrideTypes, subview.getSource()); - // For sizes, we just use the new ones, discarding the source. + // For sizes, we just use the new ones. + llvm::SmallBitVector droppedDims = subview.getDroppedDims(); + unsigned insertedDims = 0; SmallVector newSizes; - for (int i = 0; i < subRank; ++i) { + for (int i = 0; i < sourceRank; ++i) { + // Skip the sizes that don't show up in the final type. + if (droppedDims.test(i)) continue; + if (subview.isDynamicSize(i)) { newSizes.push_back(subview.getDynamicSize(i)); } else { newSizes.push_back(indexSet.get(subview.getStaticSize(i))); } - op.getSizes()[i].replaceAllUsesWith(newSizes.back()); + op.getSizes()[insertedDims++].replaceAllUsesWith(newSizes.back()); } + assert(insertedDims == subRank && + "Should have populated all the non-reduced sizes"); // Apply stride multipliers. SmallVector strides; - for (int i = 0; i < subRank; ++i) { + insertedDims = 0; + for (int i = 0; i < sourceRank; ++i) { Value currentStride; if (subview.isDynamicStride(i)) { currentStride = subview.getDynamicStride(i); @@ -67,12 +76,20 @@ struct FromMemRefSubView : public OpRewritePattern { currentStride = rewriter.createOrFold( loc, sourceDesc.getStrides()[i], currentStride); strides.push_back(currentStride); - op.getStrides()[i].replaceAllUsesWith(currentStride); + + // Don't replace the value of dropped dimensions. + // Although the new stride will be used in the computation of the final + // offset, there's no value to replace. + if (droppedDims.test(i)) continue; + + op.getStrides()[insertedDims++].replaceAllUsesWith(currentStride); } + assert(insertedDims == subRank && + "Should have populated all the non-reduced strides"); // Offsets. Value offset = sourceDesc.getOffset(); - for (int i = 0; i < subRank; ++i) { + for (int i = 0; i < sourceRank; ++i) { Value logicalOffset; if (subview.isDynamicOffset(i)) { logicalOffset = subview.getDynamicOffset(i); @@ -167,8 +184,8 @@ struct FromHalInterfaceBindingSubspan } }; -// Allocations (and anything else the returns a non-offset identity memref) -// are matched by this pattern. +// Allocations always return a non-offset memref and are matched by this +// pattern. struct FromAllocation : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GetBufferDescriptorOp op, @@ -235,6 +252,70 @@ struct FromAllocation : public OpRewritePattern { } }; +// MemRef globals are always static shaped and reference a non-offset +// buffer. +struct FromGlobal : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GetBufferDescriptorOp op, + PatternRewriter &rewriter) const override { + auto global = op.getSource().getDefiningOp(); + if (!global) return failure(); + auto memRefType = global.getResult().getType().cast(); + if (!memRefType.getLayout().isIdentity()) { + return rewriter.notifyMatchFailure(op, "not identity allocation"); + } + + auto loc = op.getLoc(); + IndexSet indexSet(loc, rewriter); + + // Replace the op with values: + // base_buffer: The subspan result + // offset: byte offset from subspan divided by element type size + // sizes: static and dynamic sizes from the subspan + // strides: identity strides + int rank = memRefType.getRank(); + + // Compute sizes. + SmallVector sizes; + for (int i = 0; i < rank; ++i) { + assert(!memRefType.isDynamicDim(i) && + "memref.get_global does not support dynamic dims"); + sizes.push_back(rewriter.create( + loc, memRefType.getDimSize(i))); + + // Replace as we go. + op.getSizes()[i].replaceAllUsesWith(sizes.back()); + } + + // Strides (just creates identity strides). + if (rank > 0) { + SmallVector strides; + strides.resize(rank); + strides[rank - 1] = indexSet.get(1); + for (int i = rank - 2; i >= 0; --i) { + strides[i] = rewriter.createOrFold(loc, strides[i + 1], + sizes[i + 1]); + } + for (int i = 0; i < rank; ++i) { + op.getStrides()[i].replaceAllUsesWith(strides[i]); + } + } + + // Offset. + op.getOffset().replaceAllUsesWith(indexSet.get(0)); + + // Base buffer. + op.getBaseBuffer().replaceAllUsesWith( + rewriter + .create( + loc, op.getBaseBuffer().getType(), global.getResult()) + .getResult(0)); + + rewriter.eraseOp(op); + return success(); + } +}; + class ResolveBufferDescriptorsPass : public ResolveBufferDescriptorsBase { public: @@ -246,7 +327,7 @@ class ResolveBufferDescriptorsPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir index d2d38aa14e65..96db626c3bec 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir @@ -28,7 +28,9 @@ func.func @resolve_subview_rankreducing(%arg0: memref<384x128xf32>, %arg1 : inde // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK-DAG: %[[I0:.*]] = arith.muli %arg1, %[[BASE_STRIDES]]#0 : index // CHECK: %[[I1:.*]] = arith.addi %[[BASE_OFFSET]], %[[I0]] : index - // CHECK: return %[[BASE_BUFFER]], %[[I1]], %[[C64]], %[[BASE_STRIDES]]#0 + // CHECK: %[[I2:.*]] = arith.muli %arg2, %[[BASE_STRIDES]]#1 : index + // CHECK: %[[I3:.*]] = arith.addi %[[I1]], %[[I2]] : index + // CHECK: return %[[BASE_BUFFER]], %[[I3]], %[[C64]], %[[BASE_STRIDES]]#0 %0 = memref.subview %arg0[%arg1, %arg2] [64, 1] [1, 1] : memref<384x128xf32> to memref<64xf32, #map0> %base_buffer, %offset, %size, %stride = vmvx.get_buffer_descriptor %0 : memref<64xf32, #map0> -> !util.buffer, index, index, index return %base_buffer, %offset, %size, %stride : !util.buffer, index, index, index @@ -36,6 +38,37 @@ func.func @resolve_subview_rankreducing(%arg0: memref<384x128xf32>, %arg1 : inde // ----- +// Check that we properly resolve subview with rankreducing when the dropped +// rank is not the last one. +// Orig strides: [%strides#0, %strides#1, %strides#2] +// Sub strides: [1, 1, 1] +// => New strides: [%strides#0, %strides#1, %strides#2] +// Final strides == filterOutReducedDim(new strides, 0) == [%strides#1 , %strides#2] +// +// Orig offset: %offset +// Sub offsets: [%arg1, %arg2, 0] +// => Final offset: %arg1 * %strides#0 + %arg2 * %strides#1 + 0 * %strides#2 + %offset +// +// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3] +// +// CHECK-LABEL: @resolve_subview_rankreducing_not_at_the_end +func.func @resolve_subview_rankreducing_not_at_the_end(%arg0: memref<8x16x4xf32>, %arg1 : index, %arg2 : index) -> (!util.buffer, index, index, index, index, index) { + // CHECK-DAG: %[[BASE_BUFFER:.*]], %[[BASE_OFFSET:.*]], %[[BASE_SIZES:.*]]:3, %[[BASE_STRIDES:.*]]:3 = vmvx.get_buffer_descriptor %arg0 + // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index + // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index + // CHECK-DAG: %[[I0:.*]] = arith.muli %arg1, %[[BASE_STRIDES]]#0 : index + // CHECK: %[[I1:.*]] = arith.addi %[[BASE_OFFSET]], %[[I0]] : index + // CHECK: %[[I2:.*]] = arith.muli %arg2, %[[BASE_STRIDES]]#1 : index + // CHECK: %[[I3:.*]] = arith.addi %[[I1]], %[[I2]] : index + // CHECK: return %[[BASE_BUFFER]], %[[I3]], %[[C6]], %[[C3]], %[[BASE_STRIDES]]#1, %[[BASE_STRIDES]]#2 + + %0 = memref.subview %arg0[%arg1, %arg2, 0] [1, 6, 3] [1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4,1], offset : ?>> + %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<6x3xf32, strided<[4,1], offset : ?>> -> !util.buffer, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index +} + +// ----- + // CHECK-LABEL: @resolve_binding_subspan_zero_offset func.func @resolve_binding_subspan_zero_offset() -> (!util.buffer, index, index, index, index, index) { // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index @@ -78,3 +111,49 @@ func.func @resolve_binding_subspan_dyn_dims(%arg0 : index, %arg1 : index) -> (!u %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref -> !util.buffer, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index } + +// ----- + +// CHECK-LABEL: @resolve_alloca_static +func.func @resolve_alloca_static() -> (!util.buffer, index, index, index, index, index) { + // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index + // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast + // CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]] + %0 = memref.alloca() : memref<512x384xf32> + %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index +} + +// ----- + +// CHECK-LABEL: @resolve_alloca_dynamic +func.func @resolve_alloca_dynamic(%arg0 : index) -> (!util.buffer, index, index, index, index, index) { + // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast + // CHECK: return %[[CAST]], %[[C0]], %arg0, %[[C384]], %[[C384]], %[[C1]] + %0 = memref.alloca(%arg0) : memref + %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref -> !util.buffer, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index +} + +// ----- + +// CHECK-LABEL: @resolve_global +memref.global "private" constant @__constant_2xi32 : memref<512x384xf32> = dense<0.0> + +func.func @resolve_global() -> (!util.buffer, index, index, index, index, index) { + // CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index + // CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast + // CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]] + %0 = memref.get_global @__constant_2xi32 : memref<512x384xf32> + %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<512x384xf32> -> !util.buffer, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index +} diff --git a/compiler/src/iree/compiler/InputConversion/Common/BUILD b/compiler/src/iree/compiler/InputConversion/Common/BUILD index 6c66b4d7a2cf..d63466b0f698 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/BUILD +++ b/compiler/src/iree/compiler/InputConversion/Common/BUILD @@ -61,6 +61,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Utils", "//llvm-external-projects/iree-dialects:IREEInputDialect", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:FuncDialect", diff --git a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt index 903b1d2c6c7f..ac7be9af5f0d 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt @@ -65,6 +65,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::Util::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp b/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp index 331d2b61e000..fec390512731 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/InputConversion/Common/PassDetail.h" #include "iree/compiler/InputConversion/Common/Passes.h" +#include "iree/compiler/Utils/StringUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -29,11 +30,8 @@ struct SanitizeModuleNamesPass auto optionalName = moduleOp.getName(); if (!optionalName.has_value()) return; auto name = optionalName.value(); - if (!name.contains('.')) return; - std::string sanitizedName(name); - std::replace(sanitizedName.begin(), sanitizedName.end(), '.', '_'); - moduleOp.setName(sanitizedName); + moduleOp.setName(sanitizeSymbolName(name)); } }; diff --git a/compiler/src/iree/compiler/InputConversion/Common/test/sanitize_module_names.mlir b/compiler/src/iree/compiler/InputConversion/Common/test/sanitize_module_names.mlir index fba2497427dd..3d0db288908e 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/test/sanitize_module_names.mlir +++ b/compiler/src/iree/compiler/InputConversion/Common/test/sanitize_module_names.mlir @@ -15,8 +15,8 @@ builtin.module @u_n_d_e_r_s_c_o_r_e_s {} // ----- -// CHECK-LABEL: module @dollar$$signs$$are$$okay -builtin.module @dollar$$signs$$are$$okay {} +// CHECK-LABEL: module @dollar__signs +builtin.module @dollar$$signs {} // ----- diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp index de602457ab11..a853594af2d2 100644 --- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp +++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp @@ -295,7 +295,7 @@ struct ScatterUpdateConversion : public OpConversionPattern { /*outputs=*/adaptor.operands()[0], indexingMaps, mhlo::getNParallelLoopsAttrs(nloops), [](OpBuilder &b, Location loc, ValueRange args) {}, - mhlo::pruneAttributeList(op)); + linalg::getPrunedAttributeList(op)); // Transform the scatter update computation region // update = a bunch of computation diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h index c044362ac9c4..25623723a21f 100644 --- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Quant/QuantOps.h" @@ -48,6 +49,7 @@ inline void registerMlirDialects(DialectRegistry ®istry) { cf::ControlFlowDialect, bufferization::BufferizationDialect, gpu::GPUDialect, + nvgpu::NVGPUDialect, LLVM::LLVMDialect, linalg::LinalgDialect, math::MathDialect, diff --git a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp index 2b71be276e2c..d867dfcadca1 100644 --- a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp @@ -6,7 +6,9 @@ #include "iree/compiler/Utils/ModuleUtils.h" +#include "iree/compiler/Utils/StringUtils.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Parser/Parser.h" @@ -15,6 +17,30 @@ namespace mlir { namespace iree_compiler { +static llvm::Optional findFirstFileLoc(Location baseLoc) { + if (auto loc = baseLoc.dyn_cast()) { + for (auto &childLoc : loc.getLocations()) { + auto childResult = findFirstFileLoc(childLoc); + if (childResult) return childResult; + } + } else if (auto loc = baseLoc.dyn_cast()) { + return loc; + } + return llvm::None; +} + +std::string guessModuleName(mlir::ModuleOp moduleOp, StringRef defaultName) { + std::string moduleName = moduleOp.getName().value_or("").str(); + if (!moduleName.empty()) return moduleName; + auto loc = findFirstFileLoc(moduleOp.getLoc()); + if (loc.has_value()) { + return sanitizeSymbolName( + llvm::sys::path::stem(loc.value().getFilename()).str()); + } else { + return defaultName.str(); + } +} + // Renames |op| within |moduleOp| with a new name that is unique within both // |moduleOp| and |symbolTable|. static void renameWithDisambiguatedName(Operation *op, Operation *moduleOp, diff --git a/compiler/src/iree/compiler/Utils/ModuleUtils.h b/compiler/src/iree/compiler/Utils/ModuleUtils.h index 687a17cb20cd..7a63ffdc8016 100644 --- a/compiler/src/iree/compiler/Utils/ModuleUtils.h +++ b/compiler/src/iree/compiler/Utils/ModuleUtils.h @@ -9,12 +9,18 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" namespace mlir { namespace iree_compiler { +// Guesses the name of the module from the source locations attached unless a +// name is already specified. If no source locations are found then +// |defaultName| is returned. +std::string guessModuleName(mlir::ModuleOp moduleOp, StringRef defaultName); + // Destructively merges |sourceOp| into |targetOp| using |targetBuilder|. // // If a private symbol in |sourceOp| conflicts with another symbol diff --git a/compiler/src/iree/compiler/Utils/StringUtils.cpp b/compiler/src/iree/compiler/Utils/StringUtils.cpp index 94a727a1e33e..bfdbbee179b6 100644 --- a/compiler/src/iree/compiler/Utils/StringUtils.cpp +++ b/compiler/src/iree/compiler/Utils/StringUtils.cpp @@ -6,6 +6,8 @@ #include "iree/compiler/Utils/StringUtils.h" +#include "llvm/ADT/StringRef.h" + namespace mlir { namespace iree_compiler { @@ -25,5 +27,33 @@ std::string replaceAllSubstrs(const std::string &str, const std::string &match, return copy; } +std::string sanitizeSymbolName(StringRef name) { + std::string result; + result.reserve(name.size()); + for (size_t i = 0; i < name.size(); ++i) { + char c = name[i]; + if (!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_')) { + c = '_'; + } + result.push_back(c); + } + return result; +} + +std::string sanitizeFileName(StringRef name) { + std::string result; + result.reserve(name.size()); + for (size_t i = 0; i < name.size(); ++i) { + char c = name[i]; + if (!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_' || c == '-' || c == '.')) { + c = '_'; + } + result.push_back(c); + } + return result; +} + } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Utils/StringUtils.h b/compiler/src/iree/compiler/Utils/StringUtils.h index dc378ce7cba6..8f7eddfa1331 100644 --- a/compiler/src/iree/compiler/Utils/StringUtils.h +++ b/compiler/src/iree/compiler/Utils/StringUtils.h @@ -9,6 +9,8 @@ #include +#include "mlir/Support/LLVM.h" + namespace mlir { namespace iree_compiler { @@ -22,6 +24,29 @@ void replaceAllSubstrsInPlace(std::string &str, const std::string &match, std::string replaceAllSubstrs(const std::string &str, const std::string &match, const std::string &substitute); +// Sanitizes a symbol name for compatibility with common targets (C, file +// systems, debug databases, etc). +// +// MLIR identifiers must match this regex: +// (letter|[_]) (letter|digit|[_$.])* +// https://mlir.llvm.org/docs/LangRef/#identifiers-and-keywords +// This is a superset of the names other targets support and as inputs are only +// expected to match the above any place exporting symbol names must use this. +// +// Examples: +// `abc` -> `abc` +// `a.b` -> `a_b` +// `a$-æb` -> `a___b` +std::string sanitizeSymbolName(StringRef name); + +// Sanitizes a file name for compatibility with common file systems. +// +// Examples: +// `abc` -> `abc` +// `a.b` -> `a.b` +// `a$-æb` -> `a_-_b` +std::string sanitizeFileName(StringRef name); + } // namespace iree_compiler } // namespace mlir diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE index 244f5b390b44..18907ffaaafa 100644 --- a/integrations/tensorflow/WORKSPACE +++ b/integrations/tensorflow/WORKSPACE @@ -7,7 +7,7 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") -TENSORFLOW_COMMIT = "60ada7df3b0e5c877049a6b8f02054ce715c7288" +TENSORFLOW_COMMIT = "00fd752b4932b73c8d04eb6e2543f6c7b2ac94e9" git_repository( name = "org_tensorflow", diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td index 0daa0eb68e2c..d3847e1fe5ec 100644 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td +++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td @@ -88,7 +88,7 @@ def IREEInput_GlobalOp : IREEInput_Op<"global", [ (`mutable` $is_mutable^)? $sym_name attr-dict - (`initializer` `(` $initializer^ `)`):(``)? + (`initializer` `(` $initializer^ `)`)? custom($type, $initial_value) }]; diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD index 96c8d428f32e..c26ed7ae007f 100644 --- a/integrations/tensorflow/iree_tf_compiler/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/BUILD @@ -46,7 +46,7 @@ cc_binary( "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes", "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo", - "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops", + "@stablehlo//:chlo_ops", ], ) diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD index adefe547e09b..cb8f1a1a5296 100644 --- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD @@ -62,6 +62,6 @@ cc_library( "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:all_passes", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:chlo_legalize_to_hlo", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mhlo_to_mhlo_lowering_patterns", - "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops", + "@stablehlo//:chlo_ops", ], ) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/layers/layers_test.py b/integrations/tensorflow/test/python/iree_tf_tests/layers/layers_test.py index 4820bb070636..7e31ce0bbeea 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/layers/layers_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/layers/layers_test.py @@ -83,6 +83,14 @@ "ConvLSTM2D": ["strides", "dilation_rate"], } +# Some layers cannot operate on a fully dynamic shape, only dynamic batch size. +LAYERS_WITH_BATCH_ONLY_DYNAMIC_SHAPE = [ + "AveragePooling1D", # uses tf.nn.avg_pool2d with reshape, which is illegal + "BatchNormalization", # failed to materialize conversion + "Conv3D", # tf.Conv3D op illegal + "MaxPool1D", # uses tf.nn.max_pool2d with reshape, which is illegal +] + def get_default_kwargs_values(layer: str) -> Dict[str, Any]: """Gets the default kwargs for a tf.keras.layers layer.""" @@ -501,8 +509,16 @@ def create_layer_unit_test( dynamic_signature = static_signature if FLAGS.dynamic_dims: - dynamic_signature = tf_utils.apply_function(dynamic_signature, - tf_utils.make_dims_dynamic) + + def make_batch_size_dynamic(tensor_spec: tf.TensorSpec) -> tf.TensorSpec: + return tf.TensorSpec([None] + tensor_spec.shape[1:], tensor_spec.dtype) + + if FLAGS.layer in LAYERS_WITH_BATCH_ONLY_DYNAMIC_SHAPE: + dynamic_signature = tf_utils.apply_function(dynamic_signature, + make_batch_size_dynamic) + else: + dynamic_signature = tf_utils.apply_function(dynamic_signature, + tf_utils.make_dims_dynamic) if len(static_signature) > 1: static_signature = [static_signature] diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD index c98945e97787..8e9c22f28a3e 100644 --- a/llvm-external-projects/iree-dialects/BUILD +++ b/llvm-external-projects/iree-dialects/BUILD @@ -375,6 +375,7 @@ cc_library( ":IREELinalgExtPasses", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:ArithmeticUtils", "@llvm-project//mlir:AsyncDialect", @@ -391,12 +392,14 @@ cc_library( "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorUtils", "@llvm-project//mlir:TilingInterface", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorTransforms", ], ) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td index 0daa0eb68e2c..d3847e1fe5ec 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td @@ -88,7 +88,7 @@ def IREEInput_GlobalOp : IREEInput_Op<"global", [ (`mutable` $is_mutable^)? $sym_name attr-dict - (`initializer` `(` $initializer^ `)`):(``)? + (`initializer` `(` $initializer^ `)`)? custom($type, $initial_value) }]; diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h index 417a582b4054..bac5ee890a29 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h @@ -8,6 +8,8 @@ #define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_ #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -16,6 +18,7 @@ namespace IREE { namespace LinalgExt { #define GEN_PASS_CLASSES + #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: keep } // namespace LinalgExt diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h index 503e2101a9d1..f261bff99762 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h @@ -43,6 +43,70 @@ std::unique_ptr> createTopkSplitReductionPass(); // Marker used as attribute the depth of the split reduction transformations. const StringLiteral kSplitReductionDepthMarker = "__split_reduction_depth__"; +//===---------------------------------------------------------------------===// +// Codegen Strategy passes that are moved into IREE. +//===---------------------------------------------------------------------===// +/// Create a LinalgStrategyTileAndFusePass. +std::unique_ptr> +createLinalgStrategyTileAndFusePass( + StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {}, + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyTilePass. +std::unique_ptr> createLinalgStrategyTilePass( + StringRef opName = "", + const linalg::LinalgTilingOptions &opt = linalg::LinalgTilingOptions(), + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyPadPass. +std::unique_ptr> createLinalgStrategyPadPass( + StringRef opName = "", + const linalg::LinalgPaddingOptions &opt = linalg::LinalgPaddingOptions(), + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyDecomposePass. +// TODO: if/when we need finer control add an `opName` parameter. +std::unique_ptr> createLinalgStrategyDecomposePass( + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyPeelPass. +std::unique_ptr> createLinalgStrategyPeelPass( + StringRef opName = "", + const linalg::LinalgPeelOptions &opt = linalg::LinalgPeelOptions(), + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyVectorizePass. +std::unique_ptr> createLinalgStrategyVectorizePass( + StringRef opName = "", + linalg::LinalgVectorizationOptions opt = + linalg::LinalgVectorizationOptions(), + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter(), + bool padVectorize = false); + +/// Create a LinalgStrategyEnablePass. +std::unique_ptr> createLinalgStrategyEnablePass( + linalg::LinalgEnablingOptions opt = linalg::LinalgEnablingOptions(), + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyLowerVectorsPass. +std::unique_ptr> +createLinalgStrategyLowerVectorsPass( + linalg::LinalgVectorLoweringOptions opt = + linalg::LinalgVectorLoweringOptions(), + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyRemoveMarkersPass. +std::unique_ptr> +createLinalgStrategyRemoveMarkersPass(); + void registerPasses(); } // namespace LinalgExt diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td index db01f82dcd40..aba76b8f5703 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td @@ -59,4 +59,124 @@ def TopkSplitReduction: ]; } +//===---------------------------------------------------------------------====// +// Codegen Strategy passes moved into IREE +// TODO: Deprecate all this. +//===---------------------------------------------------------------------====// + +def LinalgStrategyTileAndFusePass + : Pass<"iree-linalg-strategy-tile-and-fuse-pass", "func::FuncOp"> { + let summary = "Configurable pass to apply pattern-based tiling and fusion."; + let constructor = "createLinalgStrategyTileAndFusePass()"; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyTilePass + : Pass<"iree-linalg-strategy-tile-pass", "func::FuncOp"> { + let summary = "Configurable pass to apply pattern-based linalg tiling."; + let constructor = "createLinalgStrategyTilePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyPadPass + : Pass<"iree-linalg-strategy-pad-pass", "func::FuncOp"> { + let summary = "Configurable pass to apply padding and hoisting."; + let constructor = "createLinalgStrategyPadPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +// TODO: if/when we need finer control add an anchorOp option. +def LinalgStrategyDecomposePass + : Pass<"iree-linalg-strategy-decompose-pass", "func::FuncOp"> { + let summary = "Configurable pass to apply pattern-based generalization."; + let constructor = "createLinalgStrategyDecomposePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + +def LinalgStrategyPeelPass + : Pass<"iree-linalg-strategy-peel-pass", "func::FuncOp"> { + let summary = "Configurable pass to apply pattern-based linalg peeling."; + let constructor = "createLinalgStrategyPeelPass()"; + let dependentDialects = [ + "linalg::LinalgDialect", + "scf::SCFDialect" + ]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyVectorizePass + : Pass<"iree-linalg-strategy-vectorize-pass", "func::FuncOp"> { + let summary = "Configurable pass to apply pattern-based linalg vectorization."; + let constructor = "createLinalgStrategyVectorizePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + Option<"vectorizePadding", "vectorize-padding", "bool", "false", + "Enable vectorization of padding ops.">, + ]; +} + +def LinalgStrategyEnablePass + : Pass<"iree-linalg-strategy-enable-pass", "func::FuncOp"> { + let summary = "Configurable pass to enable the application of other " + "pattern-based linalg passes."; + let constructor = "createLinalgStrategyEnablePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + +def LinalgStrategyLowerVectorsPass + : Pass<"iree-linalg-strategy-lower-vectors-pass", "func::FuncOp"> { + let summary = "Configurable pass to lower vector operations."; + let constructor = "createLinalgStrategyLowerVectorsPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + +def LinalgStrategyRemoveMarkersPass + : Pass<"iree-linalg-strategy-remove-markers-pass", "func::FuncOp"> { + let summary = "Cleanup pass that drops markers."; + let constructor = "createLinalgStrategyRemoveMarkersPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + #endif // IREE_DIALECT_LINALGEXT_PASSES diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h new file mode 100644 index 000000000000..d7dcfc9c03f5 --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h @@ -0,0 +1,287 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_ +#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_ + +#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" +#include "mlir/Pass/PassManager.h" + +#include + +//===----------------------------------------------------------------------===// +// Strategies moved from upstream MLIR as IREE still heavily relies on patterns +// that compose through filters. +// TODO: Deprecate everything below. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace LinalgExt { + +/// Abstract Transformation class applied in a sequence that also handles state +/// through markers. +struct Transformation { + explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f) + : filter(std::move(f)) {} + virtual ~Transformation() = default; + virtual void + addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const = 0; + linalg::LinalgTransformationFilter::FilterFunction filter = nullptr; +}; + +/// Represent one application of LinalgStrategyTileAndFusePass. +struct TileAndFuse : public Transformation { + TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), opName(name), + options(std::move(options)) {} + + void addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m)); + } + +private: + std::string opName; + linalg::LinalgTilingAndFusionOptions options; +}; + +/// Represent one application of LinalgStrategyTilePass. +struct Tile : public Transformation { + Tile(StringRef name, linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), opName(name), + options(std::move(options)) {} + + void addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyTilePass(opName, options, m)); + } + +private: + std::string opName; + linalg::LinalgTilingOptions options; +}; + +/// Represent one application of LinalgStrategyPadPass. +struct Pad : public Transformation { + Pad(StringRef name, linalg::LinalgPaddingOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), opName(name), + options(std::move(options)) {} + + void addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyPadPass(opName, options, m)); + } + +private: + std::string opName; + linalg::LinalgPaddingOptions options; +}; + +/// Represent one application of createLinalgStrategyDecomposePass. +struct Decompose : public Transformation { + explicit Decompose( + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)) {} + + void addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyDecomposePass(m)); + } +}; + +/// Represent one application of createLinalgStrategyPeelPass. +struct Peel : public Transformation { + explicit Peel(linalg::LinalgPeelOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), options(options) {} + + Peel(StringRef name, linalg::LinalgPeelOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), opName(name), options(options) {} + + void addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyPeelPass(opName, options, m)); + } + +private: + std::string opName; + linalg::LinalgPeelOptions options; +}; + +/// Represent one application of createLinalgStrategyVectorizePass. +struct Vectorize : public Transformation { + explicit Vectorize( + linalg::LinalgVectorizationOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr, + bool padVectorize = false) + : Transformation(std::move(f)), options(options), + vectorizePadding(padVectorize) {} + + Vectorize(StringRef name, linalg::LinalgVectorizationOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr, + bool padVectorize = false) + : Transformation(std::move(f)), opName(name), options(options), + vectorizePadding(padVectorize) {} + + void addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyVectorizePass(opName, options, m, + vectorizePadding)); + } + +private: + std::string opName; + linalg::LinalgVectorizationOptions options; + bool vectorizePadding; +}; + +/// Represent one application of createLinalgStrategyLowerVectorsPass. +struct VectorLowering : public Transformation { + explicit VectorLowering( + linalg::LinalgVectorLoweringOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), options(options) {} + + void addToPassPipeline(OpPassManager &pm, + linalg::LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyLowerVectorsPass(options, m)); + } + +private: + linalg::LinalgVectorLoweringOptions options; +}; + +/// Codegen strategy controls how a Linalg op is progressively lowered. +struct CodegenStrategy { + /// Append a pattern to tile the Op `opName` and fuse its producers with + /// tiling and fusion `options`. + CodegenStrategy &tileAndFuse( + StringRef opName, const linalg::LinalgTilingAndFusionOptions &options, + const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, options, f)); + return *this; + } + /// Conditionally append a pattern to tile the Op `opName` and fuse its + /// producers with tiling and fusion `options`. + CodegenStrategy &tileAndFuseIf( + bool b, StringRef opName, linalg::LinalgTilingAndFusionOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this; + } + /// Append a pattern to add a level of tiling for Op `opName` with tiling + /// `options`. + CodegenStrategy & + tile(StringRef opName, const linalg::LinalgTilingOptions &options, + const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, options, f)); + return *this; + } + /// Conditionally append a pattern to add a level of tiling for + /// `LinalgOpType` with tiling `options`. + CodegenStrategy & + tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? tile(opName, std::move(options), std::move(f)) : *this; + } + /// Append a pattern to pad and hoist the operands of Op `opName` with padding + /// `options`. + CodegenStrategy & + pad(StringRef opName, const linalg::LinalgPaddingOptions &options, + const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, options, f)); + return *this; + } + /// Conditionally append a pattern to pad and hoist the operands of Op + /// `opName` with padding `options`. + CodegenStrategy & + padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? pad(opName, std::move(options), std::move(f)) : *this; + } + /// Append patterns to decompose convolutions. + CodegenStrategy &decompose( + const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) { + transformationSequence.emplace_back(std::make_unique(f)); + return *this; + } + /// Conditionally append patterns to decompose convolutions. + CodegenStrategy & + decomposeIf(bool b, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? decompose(std::move(f)) : *this; + } + /// Append a pattern to peel 'LinalgOpType'. + CodegenStrategy & + peel(StringRef opName, const linalg::LinalgPeelOptions &options, + const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, options, f)); + return *this; + } + /// Conditionally append a pattern to peel 'LinalgOpType'. + CodegenStrategy & + peelIf(bool b, StringRef opName, const linalg::LinalgPeelOptions &options, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? peel(opName, options, std::move(f)) : *this; + } + /// Append a pattern to rewrite `LinalgOpType` as a vector operation. + CodegenStrategy &vectorize( + StringRef opName, + const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr, + bool vectorizePadding = false) { + transformationSequence.emplace_back(std::make_unique( + opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding)); + return *this; + } + /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector + /// operation. + CodegenStrategy & + vectorizeIf(bool b, StringRef opName, + linalg::LinalgTransformationFilter::FilterFunction f = nullptr, + bool vectorizePadding = false) { + return b ? vectorize(opName, std::move(f), vectorizePadding) : *this; + } + /// Append a pattern to lower all vector operations. + CodegenStrategy &vectorLowering(linalg::LinalgVectorLoweringOptions options) { + transformationSequence.emplace_back( + std::make_unique(options)); + return *this; + } + /// Configure the post staged-patterns global enabling passes options. + CodegenStrategy & + setVectorTransferToSCFOptions(linalg::LinalgEnablingOptions options) { + linalgEnablingOptions = options; + return *this; + } + + /// Apply the transformation patterns in sequence with cleanup + /// transformations interleaved. + void configurePassPipeline(OpPassManager &pm, MLIRContext *context, + bool addEnablePass = true) const; + +private: + LogicalResult postPatternTransforms(Operation *func) const; + + linalg::LinalgEnablingOptions linalgEnablingOptions; + SmallVector, 4> transformationSequence; +}; + +} // namespace LinalgExt +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_ diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt index 8544b402d3eb..85f46a5654e9 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(IREELinalgExtTransforms + CodegenStrategy.cpp ForeachThreadToAsync.cpp ForeachThreadToSequentialFor.cpp Fusion.cpp diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp new file mode 100644 index 000000000000..8ef22236be1a --- /dev/null +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp @@ -0,0 +1,46 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h" +#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" +#include "mlir/Pass/PassManager.h" + +using namespace mlir; + +#define DEBUG_TYPE "linalg-codegen-strategy" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace LinalgExt { + +void CodegenStrategy::configurePassPipeline(OpPassManager &pm, + MLIRContext *context, + bool addEnablePass) const { + for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e; + ++stepCount) { + const std::unique_ptr &t = + transformationSequence[stepCount]; + std::string currentStr = std::to_string(stepCount); + auto currentState = StringAttr::get(context, currentStr); + std::string nextStr = std::to_string(stepCount + 1); + auto nextState = StringAttr::get(context, nextStr); + auto filter = (currentState.str() == std::to_string(0)) + ? linalg::LinalgTransformationFilter( + t->filter, ArrayRef{}, nextState) + : linalg::LinalgTransformationFilter( + t->filter, currentState, nextState); + t->addToPassPipeline(pm, filter); + if (addEnablePass) + pm.addPass(createLinalgStrategyEnablePass(linalgEnablingOptions)); + } + pm.addPass(createLinalgStrategyRemoveMarkersPass()); +} + +} // namespace LinalgExt +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp index 5e42e3b942d8..9bc3b10d951e 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp @@ -7,8 +7,18 @@ #include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; @@ -77,6 +87,426 @@ LinalgVectorizationPattern::matchAndRewrite(linalg::LinalgOp linalgOp, return vectorize(rewriter, linalgOp); } +namespace { + +/// Configurable pass to apply pattern-based tiling and fusion. +struct LinalgStrategyTileAndFusePass + : public LinalgStrategyTileAndFusePassBase { + + LinalgStrategyTileAndFusePass() = default; + + LinalgStrategyTileAndFusePass(StringRef opName, + linalg::LinalgTilingAndFusionOptions opt, + linalg::LinalgTransformationFilter filt) + : options(std::move(opt)), filter(std::move(filt)) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); + if (!anchorOpName.empty()) { + tilingAndFusionPattern.add( + anchorOpName, funcOp.getContext(), options, filter); + } else { + tilingAndFusionPattern.add( + funcOp.getContext(), options, filter); + } + // Search the root operation using bottom up traversal. + GreedyRewriteConfig config; + config.useTopDownTraversal = false; + (void)applyPatternsAndFoldGreedily( + funcOp, std::move(tilingAndFusionPattern), config); + } + + linalg::LinalgTilingAndFusionOptions options; + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to apply pattern-based linalg tiling. +struct LinalgStrategyTilePass + : public LinalgStrategyTilePassBase { + + LinalgStrategyTilePass() = default; + + LinalgStrategyTilePass(StringRef opName, linalg::LinalgTilingOptions opt, + linalg::LinalgTransformationFilter filt) + : options(std::move(opt)), filter(std::move(filt)) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet tilingPattern(ctx); + if (!anchorOpName.empty()) + tilingPattern.add(anchorOpName, ctx, options, + filter); + else + tilingPattern.add(ctx, options, filter); + if (anchorOpName == tensor::PadOp::getOperationName()) + populatePadTensorTilingPatterns(tilingPattern, options); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); + } + + linalg::LinalgTilingOptions options; + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to apply hoisting and padding. +struct LinalgStrategyPadPass + : public LinalgStrategyPadPassBase { + + LinalgStrategyPadPass() = default; + + LinalgStrategyPadPass(StringRef opName, linalg::LinalgPaddingOptions opt, + linalg::LinalgTransformationFilter filt) + : options(std::move(opt)), filter(std::move(filt)) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet paddingPattern(funcOp.getContext()); + if (!anchorOpName.empty()) { + paddingPattern.add( + anchorOpName, funcOp.getContext(), options, filter); + } else { + paddingPattern.add(funcOp.getContext(), + options, filter); + } + (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); + } + + linalg::LinalgPaddingOptions options; + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to apply lowering of coarser-grained named linalg ops into +/// finer-grained named versions. +struct LinalgStrategyDecomposePass + : public LinalgStrategyDecomposePassBase { + + LinalgStrategyDecomposePass() = default; + + LinalgStrategyDecomposePass(linalg::LinalgTransformationFilter filter) + : filter(std::move(filter)) {} + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + RewritePatternSet decompositionPattern(funcOp.getContext()); + populateDecomposeConvolutionPatterns(decompositionPattern, filter); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(decompositionPattern)))) + signalPassFailure(); + } + + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to apply pattern-based linalg peeling. +struct LinalgStrategyPeelPass + : public LinalgStrategyPeelPassBase { + + LinalgStrategyPeelPass() = default; + + LinalgStrategyPeelPass(StringRef opName, linalg::LinalgPeelOptions opt, + linalg::LinalgTransformationFilter filt) + : options(std::move(opt)), filter(std::move(filt)) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet peelingPatterns(funcOp.getContext()); + if (!anchorOpName.empty()) { + peelingPatterns.add( + anchorOpName, funcOp.getContext(), options, filter); + } else { + peelingPatterns.add(funcOp.getContext(), + filter, options); + } + if (failed( + applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns)))) + return signalPassFailure(); + } + + linalg::LinalgPeelOptions options; + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to apply pattern-based linalg vectorization. +struct LinalgStrategyVectorizePass + : public LinalgStrategyVectorizePassBase { + + LinalgStrategyVectorizePass() = default; + + LinalgStrategyVectorizePass(StringRef opName, + linalg::LinalgVectorizationOptions opt, + linalg::LinalgTransformationFilter filt, + bool padVectorize = false) + : options(opt), filter(std::move(filt)) { + this->anchorOpName.setValue(opName.str()); + this->vectorizePadding.setValue(padVectorize); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet vectorizationPatterns(funcOp.getContext()); + if (!anchorOpName.empty()) { + vectorizationPatterns.add( + anchorOpName, funcOp.getContext(), options, filter); + } else { + vectorizationPatterns.add(funcOp.getContext(), + filter, options); + } + vector::populateVectorTransferPermutationMapLoweringPatterns( + vectorizationPatterns); + vector::populateVectorReductionToContractPatterns(vectorizationPatterns); + vectorizationPatterns.add( + funcOp.getContext(), /*benefit=*/2); + vector::TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, + funcOp.getContext()); + vector::TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, + funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(vectorizationPatterns)); + + // Apply the pad tensor op vectorization separately to avoid running the + // GenericPadOpVectorizationPattern too early. + // TODO: Improve once we have better infrastructure to control pattern + // application. + if (vectorizePadding) { + RewritePatternSet patterns(funcOp.getContext()); + linalg::populatePadOpVectorizationPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + } + + linalg::LinalgVectorizationOptions options; + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to enable the application of other pattern-based linalg +/// passes. +struct LinalgStrategyEnablePass + : public LinalgStrategyEnablePassBase { + + LinalgStrategyEnablePass(linalg::LinalgEnablingOptions opt, + linalg::LinalgTransformationFilter filt) + : options(opt), filter(std::move(filt)) {} + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + return signalPassFailure(); + + if (options.licm) { + funcOp->walk([&](LoopLikeOpInterface loopLike) { + moveLoopInvariantCode(loopLike); + }); + } + + // Gathers all innermost loops through a post order pruned walk. + funcOp.walk([](Operation *op) { + if (auto forOp = dyn_cast(op)) + (void)promoteIfSingleIteration(forOp); + else if (auto forOp = dyn_cast(op)) + (void)promoteIfSingleIteration(forOp); + }); + if (options.hoistRedundantVectorTransfers) + linalg::hoistRedundantVectorTransfers(funcOp); + + if (options.hoistRedundantVectorTransfersOnTensor) + linalg::hoistRedundantVectorTransfersOnTensor(funcOp); + + // Run CSE to cleanup after canonicalization. + OpPassManager dynamicPM("func.func"); + dynamicPM.addPass(createCSEPass()); + if (failed(runPipeline(dynamicPM, funcOp))) + return signalPassFailure(); + } + + linalg::LinalgEnablingOptions options; + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to lower vector operations. +struct LinalgStrategyLowerVectorsPass + : public LinalgStrategyLowerVectorsPassBase< + LinalgStrategyLowerVectorsPass> { + + LinalgStrategyLowerVectorsPass(linalg::LinalgVectorLoweringOptions opt, + linalg::LinalgTransformationFilter filt) + : options(opt), filter(std::move(filt)) {} + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + // In a progressive lowering of vectors, this would be the 1st step. + if (options.contractionLowering) { + patterns.add( + options.vectorTransformOptions, context); + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + } + // In a progressive lowering of vectors, this would be the 2nd step. + if (options.multiReductionLowering) { + vector::populateVectorMultiReductionLoweringPatterns( + patterns, + options.vectorTransformOptions.vectorMultiReductionLowering); + } + // In a progressive lowering of vectors, this would be the 3rd step. + if (options.transferPartialRewrite) { + patterns.add( + context, options.vectorTransformOptions); + } + // In a progressive lowering of vectors, this would be the 4th step. + if (options.transferLowering) { + vector::populateVectorTransferLoweringPatterns(patterns, + options.maxTransferRank); + } + // In a progressive lowering of vectors, this would be the 5th step. + if (options.transferToSCFConversion) { + populateVectorToSCFConversionPatterns( + patterns, options.vectorTransferToSCFOptions.setTargetRank( + options.maxTransferRank)); + } + // In a progressive lowering of vectors, this would be the 6th step. + if (options.shapeCastLowering) { + vector::populateVectorShapeCastLoweringPatterns(patterns); + } + // In a progressive lowering of vectors, this would be the 7th step. + if (options.transposeLowering) { + vector::populateVectorTransposeLoweringPatterns( + patterns, options.vectorTransformOptions); + if (options.avx2Lowering) + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, options.avx2LoweringOptions, /*benefit=*/10); + } + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + linalg::LinalgVectorLoweringOptions options; + linalg::LinalgTransformationFilter filter; +}; + +/// Configurable pass to lower vector operations. +struct LinalgStrategyRemoveMarkersPass + : public LinalgStrategyRemoveMarkersPassBase< + LinalgStrategyRemoveMarkersPass> { + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + funcOp.walk([](linalg::LinalgOp op) { + op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker); + }); + } +}; +} // namespace + +/// Create a LinalgStrategyTileAndFusePass. +std::unique_ptr> +createLinalgStrategyTileAndFusePass( + StringRef opName, const linalg::LinalgTilingAndFusionOptions &options, + const linalg::LinalgTransformationFilter &filter) { + return std::make_unique(opName, options, + filter); +} + +/// Create a LinalgStrategyTilePass. +std::unique_ptr> +createLinalgStrategyTilePass(StringRef opName, + const linalg::LinalgTilingOptions &opt, + const linalg::LinalgTransformationFilter &filter) { + return std::make_unique(opName, opt, filter); +} + +/// Create a LinalgStrategyPadPass. +std::unique_ptr> +createLinalgStrategyPadPass(StringRef opName, + const linalg::LinalgPaddingOptions &opt, + const linalg::LinalgTransformationFilter &filter) { + return std::make_unique(opName, opt, filter); +} + +/// Create a LinalgStrategyDecomposePass. +// TODO: if/when we need finer control add an `opName` parameter. +std::unique_ptr> createLinalgStrategyDecomposePass( + const linalg::LinalgTransformationFilter &filter) { + return std::make_unique(filter); +} + +/// Create a LinalgStrategyPeelPass. +std::unique_ptr> +createLinalgStrategyPeelPass(StringRef opName, + const linalg::LinalgPeelOptions &opt, + const linalg::LinalgTransformationFilter &filter) { + return std::make_unique(opName, opt, filter); +} + +/// Create a LinalgStrategyVectorizePass. +std::unique_ptr> createLinalgStrategyVectorizePass( + StringRef opName, linalg::LinalgVectorizationOptions opt, + const linalg::LinalgTransformationFilter &filter, bool padVectorize) { + return std::make_unique(opName, opt, filter, + padVectorize); +} + +/// Create a LinalgStrategyEnablePass. +std::unique_ptr> createLinalgStrategyEnablePass( + linalg::LinalgEnablingOptions opt, + const linalg::LinalgTransformationFilter &filter) { + return std::make_unique(opt, filter); +} + +/// Create a LinalgStrategyLowerVectorsPass. +std::unique_ptr> +createLinalgStrategyLowerVectorsPass( + linalg::LinalgVectorLoweringOptions opt, + const linalg::LinalgTransformationFilter &filter) { + return std::make_unique(opt, filter); +} + +/// Create a LinalgStrategyRemoveMarkersPass. +std::unique_ptr> +createLinalgStrategyRemoveMarkersPass() { + return std::make_unique(); +} + } // namespace LinalgExt } // namespace IREE } // namespace iree_compiler diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir index c5062cb282ba..4ddad7ce56fd 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir @@ -371,8 +371,7 @@ func.func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) { outs(%real, %imag: memref<16xf32>, memref<16xf32>) return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> // CHECK: func.func @fft_1D // CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]] @@ -429,8 +428,7 @@ func.func @fft_2D(%real: memref, %imag: memref) { outs(%real, %imag: memref, memref) return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: func.func @fft_2D( // CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]] @@ -464,7 +462,6 @@ func.func @fft_2D_coef_buf(%real: memref, %imag: memref, outs(%real, %imag: memref, memref) return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: func.func @fft_2D_coef_buf diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir index 8db1a3ed04f0..2cf6bf7235c7 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir @@ -597,8 +597,7 @@ func.func @fft_1d_stage_5_memref(%arg0: memref<1024xf32>, %arg1: memref<1024xf32 outs(%arg0, %arg1: memref<1024xf32>, memref<1024xf32>) return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> // CHECK: func.func @fft_1d_stage_5_memref( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] @@ -610,12 +609,12 @@ func.func @fft_1d_stage_5_memref(%arg0: memref<1024xf32>, %arg1: memref<1024xf32 // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C1024]] step %[[C32]] { // CHECK: %[[SZ:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C32]], %[[C1024]]] -// CHECK: %[[SUB1:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref -// CHECK: %[[SUB2:.+]] = memref.subview %[[ARG1]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref +// CHECK: %[[SUB1:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref> +// CHECK: %[[SUB2:.+]] = memref.subview %[[ARG1]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref> // CHECK: iree_linalg_ext.fft // CHECK-SAME: {__internal_linalg_transform__ = "tiling_1d_stage5_fft_output"} // CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, memref<16xf32>, memref<16xf32>) -// CHECK-SAME: outs(%[[SUB1]], %[[SUB2]] : memref, memref) +// CHECK-SAME: outs(%[[SUB1]], %[[SUB2]] : memref>, memref>) // ----- @@ -628,7 +627,6 @@ func.func @reverse_memref(%arg0: memref, %arg1: memref) { return } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)> // CHECK: func.func @reverse_memref( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] @@ -778,8 +776,7 @@ func.func @scan_2d_memref(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { } return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> // CHECK: func.func @scan_2d_memref( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] @@ -858,9 +855,7 @@ func.func @topk_tile_memref(%input_values: memref, %input_indices: memr return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 3 + s0 + d1)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> // CHECK-LABEL: func.func @topk_tile_memref // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir index a191979ef9ee..3dc056c0e3a4 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir @@ -30,7 +30,7 @@ iree_linalg_transform.sequence { // EXPAND-NOT: expert apply // EXPAND: %[[OP:.*]] = match @pdl_target // EXPAND: %[[HANDLE:.*]], %{{.*}}:3 = tile %[[OP]] {sizes = [4, 4, 4]} - // EXPAND: %[[HANDLE2:.*]] = vectorize %[[HANDLE]] {vectorize_padding = true} + // EXPAND: %[[HANDLE2:.*]] = vectorize %[[HANDLE]] vectorize_padding // EXPAND: bufferize // EXPAND: lower_vectors {multireduction_lowering = "innerreduce"} // EXPAND: lower_to_llvm @@ -114,7 +114,8 @@ iree_linalg_transform.sequence { // EXPAND: %[[OP:.*]] = match @pdl_target2 // EXPAND: %[[HANDLE:.*]], %{{.*}}:3 = tile %[[OP]] {sizes = [32, 8, 8]} // EXPAND: %[[HANDLE2:.*]], %{{.*}}:3 = tile %[[HANDLE]] {sizes = [4, 4, 4]} - // EXPAND: %[[HANDLE3:.*]] = vectorize %[[HANDLE2]] {vectorize_padding = false} + // EXPAND: %[[HANDLE3:.*]] = vectorize %[[HANDLE2]] + // EXPAND-NOT: vectorize_padding // EXPAND: bufferize // EXPAND: lower_vectors {multireduction_lowering = "innerparallel"} // EXPAND: lower_to_llvm diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir index 547ab700d3be..54b6dfdece33 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir @@ -11,8 +11,8 @@ transform.structured.canonicalized_sequence failures(propagate) { %2, %loops2:3 = transform.structured.tile %1 [2, 2, 2] // CHECK: %[[PADDED:.*]] = transform.structured.pad %[[TILED2]] {hoist_paddings = [], pack_paddings = [1, 1, 0], padding_dimensions = [], padding_values = [], transpose_paddings = []} %3 = transform.structured.pad %2 {pack_paddings = [1, 1, 0]} - // CHECK: %{{.*}} = transform.structured.vectorize %[[PADDED]] {vectorize_padding = true} - %4 = transform.structured.vectorize %3 {vectorize_padding = true} + // CHECK: %{{.*}} = transform.structured.vectorize %[[PADDED]] {vectorize_padding} + %4 = transform.structured.vectorize %3 { vectorize_padding } // CHECK: %[[OPS2:.*]] = pdl_match @{{.*}} %5 = pdl_match @match2 in %arg0 // CHECK: transform.structured.vectorize %[[OPS2]] diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir index 8e54e628bb89..7b93cbed81ee 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir @@ -18,7 +18,7 @@ transform.structured.canonicalized_sequence failures(propagate) { %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] %2 = get_closest_isolated_parent %1 - transform.structured.vectorize %2 {vectorize_padding = true} + transform.structured.vectorize %2 { vectorize_padding } bufferize lower_vectors { multireduction_lowering = "innerreduce"} lower_to_llvm diff --git a/runtime/src/iree/base/config.h b/runtime/src/iree/base/config.h index ae29e5ebe1e1..dacfaaf26afc 100644 --- a/runtime/src/iree/base/config.h +++ b/runtime/src/iree/base/config.h @@ -164,6 +164,13 @@ typedef IREE_DEVICE_SIZE_T iree_device_size_t; // Enables optional HAL features. Each of these may add several KB to the final // binary when linked dynamically. +// To use an import provider in the built-in CPU drivers define a function like: +// iree_hal_executable_import_provider_t my_provider(void) { ... } +// And define it: +// -DIREE_HAL_EXECUTABLE_IMPORT_PROVIDER_DEFAULT_FN=my_provider +// This will only work for default drivers and otherwise users can explicitly +// specify the provider when creating the executable loaders themselves. + #if !defined(IREE_HAL_HEAP_BUFFER_ALIGNMENT) // Power of two byte alignment required on all host heap buffers. // Executables are compiled with alignment expectations and the runtime diff --git a/runtime/src/iree/base/internal/cpu.c b/runtime/src/iree/base/internal/cpu.c index a975c83ca740..4bee83d55fe4 100644 --- a/runtime/src/iree/base/internal/cpu.c +++ b/runtime/src/iree/base/internal/cpu.c @@ -60,6 +60,30 @@ static void iree_cpu_initialize_from_platform(iree_allocator_t temp_allocator, iree_cpu_query_data_arch_hwcaps(hwcap, hwcap2, out_fields); } +#elif defined(IREE_PLATFORM_MACOS) || defined(IREE_PLATFORM_IOS) + +#include +#include + +#define IREE_QUERY_SYSCTL(key, field_value, field_bit) \ + do { \ + int64_t result = 0; \ + size_t result_size = sizeof result; \ + if (0 == sysctlbyname(key, &result, &result_size, NULL, 0)) { \ + if (result) field_value |= field_bit; \ + } \ + } while (0) + +static void iree_cpu_initialize_from_platform(iree_allocator_t temp_allocator, + uint64_t* out_fields) { +#if defined(IREE_ARCH_ARM_64) + IREE_QUERY_SYSCTL("hw.optional.arm.FEAT_DotProd", out_fields[0], + IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD); + IREE_QUERY_SYSCTL("hw.optional.arm.FEAT_I8MM", out_fields[0], + IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM); +#endif +} + #else static void iree_cpu_initialize_from_platform(iree_allocator_t temp_allocator, diff --git a/runtime/src/iree/base/internal/synchronization.c b/runtime/src/iree/base/internal/synchronization.c index ac0a9d93e83d..ccb42a4bf975 100644 --- a/runtime/src/iree/base/internal/synchronization.c +++ b/runtime/src/iree/base/internal/synchronization.c @@ -178,7 +178,7 @@ static inline iree_status_code_t iree_futex_wait(void* address, int rc = syscall( SYS_futex, address, FUTEX_WAIT | FUTEX_PRIVATE_FLAG, expected_value, timeout_ms == IREE_INFINITE_TIMEOUT_MS ? NULL : &timeout, NULL, 0); - if (IREE_LIKELY(rc == 0) || errno == EAGAIN) { + if (IREE_LIKELY(rc == 0) || errno == EAGAIN || errno == EINTR) { return IREE_STATUS_OK; } else if (errno == ETIMEDOUT) { return IREE_STATUS_DEADLINE_EXCEEDED; diff --git a/runtime/src/iree/base/status_cc.cc b/runtime/src/iree/base/status_cc.cc index c5e6e60b9d73..8ca5fb077fba 100644 --- a/runtime/src/iree/base/status_cc.cc +++ b/runtime/src/iree/base/status_cc.cc @@ -8,17 +8,11 @@ #include #include -#include #include "iree/base/attributes.h" namespace iree { -std::ostream& operator<<(std::ostream& os, const StatusCode& x) { - os << StatusCodeToString(x); - return os; -} - // static IREE_MUST_USE_RESULT std::string Status::ToString(iree_status_t status) { if (iree_status_is_ok(status)) { @@ -38,11 +32,6 @@ IREE_MUST_USE_RESULT std::string Status::ToString(iree_status_t status) { return result; } -std::ostream& operator<<(std::ostream& os, const Status& x) { - os << x.ToString(); - return os; -} - namespace status_impl { void Helper::HandleInvalidStatusCtorArg(Status* status) { diff --git a/runtime/src/iree/base/status_cc.h b/runtime/src/iree/base/status_cc.h index 3170d0c6e7b9..cdfb8ad7ce7b 100644 --- a/runtime/src/iree/base/status_cc.h +++ b/runtime/src/iree/base/status_cc.h @@ -94,9 +94,6 @@ static inline const char* StatusCodeToString(StatusCode code) { return iree_status_code_string(static_cast(code)); } -// Prints a human-readable representation of `x` to `os`. -std::ostream& operator<<(std::ostream& os, const StatusCode& x); - //===----------------------------------------------------------------------===// // Status //===----------------------------------------------------------------------===// @@ -230,9 +227,6 @@ class Status final { // Returns an OK status, equivalent to a default constructed instance. IREE_MUST_USE_RESULT static inline Status OkStatus() { return Status(); } -// Prints a human-readable representation of `x` to `os`. -std::ostream& operator<<(std::ostream& os, const Status& x); - IREE_MUST_USE_RESULT static inline bool IsOk(const Status& status) { return status.code() == StatusCode::kOk; } diff --git a/runtime/src/iree/base/status_test.cc b/runtime/src/iree/base/status_test.cc index c035e1ee1d50..a9cf5d9d0f8a 100644 --- a/runtime/src/iree/base/status_test.cc +++ b/runtime/src/iree/base/status_test.cc @@ -17,6 +17,11 @@ namespace iree { namespace { +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + using ::iree::testing::status::StatusIs; using ::testing::HasSubstr; diff --git a/runtime/src/iree/base/string_builder.c b/runtime/src/iree/base/string_builder.c index 590e1d54c7c5..2aa6994cc9bc 100644 --- a/runtime/src/iree/base/string_builder.c +++ b/runtime/src/iree/base/string_builder.c @@ -33,6 +33,11 @@ IREE_API_EXPORT void iree_string_builder_deinitialize( memset(builder, 0, sizeof(*builder)); } +static bool iree_string_builder_is_calculating_size( + const iree_string_builder_t* builder) { + return iree_allocator_is_null(builder->allocator) && builder->buffer == NULL; +} + IREE_API_EXPORT const char* iree_string_builder_buffer( const iree_string_builder_t* builder) { return builder->buffer; @@ -74,13 +79,23 @@ IREE_API_EXPORT char* iree_string_builder_take_storage( IREE_API_EXPORT iree_status_t iree_string_builder_reserve( iree_string_builder_t* builder, iree_host_size_t minimum_capacity) { - if (iree_allocator_is_null(builder->allocator)) return iree_ok_status(); iree_host_size_t new_capacity = builder->capacity; if (builder->capacity < minimum_capacity) { new_capacity = iree_host_align(minimum_capacity, IREE_STRING_BUILDER_ALIGNMENT); } - if (builder->capacity >= new_capacity) return iree_ok_status(); + if (builder->capacity >= new_capacity) { + // Already at/above the requested minimum capacity. + return iree_ok_status(); + } else if (iree_allocator_is_null(builder->allocator)) { + // No allocator provided and the builder cannot grow. + return iree_make_status( + IREE_STATUS_RESOURCE_EXHAUSTED, + "non-growable builder capacity exceeded (capacity=%" PRIhsz + "; requested=%" PRIhsz ", adjusted=%" PRIhsz ")", + builder->capacity, minimum_capacity, new_capacity); + } + IREE_RETURN_IF_ERROR(iree_allocator_realloc(builder->allocator, new_capacity, (void**)&builder->buffer)); builder->buffer[builder->size] = 0; @@ -88,12 +103,24 @@ IREE_API_EXPORT iree_status_t iree_string_builder_reserve( return iree_ok_status(); } +IREE_API_EXPORT iree_status_t iree_string_builder_append_inline( + iree_string_builder_t* builder, iree_host_size_t count, char** out_head) { + *out_head = NULL; + if (!iree_string_builder_is_calculating_size(builder)) { + IREE_RETURN_IF_ERROR(iree_string_builder_reserve( + builder, builder->size + count + /*NUL=*/1)); + *out_head = &builder->buffer[builder->size]; + } + builder->size += count; + return iree_ok_status(); +} + IREE_API_EXPORT iree_status_t iree_string_builder_append_string( iree_string_builder_t* builder, iree_string_view_t value) { // Ensure capacity for the value + NUL terminator. - IREE_RETURN_IF_ERROR( - iree_string_builder_reserve(builder, builder->size + value.size + 1)); - if (builder->buffer != NULL) { + if (!iree_string_builder_is_calculating_size(builder)) { + IREE_RETURN_IF_ERROR( + iree_string_builder_reserve(builder, builder->size + value.size + 1)); // Only copy the bytes if we are not doing a size calculation. memcpy(builder->buffer + builder->size, value.data, value.size); builder->buffer[builder->size + value.size] = 0; // NUL @@ -125,14 +152,17 @@ static iree_status_t iree_string_builder_append_format_impl( return iree_ok_status(); } - // Reserve new minimum capacity. - IREE_RETURN_IF_ERROR(iree_string_builder_reserve( - builder, iree_string_builder_size(builder) + n + /*NUL*/ 1)); + if (!iree_string_builder_is_calculating_size(builder)) { + // Reserve new minimum capacity. + IREE_RETURN_IF_ERROR(iree_string_builder_reserve( + builder, iree_string_builder_size(builder) + n + /*NUL*/ 1)); + + // Try printing again. + vsnprintf(builder->buffer ? builder->buffer + builder->size : NULL, + builder->buffer ? builder->capacity - builder->size : 0, format, + varargs_1); + } - // Try printing again. - vsnprintf(builder->buffer ? builder->buffer + builder->size : NULL, - builder->buffer ? builder->capacity - builder->size : 0, format, - varargs_1); builder->size += n; return iree_ok_status(); } diff --git a/runtime/src/iree/base/string_builder.h b/runtime/src/iree/base/string_builder.h index ff6eeba6709a..007f87cb1d61 100644 --- a/runtime/src/iree/base/string_builder.h +++ b/runtime/src/iree/base/string_builder.h @@ -106,6 +106,15 @@ IREE_API_EXPORT IREE_MUST_USE_RESULT char* iree_string_builder_take_storage( IREE_API_EXPORT iree_status_t iree_string_builder_reserve( iree_string_builder_t* builder, iree_host_size_t minimum_capacity); +// Reserves storage for |count| characters (including NUL) and returns a mutable +// pointer in |out_head| for the caller to write the characters. +// The pointer is only valid so long as the string builder is initialized and +// unmodified. No NUL terminator is added by this call. +// |out_head| will be NULL if the string builder is operating in size +// calculation mode. +IREE_API_EXPORT iree_status_t iree_string_builder_append_inline( + iree_string_builder_t* builder, iree_host_size_t count, char** out_head); + // Appends a string to the builder. IREE_API_EXPORT iree_status_t iree_string_builder_append_string( iree_string_builder_t* builder, iree_string_view_t value); diff --git a/runtime/src/iree/base/string_builder_test.cc b/runtime/src/iree/base/string_builder_test.cc index fad7034ab83b..35987fe1048a 100644 --- a/runtime/src/iree/base/string_builder_test.cc +++ b/runtime/src/iree/base/string_builder_test.cc @@ -12,6 +12,10 @@ namespace { +using iree::Status; +using iree::StatusCode; +using iree::testing::status::StatusIs; + struct StringBuilder { static StringBuilder MakeSystem() { iree_string_builder_t builder; @@ -37,6 +41,18 @@ struct StringBuilder { } iree_string_builder_t builder; + + protected: + StringBuilder() = default; +}; + +template +struct InlineStringBuilder : public StringBuilder { + InlineStringBuilder() { + iree_string_builder_initialize_with_storage(storage, sizeof(storage), + &builder); + } + char storage[Capacity] = {0}; }; TEST(StringBuilderTest, QueryEmpty) { @@ -161,4 +177,44 @@ TEST(StringBuilderTest, Format) { std::string("abcabc") + std::string(1023, ' ') + std::string("x")); } +TEST(StringBuilderTest, InlineStorage) { + InlineStringBuilder<8> builder; + EXPECT_EQ(iree_string_builder_size(builder), 0); + EXPECT_GE(iree_string_builder_capacity(builder), 8); + EXPECT_TRUE(iree_string_view_is_empty(iree_string_builder_view(builder))); + + // Should be able to reserve up to capacity. + IREE_EXPECT_OK(iree_string_builder_reserve(builder, 4)); + IREE_EXPECT_OK(iree_string_builder_reserve(builder, 8)); + + // Should fail to reserve more than storage size. + EXPECT_THAT(Status(iree_string_builder_reserve(builder, 9)), + StatusIs(StatusCode::kResourceExhausted)); +} + +TEST(StringBuilderTest, SizeCalculation) { + auto builder = StringBuilder::MakeEmpty(); + EXPECT_EQ(iree_string_builder_size(builder), 0); + EXPECT_GE(iree_string_builder_capacity(builder), 0); + EXPECT_TRUE(iree_string_view_is_empty(iree_string_builder_view(builder))); + + IREE_EXPECT_OK(iree_string_builder_append_cstring(builder, "abc")); + EXPECT_EQ(iree_string_builder_size(builder), 3); + EXPECT_GE(iree_string_builder_capacity(builder), 0); + + IREE_EXPECT_OK(iree_string_builder_append_format(builder, "def")); + EXPECT_EQ(iree_string_builder_size(builder), 6); + EXPECT_GE(iree_string_builder_capacity(builder), 0); + + char* head = NULL; + IREE_EXPECT_OK(iree_string_builder_append_inline(builder, 3, &head)); + EXPECT_TRUE(head == NULL); + EXPECT_EQ(iree_string_builder_size(builder), 9); + EXPECT_GE(iree_string_builder_capacity(builder), 0); + + // Reservation should fail because there's no allocator. + EXPECT_THAT(Status(iree_string_builder_reserve(builder, 4)), + StatusIs(StatusCode::kResourceExhausted)); +} + } // namespace diff --git a/runtime/src/iree/builtins/ukernel/BUILD b/runtime/src/iree/builtins/ukernel/BUILD index 616dd2a86b04..d055b30730fa 100644 --- a/runtime/src/iree/builtins/ukernel/BUILD +++ b/runtime/src/iree/builtins/ukernel/BUILD @@ -12,31 +12,65 @@ package( licenses = ["notice"], # Apache 2.0 ) +# :types is the type declarations used by both the entry points and the +# internal implementation functions. iree_runtime_cc_library( - name = "ukernel", + name = "types", + hdrs = [ + "common.h", + "mmt4d_types.h", + ], + deps = [ + "//runtime/src/iree/base:core_headers", + "//runtime/src/iree/builtins/ukernel/arch:config", + ], +) + +# :generic contains non-architecture-specific implementations. +iree_runtime_cc_library( + name = "generic", + srcs = [ + "mmt4d_select_tile_generic.c", + ], + hdrs = [ + "mmt4d_select_tile_generic.h", + ], + deps = [ + ":types", + ], +) + +# elementwise code is structured differently from other kernels. In fact it's +# profoundly different: it carries its own custom shims. For now, we keep it +# separate from the rest. +iree_runtime_cc_library( + name = "elementwise", srcs = [ "elementwise_generic.c", "elementwise_impl.c.inc", - "mmt4d.c", - "mmt4d_arm_64.c", - "mmt4d_generic.c", ], hdrs = [ - "common.h", "elementwise.h", - "mmt4d.h", - "mmt4d_arm_64.h", - "mmt4d_generic.h", ], - copts = [ - # Placeholder for a real flag. - "-DIREE_UKERNEL_PLATFORM_EXAMPLE_FLAG=1", + deps = [ + ":types", ], - defines = [ - "IREE_HAVE_UKERNEL_BUILTINS=1", +) + +# Entry points. +iree_runtime_cc_library( + name = "ukernel", + srcs = [ + "mmt4d.c", + ], + hdrs = [ + "elementwise.h", + "mmt4d.h", ], deps = [ - "//runtime/src/iree/base:core_headers", - "//runtime/src/iree/schemas:cpu_data", + ":elementwise", + ":generic", + ":types", + "//runtime/src/iree/builtins/ukernel/arch:ukernel_arch", ], ) diff --git a/runtime/src/iree/builtins/ukernel/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/CMakeLists.txt index f514b2ffe0e4..79814886ba40 100644 --- a/runtime/src/iree/builtins/ukernel/CMakeLists.txt +++ b/runtime/src/iree/builtins/ukernel/CMakeLists.txt @@ -12,26 +12,54 @@ iree_add_all_subdirs() iree_cc_library( NAME - ukernel - COPTS - "-DIREE_UKERNEL_PLATFORM_EXAMPLE_FLAG=1" + types HDRS "common.h" + "mmt4d_types.h" + DEPS + iree::base::core_headers + iree::builtins::ukernel::arch::config + PUBLIC +) + +iree_cc_library( + NAME + generic + HDRS + "mmt4d_select_tile_generic.h" + SRCS + "mmt4d_select_tile_generic.c" + DEPS + ::types + PUBLIC +) + +iree_cc_library( + NAME + elementwise + HDRS "elementwise.h" - "mmt4d.h" - "mmt4d_arm_64.h" - "mmt4d_generic.h" SRCS "elementwise_generic.c" "elementwise_impl.c.inc" + DEPS + ::types + PUBLIC +) + +iree_cc_library( + NAME + ukernel + HDRS + "elementwise.h" + "mmt4d.h" + SRCS "mmt4d.c" - "mmt4d_arm_64.c" - "mmt4d_generic.c" DEPS - iree::base::core_headers - iree::schemas::cpu_data - DEFINES - "IREE_HAVE_UKERNEL_BUILTINS=1" + ::elementwise + ::generic + ::types + iree::builtins::ukernel::arch::ukernel_arch PUBLIC ) diff --git a/runtime/src/iree/builtins/ukernel/arch/BUILD b/runtime/src/iree/builtins/ukernel/arch/BUILD new file mode 100644 index 000000000000..d0a9212a28d7 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/BUILD @@ -0,0 +1,40 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library") +load("@bazel_skylib//rules:copy_file.bzl", "copy_file") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +copy_file( + name = "gen_config", + src = "config.h.bazel-generic", + out = "config.h", +) + +iree_runtime_cc_library( + name = "config", + hdrs = ["config.h"], +) + +# :types is the type declarations used by both the entry points and the +# internal implementation functions. +iree_runtime_cc_library( + name = "ukernel_arch", + srcs = [ + "mmt4d_select_tile_arch.c", + ], + hdrs = [ + "mmt4d_select_tile_arch.h", + ], + deps = [ + "//runtime/src/iree/builtins/ukernel:types", + ], +) diff --git a/runtime/src/iree/builtins/ukernel/arch/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/CMakeLists.txt new file mode 100644 index 000000000000..9ac064f5d02f --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +option(IREE_UKERNEL_FORCE_DISABLE_ARCH_SPECIFIC_CODE "Disable all architecture-specific code in builtin kernels" OFF) + +if(IREE_UKERNEL_FORCE_DISABLE_ARCH_SPECIFIC_CODE) + set(IREE_UKERNEL_ENABLE_ARCH_SPECIFIC_CODE FALSE) +else() + set(IREE_UKERNEL_ENABLE_ARCH_SPECIFIC_CODE TRUE) +endif() + +# This block is borrowed from boringssl's CMake code here: +# https://boringssl.googlesource.com/boringssl/+/c5f0e58e653d2d9afa8facc090ce09f8aaa3fa0d/CMakeLists.txt#43 +if(CMAKE_OSX_ARCHITECTURES) + list(LENGTH CMAKE_OSX_ARCHITECTURES NUM_ARCHES) + if(NOT ${NUM_ARCHES} EQUAL 1) + message(WARNING "Performance advisory: architecture-specific code paths disabled because this is a multi-architecture build.") + set(IREE_UKERNEL_ENABLE_ARCH_SPECIFIC_CODE FALSE) + endif() + list(GET CMAKE_OSX_ARCHITECTURES 0 CMAKE_SYSTEM_PROCESSOR) +endif() + +if(IREE_UKERNEL_ENABLE_ARCH_SPECIFIC_CODE) + if((CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) OR (CMAKE_SYSTEM_PROCESSOR STREQUAL arm64)) + set(IREE_UKERNEL_ARCH_ARM_64 TRUE) + add_subdirectory(arm_64) + list(APPEND IREE_UKERNEL_ARCH_DEPS "iree::builtins::ukernel::arch::arm_64::mmt4d_select_tile_arm_64") + endif() +endif() # IREE_UKERNEL_ENABLE_ARCH_SPECIFIC_CODE + +set(IREE_UKERNEL_POINTER_SIZE "${CMAKE_SIZEOF_VOID_P}") + +iree_cc_library( + NAME + config + HDRS + ${CMAKE_CURRENT_BINARY_DIR}/config.h +) + +iree_cc_library( + NAME + ukernel_arch + HDRS + "mmt4d_select_tile_arch.h" + SRCS + "mmt4d_select_tile_arch.c" + DEPS + iree::builtins::ukernel::types + ${IREE_UKERNEL_ARCH_DEPS} + PUBLIC +) + +configure_file(config.h.in config.h) diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD b/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD new file mode 100644 index 000000000000..cf55b8d77a1b --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD @@ -0,0 +1,20 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_runtime_cc_library( + name = "mmt4d_select_tile_arm_64", + hdrs = [ + "mmt4d_select_tile_arm_64.h", + ], +) diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt new file mode 100644 index 000000000000..12ce6d5b3811 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt @@ -0,0 +1,50 @@ +iree_cc_library( + NAME + mmt4d_tile_arm_64 + SRCS + "mmt4d_tile_arm_64.S" +) +list(APPEND IREE_UKERNEL_MMT4D_TILE_ARM_64_DEPS "iree::builtins::ukernel::arch::arm_64::mmt4d_tile_arm_64") + +check_cxx_compiler_flag("-march=armv8.2-a+dotprod" IREE_UKERNEL_BUILD_ARM_64_DOTPROD) +if(IREE_UKERNEL_BUILD_ARM_64_DOTPROD) + iree_cc_library( + NAME + mmt4d_tile_arm_64_dotprod + SRCS + "mmt4d_tile_arm_64_dotprod.S" + COPTS + "-march=armv8.2-a+dotprod" + ) + list(APPEND IREE_UKERNEL_MMT4D_TILE_ARM_64_DEPS "iree::builtins::ukernel::arch::arm_64::mmt4d_tile_arm_64_dotprod") +endif() + +check_cxx_compiler_flag("-march=armv8.2-a+i8mm" IREE_UKERNEL_BUILD_ARM_64_I8MM) +if(IREE_UKERNEL_BUILD_ARM_64_I8MM) + iree_cc_library( + NAME + mmt4d_tile_arm_64_i8mm + SRCS + "mmt4d_tile_arm_64_i8mm.S" + COPTS + "-march=armv8.2-a+i8mm" + ) + list(APPEND IREE_UKERNEL_MMT4D_TILE_ARM_64_DEPS "iree::builtins::ukernel::arch::arm_64::mmt4d_tile_arm_64_i8mm") +endif() + +configure_file(config.h.in config.h) + +iree_cc_library( + NAME + mmt4d_select_tile_arm_64 + HDRS + "mmt4d_select_tile_arm_64.h" + SRCS + "mmt4d_select_tile_arm_64.c" + DEPS + iree::base::core_headers + iree::schemas::cpu_data + iree::builtins::ukernel::types + ${IREE_UKERNEL_MMT4D_TILE_ARM_64_DEPS} + PUBLIC +) diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/assembly.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/assembly.h new file mode 100644 index 000000000000..03537f8856cb --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/assembly.h @@ -0,0 +1,49 @@ +// Borrowed from XNNPACK's assembly.h (thanks!) +// clang-format off +#ifdef __wasm__ + .macro BEGIN_FUNCTION name + .text + .section .text.\name,"",@ + .hidden \name + .globl \name + .type \name,@function + \name: + .endm + + .macro END_FUNCTION name + end_function + .endm +#elif defined(__ELF__) + .macro BEGIN_FUNCTION name + .text + .p2align 4 + .global \name + .hidden \name + .type \name, %function + \name: + .endm + + .macro END_FUNCTION name + .size \name, .-\name + .endm +#elif defined(__MACH__) + .macro BEGIN_FUNCTION name + .text + .p2align 4 + .global _\name + .private_extern _\name + _\name: + .endm + + .macro END_FUNCTION name + .endm +#endif + +#ifdef __ELF__ + .macro ALLOW_NON_EXECUTABLE_STACK + .section ".note.GNU-stack","",%progbits + .endm +#else + .macro ALLOW_NON_EXECUTABLE_STACK + .endm +#endif diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/config.h.in b/runtime/src/iree/builtins/ukernel/arch/arm_64/config.h.in new file mode 100644 index 000000000000..02701bd0490a --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/config.h.in @@ -0,0 +1,2 @@ +#cmakedefine IREE_UKERNEL_BUILD_ARM_64_DOTPROD +#cmakedefine IREE_UKERNEL_BUILD_ARM_64_I8MM diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.c new file mode 100644 index 000000000000..5e3473ab083d --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.c @@ -0,0 +1,94 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.h" + +#include "iree/builtins/ukernel/arch/arm_64/config.h" +#include "iree/schemas/cpu_data.h" + +IREE_UKERNEL_MMT4D_TILE_FUNC_DECL( + iree_ukernel_mmt4d_f32f32f32_tile_8x8x1_arm_64) +IREE_UKERNEL_MMT4D_TILE_FUNC_DECL(iree_ukernel_mmt4d_i8i8i32_tile_8x8x1_arm_64) +IREE_UKERNEL_MMT4D_TILE_FUNC_DECL( + iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod) +IREE_UKERNEL_MMT4D_TILE_FUNC_DECL( + iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm) + +static iree_ukernel_mmt4d_tile_func_t +iree_ukernel_mmt4d_select_tile_func_arm_64_f32f32f32_8x8x1( + const iree_ukernel_mmt4d_params_t* params) { + (void)params; + return iree_ukernel_mmt4d_f32f32f32_tile_8x8x1_arm_64; +} + +static iree_ukernel_mmt4d_tile_func_t +iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x1( + const iree_ukernel_mmt4d_params_t* params) { + (void)params; + return iree_ukernel_mmt4d_i8i8i32_tile_8x8x1_arm_64; +} + +static iree_ukernel_mmt4d_tile_func_t +iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x8( + const iree_ukernel_mmt4d_params_t* params) { +#ifdef IREE_UKERNEL_BUILD_ARM_64_I8MM + if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) { + return iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm; + } +#else + (void)params; +#endif + return 0; +} + +static iree_ukernel_mmt4d_tile_func_t +iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x4( + const iree_ukernel_mmt4d_params_t* params) { +#ifdef IREE_UKERNEL_BUILD_ARM_64_DOTPROD + if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) { + return iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod; + } +#else + (void)params; +#endif + return 0; +} + +static iree_ukernel_mmt4d_tile_func_t +iree_ukernel_mmt4d_select_tile_func_arm_64_f32f32f32( + const iree_ukernel_mmt4d_params_t* params) { + if (params->M0 == 8 && params->N0 == 8 && params->K0 == 1) { + return iree_ukernel_mmt4d_select_tile_func_arm_64_f32f32f32_8x8x1(params); + } + return 0; +} + +static iree_ukernel_mmt4d_tile_func_t +iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32( + const iree_ukernel_mmt4d_params_t* params) { + if (params->M0 == 8 && params->N0 == 8 && params->K0 == 1) { + return iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x1(params); + } + if (params->M0 == 8 && params->N0 == 8 && params->K0 == 4) { + return iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x4(params); + } + if (params->M0 == 8 && params->N0 == 8 && params->K0 == 8) { + return iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x8(params); + } + return 0; +} + +iree_ukernel_mmt4d_tile_func_t iree_ukernel_mmt4d_select_tile_func_arm_64( + const iree_ukernel_mmt4d_params_t* params) { + switch (params->type) { + case iree_ukernel_mmt4d_type_f32f32f32: + return iree_ukernel_mmt4d_select_tile_func_arm_64_f32f32f32(params); + case iree_ukernel_mmt4d_type_i8i8i32: + return iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32(params); + default: + return 0; + } +} diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.h new file mode 100644 index 000000000000..4fa8105676a0 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.h @@ -0,0 +1,18 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BUILTINS_UKERNEL_ARM_64_MMT4D_SELECT_TILE_ARM_64_H_ +#define IREE_BUILTINS_UKERNEL_ARM_64_MMT4D_SELECT_TILE_ARM_64_H_ + +#include "iree/builtins/ukernel/mmt4d_types.h" + +// Returns the arm64 tile function to use for the mmt4d with given params, or +// NULL if no suitable arm64 tile function exists for these params, in which +// case the caller may fall back to a generic tile function. +iree_ukernel_mmt4d_tile_func_t iree_ukernel_mmt4d_select_tile_func_arm_64( + const iree_ukernel_mmt4d_params_t* params); + +#endif // IREE_BUILTINS_UKERNEL_ARM_64_MMT4D_SELECT_TILE_ARM_64_H_ diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64.S b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64.S new file mode 100644 index 000000000000..b10c41997cc4 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64.S @@ -0,0 +1,182 @@ +#include "assembly.h" + +// TODO: share these bits with C/C++. +.equ ACCUMULATE_FLAG_BIT_POS,0 + +// Parameters: +// x0: int32_t* out_tile +// x1: const int8_t* lhs_panel +// x2: const int8_t* rhs_panel +// w3: int32_t K. Note: K>=1, as the K==0 case was handled as an early-return. +// w4: uint32_t flags +// x5: (UNUSED) params - relevant params K and flags already passed above. + +BEGIN_FUNCTION iree_ukernel_mmt4d_f32f32f32_tile_8x8x1_arm_64 + + // Do we accumulate into or clear the accumulator tile? + tbnz w4, ACCUMULATE_FLAG_BIT_POS, 1f + + 0: + // No-accumulate case. Clear the 8x8 accumulator tile. + movi v16.16b, 0 + movi v17.16b, 0 + movi v18.16b, 0 + movi v19.16b, 0 + movi v20.16b, 0 + movi v21.16b, 0 + movi v22.16b, 0 + movi v23.16b, 0 + movi v24.16b, 0 + movi v25.16b, 0 + movi v26.16b, 0 + movi v27.16b, 0 + movi v28.16b, 0 + movi v29.16b, 0 + movi v30.16b, 0 + movi v31.16b, 0 + b 2f + + 1: + // Accumulate case. Load the 8x8 accumulator tile from row-major + // out_tile, into temporary registers v16--v31. + ldp q16, q17, [x0, 0] + ldp q18, q19, [x0, 32] + ldp q20, q21, [x0, 64] + ldp q22, q23, [x0, 96] + ldp q24, q25, [x0, 128] + ldp q26, q27, [x0, 160] + ldp q28, q29, [x0, 192] + ldp q30, q31, [x0, 224] + + 2: + // Loop body. Decrement the loop counter K. + subs w3, w3, 1 + // Load 8x1 LHS tile + ldp q0, q1, [x1, 0] + add x1, x1, 32 + // Load 8x1 RHS tile + ldp q4, q5, [x2, 0] + add x2, x2, 32 + // Multiply-accumulate, row 0. + fmla v16.4s, v4.4s, v0.s[0] + fmla v17.4s, v5.4s, v0.s[0] + // Multiply-accumulate, row 1. + fmla v18.4s, v4.4s, v0.s[1] + fmla v19.4s, v5.4s, v0.s[1] + // Multiply-accumulate, row 2. + fmla v20.4s, v4.4s, v0.s[2] + fmla v21.4s, v5.4s, v0.s[2] + // Multiply-accumulate, row 3. + fmla v22.4s, v4.4s, v0.s[3] + fmla v23.4s, v5.4s, v0.s[3] + // Multiply-accumulate, row 4. + fmla v24.4s, v4.4s, v1.s[0] + fmla v25.4s, v5.4s, v1.s[0] + // Multiply-accumulate, row 5. + fmla v26.4s, v4.4s, v1.s[1] + fmla v27.4s, v5.4s, v1.s[1] + // Multiply-accumulate, row 6. + fmla v28.4s, v4.4s, v1.s[2] + fmla v29.4s, v5.4s, v1.s[2] + // Multiply-accumulate, row 7. + fmla v30.4s, v4.4s, v1.s[3] + fmla v31.4s, v5.4s, v1.s[3] + // Loop if K != 0. + b.ne 2b + + 3: + // Store the accumulator tile to the destination. + stp q16, q17, [x0, 0] + stp q18, q19, [x0, 32] + stp q20, q21, [x0, 64] + stp q22, q23, [x0, 96] + stp q24, q25, [x0, 128] + stp q26, q27, [x0, 160] + stp q28, q29, [x0, 192] + stp q30, q31, [x0, 224] + ret + +END_FUNCTION iree_ukernel_mmt4d_f32f32f32_tile_8x8x1_arm_64 + +BEGIN_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x1_arm_64 + + // Do we accumulate into or clear the accumulator tile? + tbnz w4, ACCUMULATE_FLAG_BIT_POS, 1f + + 0: + // No-accumulate case. Clear the 8x8 accumulator tile. + movi v16.16b, 0 + movi v17.16b, 0 + movi v18.16b, 0 + movi v19.16b, 0 + movi v20.16b, 0 + movi v21.16b, 0 + movi v22.16b, 0 + movi v23.16b, 0 + movi v24.16b, 0 + movi v25.16b, 0 + movi v26.16b, 0 + movi v27.16b, 0 + movi v28.16b, 0 + movi v29.16b, 0 + movi v30.16b, 0 + movi v31.16b, 0 + b 2f + + 1: + // Accumulate case. Load the 8x8 accumulator tile from row-major + // out_tile, into temporary registers v16--v31. + ldp q16, q17, [x0, 0] + ldp q18, q19, [x0, 32] + ldp q20, q21, [x0, 64] + ldp q22, q23, [x0, 96] + ldp q24, q25, [x0, 128] + ldp q26, q27, [x0, 160] + ldp q28, q29, [x0, 192] + ldp q30, q31, [x0, 224] + + 2: + // Loop body. Decrement the loop counter K. + subs w3, w3, 1 + // Load 8x1 LHS tile + ldr d0, [x1, 0] + add x1, x1, 8 + // Load 8x4 RHS tile + ldr d4, [x2, 0] + add x2, x2, 8 + sxtl v1.8h, v0.8b + sxtl v5.8h, v4.8b + smlal v16.4s, v5.4h, v1.h[0] + smlal2 v17.4s, v5.8h, v1.h[0] + smlal v18.4s, v5.4h, v1.h[1] + smlal2 v19.4s, v5.8h, v1.h[1] + smlal v20.4s, v5.4h, v1.h[2] + smlal2 v21.4s, v5.8h, v1.h[2] + smlal v22.4s, v5.4h, v1.h[3] + smlal2 v23.4s, v5.8h, v1.h[3] + smlal v24.4s, v5.4h, v1.h[4] + smlal2 v25.4s, v5.8h, v1.h[4] + smlal v26.4s, v5.4h, v1.h[5] + smlal2 v27.4s, v5.8h, v1.h[5] + smlal v28.4s, v5.4h, v1.h[6] + smlal2 v29.4s, v5.8h, v1.h[6] + smlal v30.4s, v5.4h, v1.h[7] + smlal2 v31.4s, v5.8h, v1.h[7] + // Loop if K != 0. + b.ne 2b + + 3: + // Store the accumulator tile to the destination. + stp q16, q17, [x0, 0] + stp q18, q19, [x0, 32] + stp q20, q21, [x0, 64] + stp q22, q23, [x0, 96] + stp q24, q25, [x0, 128] + stp q26, q27, [x0, 160] + stp q28, q29, [x0, 192] + stp q30, q31, [x0, 224] + ret + +END_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x1_arm_64 + +ALLOW_NON_EXECUTABLE_STACK diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_dotprod.S b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_dotprod.S new file mode 100644 index 000000000000..a540688bd66e --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_dotprod.S @@ -0,0 +1,101 @@ +#include "assembly.h" + +// TODO: share these bits with C/C++. +.equ ACCUMULATE_FLAG_BIT_POS,0 + +// Parameters: +// x0: int32_t* out_tile +// x1: const int8_t* lhs_panel +// x2: const int8_t* rhs_panel +// w3: int32_t K. Note: K>=1, as the K==0 case was handled as an early-return. +// w4: uint32_t flags +// x5: (UNUSED) params - relevant params K and flags already passed above. + +BEGIN_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod + + // Do we accumulate into or clear the accumulator tile? + tbnz w4, ACCUMULATE_FLAG_BIT_POS, 1f + + 0: + // No-accumulate case. Clear the 8x8 accumulator tile. + movi v16.16b, 0 + movi v17.16b, 0 + movi v18.16b, 0 + movi v19.16b, 0 + movi v20.16b, 0 + movi v21.16b, 0 + movi v22.16b, 0 + movi v23.16b, 0 + movi v24.16b, 0 + movi v25.16b, 0 + movi v26.16b, 0 + movi v27.16b, 0 + movi v28.16b, 0 + movi v29.16b, 0 + movi v30.16b, 0 + movi v31.16b, 0 + b 2f + + 1: + // Accumulate case. Load the 8x8 accumulator tile from row-major + // out_tile, into temporary registers v16--v31. + ldp q16, q17, [x0, 0] + ldp q18, q19, [x0, 32] + ldp q20, q21, [x0, 64] + ldp q22, q23, [x0, 96] + ldp q24, q25, [x0, 128] + ldp q26, q27, [x0, 160] + ldp q28, q29, [x0, 192] + ldp q30, q31, [x0, 224] + + 2: + // Loop body. Decrement the loop counter K. + subs w3, w3, 1 + // Load 8x4 LHS tile + ldp q0, q1, [x1, 0] + add x1, x1, 32 + // Load 8x4 RHS tile + ldp q4, q5, [x2, 0] + add x2, x2, 32 + // Multiply-accumulate, row 0. + sdot v16.4s, v4.16b, v0.4b[0] + sdot v17.4s, v5.16b, v0.4b[0] + // Multiply-accumulate, row 1. + sdot v18.4s, v4.16b, v0.4b[1] + sdot v19.4s, v5.16b, v0.4b[1] + // Multiply-accumulate, row 2. + sdot v20.4s, v4.16b, v0.4b[2] + sdot v21.4s, v5.16b, v0.4b[2] + // Multiply-accumulate, row 3. + sdot v22.4s, v4.16b, v0.4b[3] + sdot v23.4s, v5.16b, v0.4b[3] + // Multiply-accumulate, row 4. + sdot v24.4s, v4.16b, v1.4b[0] + sdot v25.4s, v5.16b, v1.4b[0] + // Multiply-accumulate, row 5. + sdot v26.4s, v4.16b, v1.4b[1] + sdot v27.4s, v5.16b, v1.4b[1] + // Multiply-accumulate, row 6. + sdot v28.4s, v4.16b, v1.4b[2] + sdot v29.4s, v5.16b, v1.4b[2] + // Multiply-accumulate, row 7. + sdot v30.4s, v4.16b, v1.4b[3] + sdot v31.4s, v5.16b, v1.4b[3] + // Loop if K != 0. + b.ne 2b + + 3: + // Store the accumulator tile to the destination. + stp q16, q17, [x0, 0] + stp q18, q19, [x0, 32] + stp q20, q21, [x0, 64] + stp q22, q23, [x0, 96] + stp q24, q25, [x0, 128] + stp q26, q27, [x0, 160] + stp q28, q29, [x0, 192] + stp q30, q31, [x0, 224] + ret + +END_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod + +ALLOW_NON_EXECUTABLE_STACK diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_i8mm.S b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_i8mm.S new file mode 100644 index 000000000000..827ce63c0f26 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_i8mm.S @@ -0,0 +1,293 @@ +#include "assembly.h" + +// TODO: share these bits with C/C++. +.equ ACCUMULATE_FLAG_BIT_POS,0 + +// Parameters: +// x0: int32_t* out_tile +// x1: const int8_t* lhs_panel +// x2: const int8_t* rhs_panel +// w3: int32_t K. Note: K>=1, as the K==0 case was handled as an early-return. +// w4: uint32_t flags +// x5: (UNUSED) params - relevant params K and flags already passed above. + +BEGIN_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm + + // Save callee-saved NEON registers + stp d8, d9, [sp, -64]! + stp d10, d11, [sp, 16] + stp d12, d13, [sp, 32] + stp d14, d15, [sp, 48] + + // Do we accumulate into or clear the accumulator tile? + tbnz w4, ACCUMULATE_FLAG_BIT_POS, 1f + + 0: + // No-accumulate case. Clear the 8x8 accumulator tile. + movi v16.16b, 0 + movi v17.16b, 0 + movi v18.16b, 0 + movi v19.16b, 0 + movi v20.16b, 0 + movi v21.16b, 0 + movi v22.16b, 0 + movi v23.16b, 0 + movi v24.16b, 0 + movi v25.16b, 0 + movi v26.16b, 0 + movi v27.16b, 0 + movi v28.16b, 0 + movi v29.16b, 0 + movi v30.16b, 0 + movi v31.16b, 0 + b 2f + + 1: + // Accumulate case. Load the 8x8 accumulator tile from row-major + // out_tile and swizzle it into 2x2 tiled layout. + // + // Load rows 0--3. + ldp q0, q1, [x0, 0] + ldp q2, q3, [x0, 32] + ldp q4, q5, [x0, 64] + ldp q6, q7, [x0, 96] + // Load rows 4--7. + ldp q8, q9, [x0, 128] + ldp q10, q11, [x0, 160] + ldp q12, q13, [x0, 192] + ldp q14, q15, [x0, 224] + // Swizzle in 2x2 tiles for smmla, rows 0--1. + zip1 v16.2d, v0.2d, v2.2d + zip2 v17.2d, v0.2d, v2.2d + zip1 v18.2d, v1.2d, v3.2d + zip2 v19.2d, v1.2d, v3.2d + // Swizzle in 2x2 tiles for smmla, rows 2--3. + zip1 v20.2d, v4.2d, v6.2d + zip2 v21.2d, v4.2d, v6.2d + zip1 v22.2d, v5.2d, v7.2d + zip2 v23.2d, v5.2d, v7.2d + // Swizzle in 2x2 tiles for smmla, rows 4--5. + zip1 v24.2d, v8.2d, v10.2d + zip2 v25.2d, v8.2d, v10.2d + zip1 v26.2d, v9.2d, v11.2d + zip2 v27.2d, v9.2d, v11.2d + // Swizzle in 2x2 tiles for smmla, rows 6--7. + zip1 v28.2d, v12.2d, v14.2d + zip2 v29.2d, v12.2d, v14.2d + zip1 v30.2d, v13.2d, v15.2d + zip2 v31.2d, v13.2d, v15.2d + + 2: + + // Start of math work. If K==1, jump over the whole main loop. + subs w3, w3, 1 + b.eq 6f + + 3: + // Prologue of main loop, 2x partially unrolled, for when K>=2. + // + // Decrement the loop counter K. + subs w3, w3, 2 + // Pre-load data for first loop iteration + // + // Load 8x8 LHS tile + ldp q0, q1, [x1], 32 + ldp q2, q3, [x1], 32 + // Load 8x8 RHS tile + ldp q4, q5, [x2], 32 + ldp q6, q7, [x2], 32 + // Load 8x8 LHS tile + ldp q8, q9, [x1], 32 + ldp q10, q11, [x1], 32 + // Load 8x8 RHS tile... + ldp q12, q13, [x2], 32 + // ...second half loads is kept inside the loop below. + // + // Multiply-accumulate, rows 0--1. + smmla v16.4s, v0.16b, v4.16b + smmla v17.4s, v0.16b, v5.16b + smmla v18.4s, v0.16b, v6.16b + smmla v19.4s, v0.16b, v7.16b + + // If K==2, jump to the epilogue. + b.le 5f + + 4: + // Body of main loop, 2x partially unrolled, for when K>2. + // + // Multiply-accumulate, rows 2--3. + smmla v20.4s, v1.16b, v4.16b + smmla v21.4s, v1.16b, v5.16b + smmla v22.4s, v1.16b, v6.16b + smmla v23.4s, v1.16b, v7.16b + ldp q14, q15, [x2], 32 + // Multiply-accumulate, rows 4--5. + smmla v24.4s, v2.16b, v4.16b + smmla v25.4s, v2.16b, v5.16b + smmla v26.4s, v2.16b, v6.16b + smmla v27.4s, v2.16b, v7.16b + ldp q0, q1, [x1], 32 + // Multiply-accumulate, rows 6--7. + smmla v28.4s, v3.16b, v4.16b + smmla v29.4s, v3.16b, v5.16b + smmla v30.4s, v3.16b, v6.16b + smmla v31.4s, v3.16b, v7.16b + ldp q2, q3, [x1], 32 + // Multiply-accumulate, rows 0--1. + smmla v16.4s, v8.16b, v12.16b + smmla v17.4s, v8.16b, v13.16b + smmla v18.4s, v8.16b, v14.16b + smmla v19.4s, v8.16b, v15.16b + ldp q4, q5, [x2], 32 + // Multiply-accumulate, rows 2--3. + smmla v20.4s, v9.16b, v12.16b + smmla v21.4s, v9.16b, v13.16b + smmla v22.4s, v9.16b, v14.16b + smmla v23.4s, v9.16b, v15.16b + ldp q6, q7, [x2], 32 + // Multiply-accumulate, rows 4--5. + smmla v24.4s, v10.16b, v12.16b + smmla v25.4s, v10.16b, v13.16b + smmla v26.4s, v10.16b, v14.16b + smmla v27.4s, v10.16b, v15.16b + ldp q8, q9, [x1], 32 + // Multiply-accumulate, rows 6--7. + smmla v28.4s, v11.16b, v12.16b + smmla v29.4s, v11.16b, v13.16b + smmla v30.4s, v11.16b, v14.16b + smmla v31.4s, v11.16b, v15.16b + ldp q10, q11, [x1], 32 + // Multiply-accumulate, rows 0--1. + smmla v16.4s, v0.16b, v4.16b + smmla v17.4s, v0.16b, v5.16b + ldp q12, q13, [x2], 32 + smmla v18.4s, v0.16b, v6.16b + subs w3, w3, 2 + smmla v19.4s, v0.16b, v7.16b + b.gt 4b + + 5: + // Epilogue of main loop, 2x partially unrolled, for when K>2. + // + // Load last chunk of last RHS tile. + ldp q14, q15, [x2], 32 + + // Multiply-accumulate, rows 2--3. + smmla v20.4s, v1.16b, v4.16b + smmla v21.4s, v1.16b, v5.16b + smmla v22.4s, v1.16b, v6.16b + smmla v23.4s, v1.16b, v7.16b + // Multiply-accumulate, rows 4--5. + smmla v24.4s, v2.16b, v4.16b + smmla v25.4s, v2.16b, v5.16b + smmla v26.4s, v2.16b, v6.16b + smmla v27.4s, v2.16b, v7.16b + // Multiply-accumulate, rows 6--7. + smmla v28.4s, v3.16b, v4.16b + smmla v29.4s, v3.16b, v5.16b + smmla v30.4s, v3.16b, v6.16b + smmla v31.4s, v3.16b, v7.16b + + // Multiply-accumulate, rows 0--1. + smmla v16.4s, v8.16b, v12.16b + smmla v17.4s, v8.16b, v13.16b + smmla v18.4s, v8.16b, v14.16b + smmla v19.4s, v8.16b, v15.16b + // Multiply-accumulate, rows 2--3. + smmla v20.4s, v9.16b, v12.16b + smmla v21.4s, v9.16b, v13.16b + smmla v22.4s, v9.16b, v14.16b + smmla v23.4s, v9.16b, v15.16b + // Multiply-accumulate, rows 4--5. + smmla v24.4s, v10.16b, v12.16b + smmla v25.4s, v10.16b, v13.16b + smmla v26.4s, v10.16b, v14.16b + smmla v27.4s, v10.16b, v15.16b + // Multiply-accumulate, rows 6--7. + smmla v28.4s, v11.16b, v12.16b + smmla v29.4s, v11.16b, v13.16b + smmla v30.4s, v11.16b, v14.16b + smmla v31.4s, v11.16b, v15.16b + + // Finished accumulating? Then jump to final store. + b.lt 7f + // Fall through. + + 6: + // Accumulate for a single K-value - used for either the K==1 case or + // final value of K for odd K>1. + + // Load 8x8 LHS tile + ldp q0, q1, [x1, 0] + ldp q2, q3, [x1, 32] + add x1, x1, 64 + // Load 8x8 RHS tile + ldp q4, q5, [x2, 0] + ldp q6, q7, [x2, 32] + add x2, x2, 64 + // Multiply-accumulate, rows 0--1. + smmla v16.4s, v0.16b, v4.16b + smmla v17.4s, v0.16b, v5.16b + smmla v18.4s, v0.16b, v6.16b + smmla v19.4s, v0.16b, v7.16b + // Multiply-accumulate, rows 2--3. + smmla v20.4s, v1.16b, v4.16b + smmla v21.4s, v1.16b, v5.16b + smmla v22.4s, v1.16b, v6.16b + smmla v23.4s, v1.16b, v7.16b + // Multiply-accumulate, rows 4--5. + smmla v24.4s, v2.16b, v4.16b + smmla v25.4s, v2.16b, v5.16b + smmla v26.4s, v2.16b, v6.16b + smmla v27.4s, v2.16b, v7.16b + // Multiply-accumulate, rows 6--7. + smmla v28.4s, v3.16b, v4.16b + smmla v29.4s, v3.16b, v5.16b + smmla v30.4s, v3.16b, v6.16b + smmla v31.4s, v3.16b, v7.16b + + 7: + // Done accumulating. + // + // Swizzle back to row-major and store to destination. + // + // Swizzle back to row-major, rows 0--1. + uzp1 v0.2d, v16.2d, v17.2d + uzp1 v1.2d, v18.2d, v19.2d + uzp2 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v18.2d, v19.2d + // Swizzle back to row-major, rows 2--3. + uzp1 v4.2d, v20.2d, v21.2d + uzp1 v5.2d, v22.2d, v23.2d + uzp2 v6.2d, v20.2d, v21.2d + uzp2 v7.2d, v22.2d, v23.2d + // Swizzle back to row-major, rows 4--5. + uzp1 v8.2d, v24.2d, v25.2d + uzp1 v9.2d, v26.2d, v27.2d + uzp2 v10.2d, v24.2d, v25.2d + uzp2 v11.2d, v26.2d, v27.2d + // Swizzle back to row-major, rows 6--7. + uzp1 v12.2d, v28.2d, v29.2d + uzp1 v13.2d, v30.2d, v31.2d + uzp2 v14.2d, v28.2d, v29.2d + uzp2 v15.2d, v30.2d, v31.2d + // Store rows 0--3 to destination. + stp q0, q1, [x0, 0] + stp q2, q3, [x0, 32] + stp q4, q5, [x0, 64] + stp q6, q7, [x0, 96] + stp q8, q9, [x0, 128] + stp q10, q11, [x0, 160] + stp q12, q13, [x0, 192] + stp q14, q15, [x0, 224] + + // Restore callee-saved NEON registers + ldp d14, d15, [sp, 48] + ldp d12, d13, [sp, 32] + ldp d10, d11, [sp, 16] + ldp d8, d9, [sp], 64 + ret + +END_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm + +ALLOW_NON_EXECUTABLE_STACK diff --git a/runtime/src/iree/builtins/ukernel/arch/config.h.bazel-generic b/runtime/src/iree/builtins/ukernel/arch/config.h.bazel-generic new file mode 100644 index 000000000000..ade3064530a4 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/config.h.bazel-generic @@ -0,0 +1 @@ +#define IREE_UKERNEL_POINTER_SIZE 8 diff --git a/runtime/src/iree/builtins/ukernel/arch/config.h.in b/runtime/src/iree/builtins/ukernel/arch/config.h.in new file mode 100644 index 000000000000..a9641452987c --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/config.h.in @@ -0,0 +1,2 @@ +#cmakedefine IREE_UKERNEL_POINTER_SIZE ${IREE_UKERNEL_POINTER_SIZE} +#cmakedefine IREE_UKERNEL_ARCH_ARM_64 diff --git a/runtime/src/iree/builtins/ukernel/arch/mmt4d_select_tile_arch.c b/runtime/src/iree/builtins/ukernel/arch/mmt4d_select_tile_arch.c new file mode 100644 index 000000000000..5d452a7dfb65 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/mmt4d_select_tile_arch.c @@ -0,0 +1,19 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/builtins/ukernel/arch/mmt4d_select_tile_arch.h" + +#if defined(IREE_UKERNEL_ARCH_ARM_64) +#include "iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.h" +#endif + +iree_ukernel_mmt4d_tile_func_t iree_ukernel_mmt4d_select_tile_func_arch( + const iree_ukernel_mmt4d_params_t* params) { +#if defined(IREE_UKERNEL_ARCH_ARM_64) + return iree_ukernel_mmt4d_select_tile_func_arm_64(params); +#endif + return 0; +} diff --git a/runtime/src/iree/builtins/ukernel/arch/mmt4d_select_tile_arch.h b/runtime/src/iree/builtins/ukernel/arch/mmt4d_select_tile_arch.h new file mode 100644 index 000000000000..3ebc4bf8393d --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/mmt4d_select_tile_arch.h @@ -0,0 +1,19 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BUILTINS_UKERNEL_ARCH_MMT4D_SELECT_TILE_ARCH_H_ +#define IREE_BUILTINS_UKERNEL_ARCH_MMT4D_SELECT_TILE_ARCH_H_ + +#include "iree/builtins/ukernel/mmt4d_types.h" + +// Returns the architecture-specific tile function to use for the mmt4d with +// given params, or NULL if no suitable architecture-specific tile function +// exists for these params, in which case the caller may fall back to a generic +// tile function. +iree_ukernel_mmt4d_tile_func_t iree_ukernel_mmt4d_select_tile_func_arch( + const iree_ukernel_mmt4d_params_t* params); + +#endif // IREE_BUILTINS_UKERNEL_ARCH_MMT4D_SELECT_TILE_ARCH_H_ diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h index 0ed9455f7bb7..f97c7a4000cf 100644 --- a/runtime/src/iree/builtins/ukernel/common.h +++ b/runtime/src/iree/builtins/ukernel/common.h @@ -34,37 +34,24 @@ // that can be substituted by the IREE compiler when producing the final // target-specific module. +// Include the build-system-generated configured header and use it as the only +// source of information about the target we're compiling against, as opposed to +// including iree/base/target_platform.h. +// +// For example, using IREE_UKERNEL_ARCH_ARM_64 (from arch/config.h) rather than +// IREE_ARCH_ARM_64 (from target_platform.h) means that we can control from a +// single place in the build system whether we enable ARM_64-specific code paths +// or stick to generic code. +#include "iree/builtins/ukernel/arch/config.h" + // We require that this header compile on bare-metal targets with no stdlib. -// These two headers are clean and do not include any other headers: +// These headers are clean: #include "iree/base/attributes.h" -#include "iree/base/target_platform.h" -#include "iree/schemas/cpu_data.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus -//===----------------------------------------------------------------------===// -// Target architecture selection -//===----------------------------------------------------------------------===// -// The "generic" target is used if no other target is specified. All platforms -// should support the generic path and then optionally provide their own -// specializations as needed. The generic path can be forced by passing -// -DIREE_MMT4D_ARCH_GENERIC_32=1 or -DIREE_MMT4D_ARCH_GENERIC_64=1 to the -// compiler in addition to -DIREE_PLATFORM_GENERIC=1. - -#if defined(IREE_UKERNEL_ARCH_GENERIC_32) -#define IREE_UKERNEL_SIZE_TYPE int32_t -#elif defined(IREE_UKERNEL_ARCH_GENERIC_64) -#define IREE_UKERNEL_SIZE_TYPE int64_t -#elif defined(IREE_ARCH_ARM_64) -#define IREE_UKERNEL_ARCH_ARM_64 1 -#define IREE_UKERNEL_SIZE_TYPE int64_t -#else -#define IREE_UKERNEL_ARCH_GENERIC_64 1 -#define IREE_UKERNEL_SIZE_TYPE int64_t -#endif // IREE_ARCH_* - //===----------------------------------------------------------------------===// // Attributes and metadata //===----------------------------------------------------------------------===// @@ -116,11 +103,17 @@ typedef unsigned long long uint64_t; #endif // !INT8_MIN -// Use iree_mmt4d_size_t for all sizes that may need pointer width. +// Use iree_ukernel_ssize_t for all sizes that may need pointer width. // For any argument that is known to fit in a specific size prefer that to // ensure this code operates well on systems with small/weird widths (x32/ilp32, // etc). -typedef IREE_UKERNEL_SIZE_TYPE iree_ukernel_size_t; +#if IREE_UKERNEL_POINTER_SIZE == 4 +typedef int32_t iree_ukernel_ssize_t; +#elif IREE_UKERNEL_POINTER_SIZE == 8 +typedef int64_t iree_ukernel_ssize_t; +#else +#error Unexpected pointer size +#endif #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree/builtins/ukernel/elementwise.h b/runtime/src/iree/builtins/ukernel/elementwise.h index 5d2b35b58df5..f4818d344e82 100644 --- a/runtime/src/iree/builtins/ukernel/elementwise.h +++ b/runtime/src/iree/builtins/ukernel/elementwise.h @@ -21,26 +21,26 @@ extern "C" { // It takes lhs, rhs, out buffers and size, returning 0 on success and !0 on // error. typedef int (*iree_ukernel_x32b_2d_func_t)( - const uint32_t* lhs, iree_ukernel_size_t lhs_offset, - iree_ukernel_size_t lhs_stride0, iree_ukernel_size_t lhs_stride1, - const uint32_t* rhs, iree_ukernel_size_t rhs_offset, - iree_ukernel_size_t rhs_stride0, iree_ukernel_size_t rhs_stride1, - uint32_t* out, iree_ukernel_size_t out_offset, - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, - iree_ukernel_size_t size0, iree_ukernel_size_t size1); + const uint32_t* lhs, iree_ukernel_ssize_t lhs_offset, + iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1, + const uint32_t* rhs, iree_ukernel_ssize_t rhs_offset, + iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1, + uint32_t* out, iree_ukernel_ssize_t out_offset, + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1); // Declares a binary 2d microkernel with the following signature: // int iree_ukernel_{category}_{opcode}_2d(...) // of function type iree_ukernel_{category}_2d_func_t. -#define DECLARE_UKERNEL_BINARY_2D(opcode, dtype, category) \ - IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \ - const dtype* lhs, iree_ukernel_size_t lhs_offset, \ - iree_ukernel_size_t lhs_stride0, iree_ukernel_size_t lhs_stride1, \ - const dtype* rhs, iree_ukernel_size_t rhs_offset, \ - iree_ukernel_size_t rhs_stride0, iree_ukernel_size_t rhs_stride1, \ - dtype* IREE_RESTRICT out, iree_ukernel_size_t out_offset, \ - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, \ - iree_ukernel_size_t size0, iree_ukernel_size_t size1) +#define DECLARE_UKERNEL_BINARY_2D(opcode, dtype, category) \ + IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \ + const dtype* lhs, iree_ukernel_ssize_t lhs_offset, \ + iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1, \ + const dtype* rhs, iree_ukernel_ssize_t rhs_offset, \ + iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1, \ + dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \ + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \ + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) DECLARE_UKERNEL_BINARY_2D(addf, uint32_t, x32b); DECLARE_UKERNEL_BINARY_2D(addi, uint32_t, x32b); @@ -66,23 +66,23 @@ DECLARE_UKERNEL_BINARY_2D(xori, uint32_t, x32b); // It takes in, out buffers and size, returning 0 on success and !0 on // error. typedef int (*iree_ukernel_x32u_2d_func_t)( - const uint32_t* in, iree_ukernel_size_t in_offset, - iree_ukernel_size_t in_stride0, iree_ukernel_size_t in_stride1, - uint32_t* out, iree_ukernel_size_t out_offset, - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, - iree_ukernel_size_t size0, iree_ukernel_size_t size1); + const uint32_t* in, iree_ukernel_ssize_t in_offset, + iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1, + uint32_t* out, iree_ukernel_ssize_t out_offset, + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1); // Declares a binary 2d microkernel with the following signature: // int iree_ukernel_{category}_{opcode}_2d(...) // It takes lhs, rhs, out buffers and size, returning 0 on success and !0 on // error. -#define DECLARE_UKERNEL_UNARY_2D(opcode, dtype, category) \ - IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \ - const dtype* in, iree_ukernel_size_t in_offset, \ - iree_ukernel_size_t in_stride0, iree_ukernel_size_t in_stride1, \ - dtype* IREE_RESTRICT out, iree_ukernel_size_t out_offset, \ - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, \ - iree_ukernel_size_t size0, iree_ukernel_size_t size1) +#define DECLARE_UKERNEL_UNARY_2D(opcode, dtype, category) \ + IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \ + const dtype* in, iree_ukernel_ssize_t in_offset, \ + iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1, \ + dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \ + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \ + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) DECLARE_UKERNEL_UNARY_2D(absf, uint32_t, x32u); DECLARE_UKERNEL_UNARY_2D(ceilf, uint32_t, x32u); diff --git a/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc b/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc index 338919d16775..3f4ee4024dc2 100644 --- a/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc +++ b/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc @@ -111,13 +111,13 @@ static inline int iree_ukernel_count_leading_zeros_u32(const uint32_t n) { // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D. #define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category) \ IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \ - const dtype* lhs, iree_ukernel_size_t lhs_offset, \ - iree_ukernel_size_t lhs_stride0, iree_ukernel_size_t lhs_stride1, \ - const dtype* rhs, iree_ukernel_size_t rhs_offset, \ - iree_ukernel_size_t rhs_stride0, iree_ukernel_size_t rhs_stride1, \ - dtype* IREE_RESTRICT out, iree_ukernel_size_t out_offset, \ - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, \ - iree_ukernel_size_t size0, iree_ukernel_size_t size1) { \ + const dtype* lhs, iree_ukernel_ssize_t lhs_offset, \ + iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1, \ + const dtype* rhs, iree_ukernel_ssize_t rhs_offset, \ + iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1, \ + dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \ + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \ + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) { \ return iree_ukernel_generic_##category##_2d( \ opcode_t, lhs, lhs_offset, lhs_stride0, lhs_stride1, rhs, rhs_offset, \ rhs_stride0, rhs_stride1, out, out_offset, out_stride0, out_stride1, \ @@ -129,11 +129,11 @@ static inline int iree_ukernel_count_leading_zeros_u32(const uint32_t n) { // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D. #define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category) \ IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \ - const dtype* in, iree_ukernel_size_t in_offset, \ - iree_ukernel_size_t in_stride0, iree_ukernel_size_t in_stride1, \ - dtype* IREE_RESTRICT out, iree_ukernel_size_t out_offset, \ - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, \ - iree_ukernel_size_t size0, iree_ukernel_size_t size1) { \ + const dtype* in, iree_ukernel_ssize_t in_offset, \ + iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1, \ + dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \ + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \ + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) { \ return iree_ukernel_generic_##category##_2d( \ opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset, \ out_stride0, out_stride1, size0, size1); \ @@ -242,20 +242,20 @@ static void iree_ukernel_generic_x32u_op(iree_ukernel_x32u_opcode_t opcode, static int iree_ukernel_generic_x32b_2d( iree_ukernel_x32b_opcode_t opcode, // LHS. - const uint32_t* lhs, iree_ukernel_size_t lhs_offset, - iree_ukernel_size_t lhs_stride0, iree_ukernel_size_t lhs_stride1, + const uint32_t* lhs, iree_ukernel_ssize_t lhs_offset, + iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1, // RHS - const uint32_t* rhs, iree_ukernel_size_t rhs_offset, - iree_ukernel_size_t rhs_stride0, iree_ukernel_size_t rhs_stride1, + const uint32_t* rhs, iree_ukernel_ssize_t rhs_offset, + iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1, // OUT. - uint32_t* IREE_RESTRICT out, iree_ukernel_size_t out_offset, - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, + uint32_t* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, // Sizes. - iree_ukernel_size_t size0, iree_ukernel_size_t size1) { + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) { int result_code = 0; // TODO: Manually unroll to x4 to trigger vectorization. - for (iree_ukernel_size_t i = 0; i < size0; ++i) { - for (iree_ukernel_size_t j = 0; j < size1; ++j) { + for (iree_ukernel_ssize_t i = 0; i < size0; ++i) { + for (iree_ukernel_ssize_t j = 0; j < size1; ++j) { iree_ukernel_generic_x32b_op(opcode, &result_code, &lhs[i * lhs_stride0 + j * lhs_stride1], &rhs[i * rhs_stride0 + j * rhs_stride1], @@ -269,17 +269,17 @@ static int iree_ukernel_generic_x32b_2d( static int iree_ukernel_generic_x32u_2d( iree_ukernel_x32u_opcode_t opcode, // IN. - const uint32_t* in, iree_ukernel_size_t in_offset, - iree_ukernel_size_t in_stride0, iree_ukernel_size_t in_stride1, + const uint32_t* in, iree_ukernel_ssize_t in_offset, + iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1, // OUT. - uint32_t* IREE_RESTRICT out, iree_ukernel_size_t out_offset, - iree_ukernel_size_t out_stride0, iree_ukernel_size_t out_stride1, + uint32_t* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, + iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, // Sizes. - iree_ukernel_size_t size0, iree_ukernel_size_t size1) { + iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) { int result_code = 0; // TODO: Manually unroll to x4 to trigger vectorization. - for (iree_ukernel_size_t i = 0; i < size0; ++i) { - for (iree_ukernel_size_t j = 0; j < size1; ++j) { + for (iree_ukernel_ssize_t i = 0; i < size0; ++i) { + for (iree_ukernel_ssize_t j = 0; j < size1; ++j) { iree_ukernel_generic_x32u_op(opcode, &result_code, &in[i * in_stride0 + j * in_stride1], &out[i * out_stride0 + j * out_stride1]); diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c index 7ea816d66045..04f20a439e9a 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d.c +++ b/runtime/src/iree/builtins/ukernel/mmt4d.c @@ -6,57 +6,178 @@ #include "iree/builtins/ukernel/mmt4d.h" -#if defined(IREE_UKERNEL_ARCH_ARM_64) -#include "iree/builtins/ukernel/mmt4d_arm_64.h" -#endif +#include -#if defined(IREE_UKERNEL_ARCH_GENERIC_32) || \ - defined(IREE_UKERNEL_ARCH_GENERIC_64) -#include "iree/builtins/ukernel/mmt4d_generic.h" -#endif +#include "iree/builtins/ukernel/arch/mmt4d_select_tile_arch.h" +#include "iree/builtins/ukernel/mmt4d_select_tile_generic.h" -IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_f32f32f32( - const iree_ukernel_mmt4d_f32f32f32_params_t* params) { +#define OUTSIDE_UINT_RANGE(value, bits) (((value) < 0) || ((value) >> (bits))) + +static iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_validate( + const iree_ukernel_mmt4d_params_t* params) { if (params->flags & ~IREE_VMVX_MATMUL_FLAG_ACCUMULATE) { - return IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS; + return iree_ukernel_mmt4d_status_bad_flags; + } + switch (params->type) { + case iree_ukernel_mmt4d_type_f32f32f32: + case iree_ukernel_mmt4d_type_i8i8i32: + break; + default: + return iree_ukernel_mmt4d_status_bad_type; } + // Some implementations may wish to avoid supporting absurdly wide types. For + // instance, K is the innermost (i.e. hottest) loop bound, so some 32bit + // targets may benefit from K being int32, not int64. We still let K be of + // type int64 to be future-proof, as types are hard to change later. But we + // enforce a narrower range here, as we can always relax that later as needed. + if (OUTSIDE_UINT_RANGE(params->M, 31) || OUTSIDE_UINT_RANGE(params->M, 31) || + OUTSIDE_UINT_RANGE(params->K, 31) || OUTSIDE_UINT_RANGE(params->M0, 15) || + OUTSIDE_UINT_RANGE(params->N0, 15) || + OUTSIDE_UINT_RANGE(params->K0, 15)) { + return iree_ukernel_mmt4d_status_unsupported_huge_or_negative_dimension; + } + return iree_ukernel_mmt4d_status_ok; +} -#if defined(IREE_UKERNEL_ARCH_ARM_64) - return iree_ukernel_mmt4d_f32f32f32_arm_64(params); -#endif +// On success, *out_tile_func is the tile function to use to perform the mmt4d +// with the given *params. +static iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_select_tile_func( + const iree_ukernel_mmt4d_params_t* params, + iree_ukernel_mmt4d_tile_func_t* out_tile_func) { + iree_ukernel_mmt4d_tile_func_t arch_tile_func = + iree_ukernel_mmt4d_select_tile_func_arch(params); + if (arch_tile_func) { + *out_tile_func = arch_tile_func; + return iree_ukernel_mmt4d_status_ok; + } + return iree_ukernel_mmt4d_select_tile_func_generic(params, out_tile_func); +} -#if defined(IREE_UKERNEL_ARCH_GENERIC_32) || \ - defined(IREE_UKERNEL_ARCH_GENERIC_64) - return iree_ukernel_mmt4d_f32f32f32_generic(params); -#endif +// General mmt4d implementation, shared among all cases. The idea is that the +// only really performance-critical part is the inner-most loop, and that's +// handled by the tile_func passed as argument here. Sharing the outer loops +// across all cases is a roughly 2x code shrink compared to if we were +// emitting the whole loop nest for each case. +static void iree_ukernel_mmt4d_using_tile_func( + const iree_ukernel_mmt4d_params_t* params, + iree_ukernel_mmt4d_tile_func_t tile_func) { + const int32_t M = params->M; + const int32_t N = params->N; + const int32_t K = params->K; + const int16_t M0 = params->M0; + const int16_t N0 = params->N0; + const int16_t lhs_elem_size_log2 = + iree_ukernel_mmt4d_lhs_elem_size_log2(params->type); + const int16_t rhs_elem_size_log2 = + iree_ukernel_mmt4d_rhs_elem_size_log2(params->type); + const int16_t out_elem_size_log2 = + iree_ukernel_mmt4d_out_elem_size_log2(params->type); + char* out_tile_row = params->out_buffer; + const char* lhs_panel = params->lhs_buffer; + int32_t out_tile_size = (M0 * N0) << out_elem_size_log2; + iree_ukernel_ssize_t lhs_panel_stride = params->lhs_stride + << lhs_elem_size_log2; + iree_ukernel_ssize_t rhs_panel_stride = params->rhs_stride + << rhs_elem_size_log2; + iree_ukernel_ssize_t out_stride = params->out_stride << out_elem_size_log2; + for (int32_t i = 0; i < M; ++i) { + char* out_tile = out_tile_row; + const char* rhs_panel = params->rhs_buffer; + for (int32_t j = 0; j < N; ++j) { + tile_func(out_tile, lhs_panel, rhs_panel, K, params->flags, params); + out_tile += out_tile_size; + rhs_panel += rhs_panel_stride; + } + out_tile_row += out_stride; + lhs_panel += lhs_panel_stride; + } +} - return IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED; +// A memset implementation that we can use here, as we can't #include +// as that brings in . Special-cased for byte value 0. +void iree_ukernel_memset_zero(void* buf, iree_ukernel_ssize_t n) { + // No need for memset builtins: this naive loop is already transformed into a + // memset by both clang and gcc on ARM64. As __builtin_memset_inline requires + // a compile-time-constant size, it would require writing more complex code, + // which could actually prevent the optimization matching it as a single + // memset! + for (iree_ukernel_ssize_t i = 0; i < n; ++i) ((char*)buf)[i] = 0; } -IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_i8i8i32( - const iree_ukernel_mmt4d_i8i8i32_params_t* params) { - if (params->flags & ~IREE_VMVX_MATMUL_FLAG_ACCUMULATE) { - return IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS; +// Helper for early-return path when K==0 and we just need to clear the output. +static void iree_ukernel_mmt4d_zero_out( + const iree_ukernel_mmt4d_params_t* params) { + iree_ukernel_ssize_t contiguous_size = + params->N * params->M0 * params->N0 + << iree_ukernel_mmt4d_out_elem_size_log2(params->type); + iree_ukernel_ssize_t stride = + params->out_stride << iree_ukernel_mmt4d_out_elem_size_log2(params->type); + char* out_ptr = params->out_buffer; + for (iree_ukernel_ssize_t i = 0; i < params->M; ++i) { + iree_ukernel_memset_zero(out_ptr, contiguous_size); + out_ptr += stride; + } +} + +// Early-return code paths, including trivial or near-trivial cases (when one +// of the dimensions is 0) and in the future, hardware ports that specialize +// the entire loop nest. +// The value |true| is written to the out-param |*done| if an early-return path +// was taken and the mmt4d work is already done. +static iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_early( + const iree_ukernel_mmt4d_params_t* params, bool* done) { + // Trivial cases + if (params->M == 0 || params->N == 0) { + *done = true; + return iree_ukernel_mmt4d_status_ok; } + if (params->K == 0) { + if (params->flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE) { + // Nothing to do! + } else { + iree_ukernel_mmt4d_zero_out(params); + } + *done = true; + return iree_ukernel_mmt4d_status_ok; + } + + // Targets that want to specialize the entire loop nest can do so here. + + return iree_ukernel_mmt4d_status_ok; +} -#if defined(IREE_UKERNEL_ARCH_ARM_64) - return iree_ukernel_mmt4d_i8i8i32_arm_64(params); -#endif +IREE_UKERNEL_EXPORT iree_ukernel_mmt4d_status_t +iree_ukernel_mmt4d(const iree_ukernel_mmt4d_params_t* params) { + // Validate params. + IREE_UKERNEL_MMT4D_RETURN_IF_ERROR(iree_ukernel_mmt4d_validate(params)); -#if defined(IREE_UKERNEL_ARCH_GENERIC_32) || \ - defined(IREE_UKERNEL_ARCH_GENERIC_64) - return iree_ukernel_mmt4d_i8i8i32_generic(params); -#endif + // Maybe handle this mmt4d "early", without needing to select a tile_func. + // Typical cases include trivial cases (e.g. when params->K == 0) and hardware + // targets that want to handle the entire loop nest in target-specific code. + bool done = false; + IREE_UKERNEL_MMT4D_RETURN_IF_ERROR(iree_ukernel_mmt4d_early(params, &done)); + if (done) return iree_ukernel_mmt4d_status_ok; - return IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED; + // Select a target-specific tile_func (inner loop on K, computing one M0xN0 + // tile) and use that with generic outer loops. + iree_ukernel_mmt4d_tile_func_t tile_func; + IREE_UKERNEL_MMT4D_RETURN_IF_ERROR( + iree_ukernel_mmt4d_select_tile_func(params, &tile_func)); + iree_ukernel_mmt4d_using_tile_func(params, tile_func); + return iree_ukernel_mmt4d_status_ok; } -const char* iree_ukernel_mmt4d_error_message(int retcode) { - switch (retcode) { - case IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED: - return "hit unimplemented code path in mmt4d"; - case IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS: +IREE_UKERNEL_EXPORT const char* iree_ukernel_mmt4d_status_message( + iree_ukernel_mmt4d_status_t status) { + switch (status) { + case iree_ukernel_mmt4d_status_bad_flags: return "bad mmt4d flags"; + case iree_ukernel_mmt4d_status_bad_type: + return "bad mmt4d type enum"; + case iree_ukernel_mmt4d_status_unsupported_huge_or_negative_dimension: + return "unsupported huge or negative size in mmt4d"; + case iree_ukernel_mmt4d_status_unsupported_generic_tile_size: + return "tile size too large for the generic tile implementation"; default: return "unknown"; } diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.h b/runtime/src/iree/builtins/ukernel/mmt4d.h index 407374a88a95..51e512a09451 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d.h +++ b/runtime/src/iree/builtins/ukernel/mmt4d.h @@ -7,61 +7,19 @@ #ifndef IREE_BUILTINS_UKERNEL_MMT4D_H_ #define IREE_BUILTINS_UKERNEL_MMT4D_H_ -#include "iree/builtins/ukernel/common.h" +#include "iree/builtins/ukernel/mmt4d_types.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus -struct iree_ukernel_mmt4d_f32f32f32_params_t { - const float* lhs_buffer; - const float* rhs_buffer; - float* out_buffer; - iree_ukernel_size_t lhs_stride; - iree_ukernel_size_t rhs_stride; - iree_ukernel_size_t out_stride; - iree_ukernel_size_t M; - iree_ukernel_size_t N; - iree_ukernel_size_t K; - int32_t M0; - int32_t N0; - int32_t K0; - uint32_t flags; -}; +// Main entry point. +IREE_UKERNEL_EXPORT iree_ukernel_mmt4d_status_t +iree_ukernel_mmt4d(const iree_ukernel_mmt4d_params_t* params); -struct iree_ukernel_mmt4d_i8i8i32_params_t { - const int8_t* lhs_buffer; - const int8_t* rhs_buffer; - int32_t* out_buffer; - iree_ukernel_size_t lhs_stride; - iree_ukernel_size_t rhs_stride; - iree_ukernel_size_t out_stride; - iree_ukernel_size_t M; - iree_ukernel_size_t N; - iree_ukernel_size_t K; - int32_t M0; - int32_t N0; - int32_t K0; - uint32_t flags; -}; - -typedef struct iree_ukernel_mmt4d_f32f32f32_params_t - iree_ukernel_mmt4d_f32f32f32_params_t; -typedef struct iree_ukernel_mmt4d_i8i8i32_params_t - iree_ukernel_mmt4d_i8i8i32_params_t; - -#define IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED 1 -#define IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS 2 - -// TODO: move these flags to a header file shared with compiler/. -#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1 - -IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_f32f32f32( - const iree_ukernel_mmt4d_f32f32f32_params_t* params); -IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_i8i8i32( - const iree_ukernel_mmt4d_i8i8i32_params_t* params); - -IREE_UKERNEL_EXPORT const char* iree_ukernel_mmt4d_error_message(int retcode); +// Convert a status code to a human-readable string. +IREE_UKERNEL_EXPORT const char* iree_ukernel_mmt4d_status_message( + iree_ukernel_mmt4d_status_t status); #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.c b/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.c deleted file mode 100644 index ef077b6f67df..000000000000 --- a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.c +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/builtins/ukernel/mmt4d_arm_64.h" - -// TODO: once actual ARM64 code is implemented, we shouldn't need this anymore -#include "iree/builtins/ukernel/mmt4d_generic.h" - -#if defined(IREE_UKERNEL_ARCH_ARM_64) - -int iree_ukernel_mmt4d_f32f32f32_arm_64( - const iree_ukernel_mmt4d_f32f32f32_params_t* params) { - // TODO: implement actual arm assembly kernels instead of calling _generic. - return iree_ukernel_mmt4d_f32f32f32_generic(params); -} - -int iree_ukernel_mmt4d_i8i8i32_arm_64( - const iree_ukernel_mmt4d_i8i8i32_params_t* params) { - // TODO: implement actual arm assembly kernels instead of calling _generic. - return iree_ukernel_mmt4d_i8i8i32_generic(params); -} - -#endif // IREE_UKERNEL_ARCH_ARM_64 diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.h b/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.h deleted file mode 100644 index 1ea08fa575d3..000000000000 --- a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_BUILTINS_UKERNEL_MMT4D_ARM_64_H_ -#define IREE_BUILTINS_UKERNEL_MMT4D_ARM_64_H_ - -#include "iree/builtins/ukernel/mmt4d.h" - -#if defined(IREE_UKERNEL_ARCH_ARM_64) - -int iree_ukernel_mmt4d_f32f32f32_arm_64( - const iree_ukernel_mmt4d_f32f32f32_params_t* params); -int iree_ukernel_mmt4d_i8i8i32_arm_64( - const iree_ukernel_mmt4d_i8i8i32_params_t* params); - -#endif // IREE_UKERNEL_ARCH_ARM_64 - -#endif // IREE_BUILTINS_UKERNEL_MMT4D_ARM_64_H_ diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_generic.c b/runtime/src/iree/builtins/ukernel/mmt4d_generic.c deleted file mode 100644 index cc9eeb40e8b8..000000000000 --- a/runtime/src/iree/builtins/ukernel/mmt4d_generic.c +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/builtins/ukernel/mmt4d_generic.h" - -#include - -int iree_ukernel_mmt4d_f32f32f32_generic( - const iree_ukernel_mmt4d_f32f32f32_params_t* params) { - bool accumulate = params->flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE; - iree_ukernel_size_t lhs_tile_size = params->M0 * params->K0; - iree_ukernel_size_t rhs_tile_size = params->N0 * params->K0; - iree_ukernel_size_t out_tile_size = params->M0 * params->N0; - for (iree_ukernel_size_t i = 0; i < params->M; ++i) { - for (iree_ukernel_size_t j = 0; j < params->N; ++j) { - float* out_tile_ptr = - params->out_buffer + i * params->out_stride + j * out_tile_size; - const float* lhs_panel_ptr = params->lhs_buffer + i * params->lhs_stride; - const float* rhs_panel_ptr = params->rhs_buffer + j * params->rhs_stride; - for (iree_ukernel_size_t i0 = 0; i0 < params->M0; ++i0) { - for (iree_ukernel_size_t j0 = 0; j0 < params->N0; ++j0) { - const float* lhs_tile_ptr = lhs_panel_ptr; - const float* rhs_tile_ptr = rhs_panel_ptr; - float* out_ptr = out_tile_ptr + i0 * params->N0 + j0; - float acc = accumulate ? *out_ptr : 0.f; - for (iree_ukernel_size_t k = 0; k < params->K; ++k) { - for (iree_ukernel_size_t k0 = 0; k0 < params->K0; ++k0) { - float lhs_val = lhs_tile_ptr[i0 * params->K0 + k0]; - float rhs_val = rhs_tile_ptr[j0 * params->K0 + k0]; - acc += lhs_val * rhs_val; - } - lhs_tile_ptr += lhs_tile_size; - rhs_tile_ptr += rhs_tile_size; - } - *out_ptr = acc; - } - } - } - } - return 0; -} - -int iree_ukernel_mmt4d_i8i8i32_generic( - const iree_ukernel_mmt4d_i8i8i32_params_t* params) { - bool accumulate = params->flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE; - iree_ukernel_size_t lhs_tile_size = params->M0 * params->K0; - iree_ukernel_size_t rhs_tile_size = params->N0 * params->K0; - iree_ukernel_size_t out_tile_size = params->M0 * params->N0; - for (iree_ukernel_size_t i = 0; i < params->M; ++i) { - for (iree_ukernel_size_t j = 0; j < params->N; ++j) { - int32_t* out_tile_ptr = - params->out_buffer + i * params->out_stride + j * out_tile_size; - const int8_t* lhs_panel_ptr = params->lhs_buffer + i * params->lhs_stride; - const int8_t* rhs_panel_ptr = params->rhs_buffer + j * params->rhs_stride; - for (iree_ukernel_size_t i0 = 0; i0 < params->M0; ++i0) { - for (iree_ukernel_size_t j0 = 0; j0 < params->N0; ++j0) { - const int8_t* lhs_tile_ptr = lhs_panel_ptr; - const int8_t* rhs_tile_ptr = rhs_panel_ptr; - int32_t* out_ptr = out_tile_ptr + i0 * params->N0 + j0; - int32_t acc = accumulate ? *out_ptr : 0; - for (iree_ukernel_size_t k = 0; k < params->K; ++k) { - for (iree_ukernel_size_t k0 = 0; k0 < params->K0; ++k0) { - // C's implicit promotion to int saves skin, but let's be explicit - int32_t lhs_val_int32 = lhs_tile_ptr[i0 * params->K0 + k0]; - int32_t rhs_val_int32 = rhs_tile_ptr[j0 * params->K0 + k0]; - acc += lhs_val_int32 * rhs_val_int32; - } - lhs_tile_ptr += lhs_tile_size; - rhs_tile_ptr += rhs_tile_size; - } - *out_ptr = acc; - } - } - } - } - return 0; -} diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_generic.h b/runtime/src/iree/builtins/ukernel/mmt4d_generic.h deleted file mode 100644 index 5dc0b5de5dab..000000000000 --- a/runtime/src/iree/builtins/ukernel/mmt4d_generic.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_ -#define IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_ - -#include "iree/builtins/ukernel/mmt4d.h" - -int iree_ukernel_mmt4d_f32f32f32_generic( - const iree_ukernel_mmt4d_f32f32f32_params_t* params); -int iree_ukernel_mmt4d_i8i8i32_generic( - const iree_ukernel_mmt4d_i8i8i32_params_t* params); - -#endif // IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_ diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.c b/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.c new file mode 100644 index 000000000000..56ec4d85a97f --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.c @@ -0,0 +1,120 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/builtins/ukernel/mmt4d_select_tile_generic.h" + +// In order to be helpful as a reference for future architecture-specific +// kernels, the generic kernels here are structured like an actual optimized +// kernel, using an "accumulator tile" that in this case is a stack array +// (which would become a group of SIMD registers in an actual optimized kernel). +// The downside of this approach is that we have to set a fixed max size for +// the accumulator tile, but for now all known cases are comfortably far below +// where trouble would happen. For reference: +// - On ARM NEON, the entire register space is 512 bytes, so the accumulator +// tile is less than that, typically 256 to 384 bytes. +// - On ARM SME, we will be working with an accumulator tile as large as 4096 +// bytes (IIUC). +// - The smallest stack frame size limit that we know we may have to deal with +// on certain targets is 16 kilobytes. +// The size or architecture-specific tiles is relevant here because this +// generic code is what will be run as a fallback if the device is found not to +// support the CPU feature that the tile sizes were picked to target. +enum { iree_ukernel_mmt4d_tile_generic_max_bytes = 4096 }; + +// Generic implementation of matmul tile, i8*i8->i32 case. +static void iree_ukernel_mmt4d_tile_i8i8i32_generic( + void* out_tile_untyped, const void* lhs_panel_untyped, + const void* rhs_panel_untyped, int32_t K, uint32_t flags, + const iree_ukernel_mmt4d_params_t* params) { + int32_t* out_tile = out_tile_untyped; + const int8_t* lhs_panel = lhs_panel_untyped; + const int8_t* rhs_panel = rhs_panel_untyped; + int16_t M0 = params->M0; + int16_t N0 = params->N0; + int16_t K0 = params->K0; + // Initialize the local accumulator tile. + int32_t acc[iree_ukernel_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)]; + if (flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE) { + for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i]; + } else { + for (int i = 0; i < M0 * N0; ++i) acc[i] = 0; + } + // Accumulation loop. + for (iree_ukernel_ssize_t k = 0; k < K; ++k) { + for (iree_ukernel_ssize_t i0 = 0; i0 < M0; ++i0) { + for (iree_ukernel_ssize_t j0 = 0; j0 < N0; ++j0) { + for (iree_ukernel_ssize_t k0 = 0; k0 < K0; ++k0) { + int32_t lhs_val_int32 = lhs_panel[i0 * K0 + k0]; + int32_t rhs_val_int32 = rhs_panel[j0 * K0 + k0]; + acc[i0 * N0 + j0] += lhs_val_int32 * rhs_val_int32; + } + } + } + lhs_panel += M0 * K0; + rhs_panel += N0 * K0; + } + // Store the local accumulator tile to the destination. + for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; +} + +// Generic implementation of matmul tile, f32*f32->f32 case. +static void iree_ukernel_mmt4d_tile_f32f32f32_generic( + void* out_tile_untyped, const void* lhs_panel_untyped, + const void* rhs_panel_untyped, int32_t K, uint32_t flags, + const iree_ukernel_mmt4d_params_t* params) { + float* out_tile = out_tile_untyped; + const float* lhs_panel = lhs_panel_untyped; + const float* rhs_panel = rhs_panel_untyped; + int16_t M0 = params->M0; + int16_t N0 = params->N0; + int16_t K0 = params->K0; + // Initialize the local accumulator tile. + float acc[iree_ukernel_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)]; + if (flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE) { + for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i]; + } else { + for (int i = 0; i < M0 * N0; ++i) acc[i] = 0; + } + // Accumulation loop. + for (iree_ukernel_ssize_t k = 0; k < K; ++k) { + for (iree_ukernel_ssize_t i0 = 0; i0 < M0; ++i0) { + for (iree_ukernel_ssize_t j0 = 0; j0 < N0; ++j0) { + for (iree_ukernel_ssize_t k0 = 0; k0 < K0; ++k0) { + float lhs_val = lhs_panel[i0 * K0 + k0]; + float rhs_val = rhs_panel[j0 * K0 + k0]; + acc[i0 * N0 + j0] += lhs_val * rhs_val; + } + } + } + lhs_panel += M0 * K0; + rhs_panel += N0 * K0; + } + // Store the local accumulator tile to the destination. + for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; +} + +// Generic implementation of matmul tile +iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_select_tile_func_generic( + const iree_ukernel_mmt4d_params_t* params, + iree_ukernel_mmt4d_tile_func_t* out_tile_func) { + int tile_elems = params->M0 * params->N0; + int tile_bytes = tile_elems + << iree_ukernel_mmt4d_out_elem_size_log2(params->type); + if (tile_bytes > iree_ukernel_mmt4d_tile_generic_max_bytes) { + return iree_ukernel_mmt4d_status_unsupported_generic_tile_size; + } + switch (params->type) { + case iree_ukernel_mmt4d_type_f32f32f32: + *out_tile_func = iree_ukernel_mmt4d_tile_f32f32f32_generic; + return iree_ukernel_mmt4d_status_ok; + case iree_ukernel_mmt4d_type_i8i8i32: + *out_tile_func = iree_ukernel_mmt4d_tile_i8i8i32_generic; + return iree_ukernel_mmt4d_status_ok; + default: + // shouldn't happen, validated earlier. + return iree_ukernel_mmt4d_status_bad_type; + } +} diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.h b/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.h new file mode 100644 index 000000000000..685d2b7d77e1 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.h @@ -0,0 +1,19 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_ +#define IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_ + +#include "iree/builtins/ukernel/mmt4d_types.h" + +// On success, *out_tile_func is the generic tile function to use to perform the +// mmt4d with the given *params. The caller may want to first try to get an +// optimized architecture-specific tile function before falling back on this. +iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_select_tile_func_generic( + const iree_ukernel_mmt4d_params_t* params, + iree_ukernel_mmt4d_tile_func_t* out_tile_func); + +#endif // IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_ diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_types.h b/runtime/src/iree/builtins/ukernel/mmt4d_types.h new file mode 100644 index 000000000000..4ebb779a4bc5 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/mmt4d_types.h @@ -0,0 +1,131 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BUILTINS_UKERNEL_MMT4D_TYPES_H_ +#define IREE_BUILTINS_UKERNEL_MMT4D_TYPES_H_ + +#include "iree/builtins/ukernel/common.h" + +// Supported combinations of data types (order: LHS, RHS, OUT). +enum iree_ukernel_mmt4d_type_t { + iree_ukernel_mmt4d_type_none = 0, + iree_ukernel_mmt4d_type_f32f32f32, + iree_ukernel_mmt4d_type_i8i8i32, +}; + +typedef enum iree_ukernel_mmt4d_type_t iree_ukernel_mmt4d_type_t; + +// Parameters for a mmt4d operation. +struct iree_ukernel_mmt4d_params_t { + iree_ukernel_mmt4d_type_t type; + uint32_t flags; + const void* lhs_buffer; + const void* rhs_buffer; + void* out_buffer; + iree_ukernel_ssize_t lhs_stride; + iree_ukernel_ssize_t rhs_stride; + iree_ukernel_ssize_t out_stride; + iree_ukernel_ssize_t M; + iree_ukernel_ssize_t N; + iree_ukernel_ssize_t K; + int32_t M0; + int32_t N0; + int32_t K0; + const uint64_t* cpu_data; +}; + +typedef struct iree_ukernel_mmt4d_params_t iree_ukernel_mmt4d_params_t; + +// Status codes returned by a mmt4d operation. +enum iree_ukernel_mmt4d_status_t { + iree_ukernel_mmt4d_status_ok = 0, + iree_ukernel_mmt4d_status_bad_type, + iree_ukernel_mmt4d_status_bad_flags, + iree_ukernel_mmt4d_status_unsupported_huge_or_negative_dimension, + iree_ukernel_mmt4d_status_unsupported_generic_tile_size, +}; + +typedef enum iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_status_t; + +// TODO: move these flags to a header file shared with compiler/. +#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1 + +#define IREE_UKERNEL_MMT4D_RETURN_IF_ERROR(X) \ + do { \ + iree_ukernel_mmt4d_status_t status = (X); \ + if (status != iree_ukernel_mmt4d_status_ok) { \ + return status; \ + } \ + } while (0) + +// Function pointer type for tile functions, i.e. typically architecture +// specific functions computing one M0xN0 tile of the output matrix, i.e. +// the inner-most loop of the matmul, i.e. the thing that we should actually +// be calling "micro kernel" except that the name is already taken by the +// higher-level builtin name. +// +// The 'params' argument is only used by generic kernels. Actual optimized +// kernels are already specialized for a given tile shape (M0xN0xK0), so the +// five first arguments here are the only information that they need. Not having +// to address 'params' struct fields in the middle of assembly kernels is +// good, because it's hard to get the struct field offsets right in assembly +// and keep that in sync with future struct changes. +typedef void (*iree_ukernel_mmt4d_tile_func_t)( + void* /*out_tile*/, const void* /*lhs_panel*/, const void* /*rhs_panel*/, + int32_t /*K*/, uint32_t /*flags*/, + const iree_ukernel_mmt4d_params_t* /*params*/); + +// Tile kernel declarations. Prototype matches iree_ukernel_mmt4d_tile_func_t. +#define IREE_UKERNEL_MMT4D_TILE_FUNC_DECL(NAME) \ + void NAME(void* out_tile, const void* lhs_panel, const void* rhs_panel, \ + int32_t K, uint32_t flags, \ + const iree_ukernel_mmt4d_params_t* params); + +// Log2 of size of LHS matrix element type, e.g. f32 --> size=4 --> log2=2 +static inline int iree_ukernel_mmt4d_lhs_elem_size_log2( + iree_ukernel_mmt4d_type_t type) { + switch (type) { + case iree_ukernel_mmt4d_type_f32f32f32: + return 2; + default: + return 0; + } +} + +static inline int iree_ukernel_mmt4d_lhs_elem_size( + iree_ukernel_mmt4d_type_t type) { + return 1 << iree_ukernel_mmt4d_lhs_elem_size_log2(type); +} + +// Log2 of size of RHS matrix element type, e.g. f32 --> size=4 --> log2=2 +static inline int iree_ukernel_mmt4d_rhs_elem_size_log2( + iree_ukernel_mmt4d_type_t type) { + return iree_ukernel_mmt4d_lhs_elem_size_log2(type); // for now it's the same +} + +static inline int iree_ukernel_mmt4d_rhs_elem_size( + iree_ukernel_mmt4d_type_t type) { + return 1 << iree_ukernel_mmt4d_rhs_elem_size_log2(type); +} + +// Log2 of size of OUT matrix element type, e.g. f32 --> size=4 --> log2=2 +static inline int iree_ukernel_mmt4d_out_elem_size_log2( + iree_ukernel_mmt4d_type_t type) { + switch (type) { + case iree_ukernel_mmt4d_type_f32f32f32: + case iree_ukernel_mmt4d_type_i8i8i32: + return 2; + default: + return 0; + } +} + +static inline int iree_ukernel_mmt4d_out_elem_size( + iree_ukernel_mmt4d_type_t type) { + return 1 << iree_ukernel_mmt4d_out_elem_size_log2(type); +} + +#endif // IREE_BUILTINS_UKERNEL_MMT4D_TYPES_H_ diff --git a/runtime/src/iree/builtins/ukernel/tools/BUILD b/runtime/src/iree/builtins/ukernel/tools/BUILD index ac1940a63d64..21c4c8733ad2 100644 --- a/runtime/src/iree/builtins/ukernel/tools/BUILD +++ b/runtime/src/iree/builtins/ukernel/tools/BUILD @@ -13,11 +13,24 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "mmt4d_test_utils", + srcs = ["mmt4d_test_utils.cc"], + hdrs = ["mmt4d_test_utils.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/builtins/ukernel:types", + "//runtime/src/iree/schemas:cpu_data", + ], +) + cc_binary_benchmark( name = "mmt4d_benchmark", srcs = ["mmt4d_benchmark.c"], deps = [ + ":mmt4d_test_utils", "//runtime/src/iree/base", + "//runtime/src/iree/base/internal:cpu", "//runtime/src/iree/base/internal:flags", "//runtime/src/iree/builtins/ukernel", "//runtime/src/iree/testing:benchmark", @@ -28,10 +41,11 @@ iree_runtime_cc_test( name = "mmt4d_test", srcs = ["mmt4d_test.cc"], deps = [ + ":mmt4d_test_utils", "//runtime/src/iree/base", + "//runtime/src/iree/base/internal:cpu", "//runtime/src/iree/base/internal:flags", "//runtime/src/iree/builtins/ukernel", "//runtime/src/iree/testing:gtest", - "//runtime/src/iree/testing:gtest_main", ], ) diff --git a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt index 3e8f6d4f9961..4b4e4558a0bf 100644 --- a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt +++ b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt @@ -10,13 +10,29 @@ iree_add_all_subdirs() +iree_cc_library( + NAME + mmt4d_test_utils + HDRS + "mmt4d_test_utils.h" + SRCS + "mmt4d_test_utils.cc" + DEPS + iree::base + iree::builtins::ukernel::types + iree::schemas::cpu_data + PUBLIC +) + iree_cc_binary_benchmark( NAME mmt4d_benchmark SRCS "mmt4d_benchmark.c" DEPS + ::mmt4d_test_utils iree::base + iree::base::internal::cpu iree::base::internal::flags iree::builtins::ukernel iree::testing::benchmark @@ -29,11 +45,12 @@ iree_cc_test( SRCS "mmt4d_test.cc" DEPS + ::mmt4d_test_utils iree::base + iree::base::internal::cpu iree::base::internal::flags iree::builtins::ukernel iree::testing::gtest - iree::testing::gtest_main ) ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c index ee16afd31f1a..b9b6d4e9e517 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c @@ -4,39 +4,164 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET. +// clang-format off +#include // include before ukernel/common.h to keep standard types +// clang-format on -#include #include #include #include "iree/base/api.h" +#include "iree/base/internal/cpu.h" #include "iree/base/internal/flags.h" #include "iree/builtins/ukernel/mmt4d.h" +#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h" #include "iree/testing/benchmark.h" -// Example flag; not really useful: -IREE_FLAG(int32_t, batch_count, 64, "Ops to run per benchmark iteration."); +IREE_FLAG(int32_t, batch_count, 1000, "Ops to run per benchmark iteration."); +IREE_FLAG(int32_t, m_size, 1, + "M-dimension of mmt4d ops. The overall number of rows of the " + "accumulator is that times the M0 tile size."); +IREE_FLAG(int32_t, n_size, 1, + "N-dimension of mmt4d ops. The overall number of columns of the " + "accumulator is that times the N0 tile size."); +IREE_FLAG( + int32_t, k_size, 256, + "K-dimension of mmt4d ops. That's the number of iterations of the inner " + "loop. The overall accumulation depth is that times the K0 tile size."); +IREE_FLAG(bool, accumulate, false, + "Whether the kernel should accumulate into the existing accumulator " + "tile values, or zero the accumulator tile."); -static iree_status_t iree_mmt4d_example_matmul_f32_benchmark( +struct iree_mmt4d_benchmark_user_data_t { + iree_ukernel_mmt4d_type_t type; + int M0; + int N0; + int K0; + const uint64_t* cpu_data; +}; + +typedef struct iree_mmt4d_benchmark_user_data_t + iree_mmt4d_benchmark_user_data_t; + +static iree_status_t iree_mmt4d_benchmark( const iree_benchmark_def_t* benchmark_def, iree_benchmark_state_t* benchmark_state) { + const iree_mmt4d_benchmark_user_data_t* user_data = benchmark_def->user_data; + iree_ukernel_mmt4d_params_t params; + memset(¶ms, 0, sizeof params); + params.type = user_data->type; + params.flags = FLAG_accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0; + params.M = FLAG_m_size; + params.N = FLAG_n_size; + params.K = FLAG_k_size; + params.M0 = user_data->M0; + params.N0 = user_data->N0; + params.K0 = user_data->K0; + params.cpu_data = user_data->cpu_data; + params.lhs_stride = params.K * params.M0 * params.K0; + params.rhs_stride = params.K * params.N0 * params.K0; + params.out_stride = params.N * params.M0 * params.N0; + iree_ukernel_ssize_t lhs_buffer_size = + iree_ukernel_mmt4d_lhs_buffer_size(¶ms); + iree_ukernel_ssize_t rhs_buffer_size = + iree_ukernel_mmt4d_rhs_buffer_size(¶ms); + iree_ukernel_ssize_t out_buffer_size = + iree_ukernel_mmt4d_out_buffer_size(¶ms); + void* lhs_buffer = malloc(lhs_buffer_size); + void* rhs_buffer = malloc(lhs_buffer_size); + void* out_buffer = malloc(lhs_buffer_size); + iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(¶ms); + iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(¶ms); + iree_mmt4d_scalar_type_t out_type = iree_ukernel_mmt4d_out_type(¶ms); + iree_mmt4d_test_random_engine_t* engine = + iree_mmt4d_test_random_engine_create(); + // It's just about plausible that on some platform, for some number type, + // performance might be different on zero buffers vs random buffers. But it + // shouldn't matter that we recreate the random engine every time, getting + // the same random values again. + write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine); + write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine); + write_random_buffer(out_buffer, out_buffer_size, out_type, engine); + iree_mmt4d_test_random_engine_destroy(engine); + params.lhs_buffer = lhs_buffer; + params.rhs_buffer = rhs_buffer; + params.out_buffer = out_buffer; + int64_t total_iterations = 0; while (iree_benchmark_keep_running(benchmark_state, /*batch_count=*/FLAG_batch_count)) { for (int i = 0; i < FLAG_batch_count; ++i) { - iree_ukernel_mmt4d_f32f32f32_params_t params; - memset(¶ms, 0, sizeof params); - int ukernel_retcode = iree_ukernel_mmt4d_f32f32f32(¶ms); - if (0 != iree_ukernel_mmt4d_f32f32f32(¶ms)) { - fprintf(stderr, "FATAL: iree_ukernel_mmt4d_f32f32f32 failed: %s\n", - iree_ukernel_mmt4d_error_message(ukernel_retcode)); + iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(¶ms); + if (status != iree_ukernel_mmt4d_status_ok) { + fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n", + iree_ukernel_mmt4d_status_message(status)); abort(); } } + total_iterations += FLAG_batch_count; } + iree_benchmark_set_items_processed( + benchmark_state, total_iterations * 2 * params.M * params.N * params.K * + params.M0 * params.N0 * params.K0); + free(lhs_buffer); + free(rhs_buffer); + free(out_buffer); return iree_ok_status(); } +static void iree_mmt4d_benchmark_register( + const iree_mmt4d_benchmark_user_data_t* user_data, const char* name) { + // Does this benchmark require an optional CPU feature? + if (user_data->cpu_data[0]) { + if ((iree_cpu_data_field(0) & user_data->cpu_data[0]) != + user_data->cpu_data[0]) { + // The CPU does not meet this benchmark's requirements. The builtin + // would crash. + return; + } + } + + // benchmark_def does not need to be static, it will be cloned. + const iree_benchmark_def_t benchmark_def = { + .flags = IREE_BENCHMARK_FLAG_USE_REAL_TIME, + .time_unit = IREE_BENCHMARK_UNIT_MICROSECOND, + .minimum_duration_ns = 0, + .iteration_count = 0, + .run = iree_mmt4d_benchmark, + .user_data = user_data, + }; + iree_benchmark_register(IREE_SV(name), &benchmark_def); +} + +#define MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, _cpu_data_field_0, \ + _label) \ + do { \ + static const uint64_t local_cpu_data[IREE_CPU_DATA_FIELD_COUNT] = { \ + _cpu_data_field_0}; \ + static const iree_mmt4d_benchmark_user_data_t user_data = { \ + .type = iree_ukernel_mmt4d_type_##_type, \ + .M0 = _m0, \ + .N0 = _n0, \ + .K0 = _k0, \ + .cpu_data = local_cpu_data, \ + }; \ + iree_mmt4d_benchmark_register(&user_data, \ + "iree_ukernel_mmt4d_" #_type "_" #_m0 \ + "x" #_n0 "x" #_k0 "_" #_label); \ + } while (0) + +#define MMT4D_BENCHMARK_REGISTER_GENERIC(_type, _m0, _n0, _k0) \ + MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, 0, GENERIC) + +#define MMT4D_BENCHMARK_REGISTER_ARM_64(_type, _m0, _n0, _k0) \ + MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, 0, arm_64) + +#define MMT4D_BENCHMARK_REGISTER_ARM_64_WITH_CPU_FEATURE(_type, _m0, _n0, _k0, \ + _cpu_feature) \ + MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, \ + IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##_cpu_feature, \ + arm_64_##_cpu_feature) + int main(int argc, char** argv) { iree_flags_set_usage( "mmt4d_benchmark", @@ -45,22 +170,23 @@ int main(int argc, char** argv) { iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK, &argc, &argv); iree_benchmark_initialize(&argc, argv); + iree_cpu_initialize(iree_allocator_system()); - // TODO: always add _generic variants to have a baseline vs reference? - - { - static const iree_benchmark_def_t benchmark_def = { - .flags = IREE_BENCHMARK_FLAG_MEASURE_PROCESS_CPU_TIME | - IREE_BENCHMARK_FLAG_USE_REAL_TIME, - .time_unit = IREE_BENCHMARK_UNIT_NANOSECOND, - .minimum_duration_ns = 0, - .iteration_count = 0, - .run = iree_mmt4d_example_matmul_f32_benchmark, - .user_data = NULL, - }; - iree_benchmark_register(IREE_SV("iree_mmt4d_example_matmul_f32"), - &benchmark_def); - } + // Generic code paths, not actually used, but interesting to get a sense + // of how slow generic code goes vs decent SIMD kernels. Interesting also to + // compare generic float vs int arithmetic. + MMT4D_BENCHMARK_REGISTER_GENERIC(f32f32f32, 4, 4, 1); + MMT4D_BENCHMARK_REGISTER_GENERIC(i8i8i32, 4, 4, 1); + +// ARM_64 benchmarks. +#if defined(IREE_UKERNEL_ARCH_ARM_64) + + MMT4D_BENCHMARK_REGISTER_ARM_64(f32f32f32, 8, 8, 1); + MMT4D_BENCHMARK_REGISTER_ARM_64(i8i8i32, 8, 8, 1); + MMT4D_BENCHMARK_REGISTER_ARM_64_WITH_CPU_FEATURE(i8i8i32, 8, 8, 4, DOTPROD); + MMT4D_BENCHMARK_REGISTER_ARM_64_WITH_CPU_FEATURE(i8i8i32, 8, 8, 8, I8MM); + +#endif // defined(IREE_UKERNEL_ARCH_ARM_64) iree_benchmark_run_specified(); return 0; diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc index c3094e2ffd5d..a9b42dad09de 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc @@ -4,22 +4,327 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET. +// Design rationale and code creep warning! +// +// Summary: +// +// The goal of this test is to provide 100% coverage across all +// internal kernel variants, which is not convenient to do in e2e tests. +// Resist the temptation to reimplement here all the niceties of the e2e test. +// Stick to guaranteeing that if the test succeeds, then the mmt4d builtin, +// with all its asm code path variants, is correct. In case of failure, the +// user is expected to be happy to jump into a debugger. +// +// Longer story: +// +// It is said by an ancient prophecy that all matrix multiplication tests grow +// to be thousands of lines of code. +// +// In fact, we already have one, it's the end-to-end matmul test under +// iree/tests/e2e/matmul. That one is needed anyway, and needs to be large +// anyway, being end-to-end and applying to all target backends, including those +// where device!=host. And so it makes sense for that one to have extra bells +// and whistles such as fuzzy comparisons, pretty-printing of numerical errors +// to aid debugging, and yet more special logic to make numerical errors easier +// to debug. +// +// Let's not duplicate all that here! Note also that, tempting as it would +// be to borrow the matrix-pretty-printing stuff from e2e/matmul, that applies +// to plain row-major 2D matrices, while here we are dealing with 4D arrays / +// tiled-layout matrices. Trying to bridge over that difference would bring yet +// more complexity. +// +// Instead, let us keep a sharp focus on why we need this separate micro test. +// The motivation is not the usual "because micro tests are easier to debug than +// e2e" but rather because it would be difficult to have 100% code coverage in +// e2e. There are many variants of mmt4d builtin ukernels for various CPU +// features and tuned for various CPU models. We have to iterate over all these +// variants. Trying to do so in e2e tests would require exposing knobs for +// things that we would otherwise prefer to keep internal in the mmt4d builtin +// implementation, and would make e2e/matmul tests even more expensive. -#include +// clang-format off +#include // include before ukernel/common.h to keep standard types +// clang-format on + +#include "iree/builtins/ukernel/mmt4d.h" -// Include in expected order with stdint and other system headers first. -// See the note in mmt4d.h about stdint.h. This won't be an issue in most uses -// but clang-format really likes to put the mmt4d.h above the system headers -// due to this _test.cc file naming. +#include #include "iree/base/api.h" -#include "iree/builtins/ukernel/mmt4d.h" +#include "iree/base/internal/cpu.h" +#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h" #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" -TEST(MMT4DTest, iree_mmt4d_example_matmul_f32) { - iree_ukernel_mmt4d_f32f32f32_params_t params; +template +static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) { + bool accumulate = params.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE; + iree_ukernel_ssize_t lhs_tile_size = params.M0 * params.K0; + iree_ukernel_ssize_t rhs_tile_size = params.N0 * params.K0; + iree_ukernel_ssize_t out_tile_size = params.M0 * params.N0; + for (iree_ukernel_ssize_t i = 0; i < params.M; ++i) { + for (iree_ukernel_ssize_t j = 0; j < params.N; ++j) { + out_t* out_tile_ptr = ((out_t*)params.out_buffer) + + i * params.out_stride + j * out_tile_size; + const lhs_t* lhs_panel_ptr = + ((const lhs_t*)params.lhs_buffer) + i * params.lhs_stride; + const rhs_t* rhs_panel_ptr = + ((const rhs_t*)params.rhs_buffer) + j * params.rhs_stride; + for (iree_ukernel_ssize_t i0 = 0; i0 < params.M0; ++i0) { + for (iree_ukernel_ssize_t j0 = 0; j0 < params.N0; ++j0) { + const lhs_t* lhs_tile_ptr = lhs_panel_ptr; + const rhs_t* rhs_tile_ptr = rhs_panel_ptr; + out_t* out_ptr = out_tile_ptr + i0 * params.N0 + j0; + out_t acc = accumulate ? *out_ptr : 0.f; + for (iree_ukernel_ssize_t k = 0; k < params.K; ++k) { + for (iree_ukernel_ssize_t k0 = 0; k0 < params.K0; ++k0) { + out_t lhs_val = lhs_tile_ptr[i0 * params.K0 + k0]; + out_t rhs_val = rhs_tile_ptr[j0 * params.K0 + k0]; + acc += lhs_val * rhs_val; + } + lhs_tile_ptr += lhs_tile_size; + rhs_tile_ptr += rhs_tile_size; + } + *out_ptr = acc; + } + } + } + } +} + +static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) { + switch (params.type) { + case iree_ukernel_mmt4d_type_f32f32f32: + iree_mmt4d_reference(params); + break; + case iree_ukernel_mmt4d_type_i8i8i32: + iree_mmt4d_reference(params); + break; + default: + assert(false && "unknown type"); + } +} + +static void test_one_matmul_using_given_lhs_rhs( + const iree_ukernel_mmt4d_params_t& shared_params, + iree_mmt4d_test_random_engine_t* engine) { + assert(!shared_params.out_buffer); + + iree_ukernel_mmt4d_params_t reference_params; + memcpy(&reference_params, &shared_params, sizeof shared_params); + iree_ukernel_ssize_t out_buffer_size = + iree_ukernel_mmt4d_out_buffer_size(&shared_params); + reference_params.out_buffer = malloc(out_buffer_size); + iree_mmt4d_scalar_type_t out_type = + iree_ukernel_mmt4d_out_type(&shared_params); + write_random_buffer(reference_params.out_buffer, out_buffer_size, out_type, + engine); + + iree_ukernel_mmt4d_params_t actual_params; + memcpy(&actual_params, &shared_params, sizeof shared_params); + actual_params.out_buffer = malloc(out_buffer_size); + memcpy(actual_params.out_buffer, reference_params.out_buffer, + out_buffer_size); + + iree_mmt4d_reference(reference_params); + iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&actual_params); + if (status != iree_ukernel_mmt4d_status_ok) { + fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n", + iree_ukernel_mmt4d_status_message(status)); + abort(); + } + + // For now we use exact comparisons, even for float, even though the reference + // code accumulates in a different order compared to the actual code. This + // relies on picking input test matrix elements so that all intermediate + // values are exactly representable - i.e. small integer numerators. This + // become problematic when we do float16. See the comment at the top of this + // file explaining how we refrain from letting this grow into a 1000-line-long + // fully-featured test. + if (memcmp(actual_params.out_buffer, reference_params.out_buffer, + out_buffer_size)) { + const auto& p = actual_params; + fprintf(stderr, "mmt4d test failure with the following params:\n"); + fprintf(stderr, " type=%s\n", get_mmt4d_type_str(&p)); + fprintf(stderr, " flags: accumulate=%d\n", + (int)(p.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE)); + fprintf(stderr, " M=%d, N=%d, K=%d\n", (int)p.M, (int)p.N, (int)p.K); + fprintf(stderr, " M0=%d, N0=%d, K0=%d\n", (int)p.M0, (int)p.N0, (int)p.K0); + fprintf(stderr, " lhs_stride=%zu, rhs_stride=%zu, out_stride=%zu\n", + (size_t)p.lhs_stride, (size_t)p.rhs_stride, (size_t)p.out_stride); + fprintf(stderr, " cpu features: %s\n", get_cpu_features_str(&p)); + // Don't even try to pretty-print matrices. See the comment at the top of + // this file. Don't try to use GTest primitives to show expected vs actual + // since that would require dispatching to type-specific code paths. + // Also, at this point it's easy for the user to rerun this test + // in a debugger and manually inspect values. + // + // We want fatal here - that is what the user running this in a debugger + // wants us to do, so they can inspect values while they exist in memory. + // What's the GTest-sanctioned fatal error? GTEST_FAIL() has a comment that + // says that it's fatal, but that's a lie at least here on Android. + abort(); + } + + free(reference_params.out_buffer); + free(actual_params.out_buffer); +} + +static void test_one_matmul_creating_lhs_rhs_for_given_shape( + const iree_ukernel_mmt4d_params_t& shared_params, + iree_mmt4d_test_random_engine_t* engine) { + iree_ukernel_mmt4d_params_t params; + memcpy(¶ms, &shared_params, sizeof params); + assert(!params.lhs_buffer); + assert(!params.rhs_buffer); + assert(!params.out_buffer); + assert(!params.lhs_stride); + assert(!params.rhs_stride); + assert(!params.out_stride); + // Populate strides first - they are read by the get_*_buffer_size helper. + // Randomly make strides either tight or not to exercise all cases. + params.lhs_stride = params.K * params.M0 * params.K0 + + iree_mmt4d_test_random_engine_get_0_or_1(engine); + params.rhs_stride = params.K * params.N0 * params.K0 + + iree_mmt4d_test_random_engine_get_0_or_1(engine); + params.out_stride = params.N * params.M0 * params.N0 + + iree_mmt4d_test_random_engine_get_0_or_1(engine); + iree_ukernel_ssize_t lhs_buffer_size = + iree_ukernel_mmt4d_lhs_buffer_size(¶ms); + iree_ukernel_ssize_t rhs_buffer_size = + iree_ukernel_mmt4d_rhs_buffer_size(¶ms); + iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(¶ms); + iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(¶ms); + void* lhs_buffer = malloc(lhs_buffer_size); + void* rhs_buffer = malloc(rhs_buffer_size); + write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine); + write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine); + params.lhs_buffer = lhs_buffer; + params.rhs_buffer = rhs_buffer; + test_one_matmul_using_given_lhs_rhs(params, engine); + free(lhs_buffer); + free(rhs_buffer); +} + +static void test_matmuls_for_various_MNK_shapes_and_flags( + const iree_ukernel_mmt4d_params_t& shared_params, + iree_mmt4d_test_random_engine_t* engine) { + iree_ukernel_mmt4d_params_t params; + memcpy(¶ms, &shared_params, sizeof params); + assert(params.M == 0); + assert(params.N == 0); + assert(params.K == 0); + assert(params.flags == 0); + struct shape_mnk_t { + int m, n, k; + }; + std::vector shapes{ + // Degenerate case M==0. Vacuous. + {0, 1, 1}, + {0, 5, 7}, + // Degenerate case N==0. Vacuous. + {1, 0, 1}, + {5, 0, 7}, + // Degenerate case K==0. Vacuous if flags have ACCUMULATE. Zeroing the + // output buffer otherwise. + {1, 1, 0}, + {5, 7, 0}, + // Non-degenerate cases. + {1, 1, 1}, + {1, 1, 2}, + {1, 1, 10}, + {1, 1, 1000}, + {2, 1, 1}, + {1, 2, 1}, + {2, 2, 2}, + {5, 7, 13}, + }; + for (shape_mnk_t shape : shapes) { + params.M = shape.m; + params.N = shape.n; + params.K = shape.k; + for (bool accumulate : {false, true}) { + params.flags = accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0; + test_one_matmul_creating_lhs_rhs_for_given_shape(params, engine); + } + } +} + +// Tests mmt4d with the specific data type and specific M0xN0xK0 tile format. +// If cpu_data_field_0_bit is nonzero, it must then be a single bit (power of 2) +// and if the CPU supports the corresponding feature, the mmt4d tests are run a +// second time with that CPU feature enabled. +static void mmt4d_test(iree_ukernel_mmt4d_type_t type, int M0, int N0, int K0, + uint64_t cpu_data_field_0_bit) { + // Letting each test create its own engine makes them independent: a testcase + // succeeds or fails the same way if we isolate it or reorder it. The + // potential downside of repeating the same pseudorandom sequence is OK + // because any pseudorandom sequence should be equally good at coverage, and + // different testcases tend to use different tile shapes anyway. + iree_mmt4d_test_random_engine_t* engine = + iree_mmt4d_test_random_engine_create(); + iree_ukernel_mmt4d_params_t params; memset(¶ms, 0, sizeof params); - EXPECT_EQ(0, iree_ukernel_mmt4d_f32f32f32(¶ms)); + params.type = type; + params.M0 = M0; + params.N0 = N0; + params.K0 = K0; + const uint64_t local_cpu_data_default[IREE_CPU_DATA_FIELD_COUNT] = {0}; + params.cpu_data = local_cpu_data_default; + // First try without any optional CPU feature. This matters even when the + // feature is supported by the CPU because we want to test the fallback to + // architecture-default or generic code. + test_matmuls_for_various_MNK_shapes_and_flags(params, engine); + // If this is nonzero, we are asked to test again with this CPU feature. + if (cpu_data_field_0_bit) { + const uint64_t local_cpu_data_with_bit[IREE_CPU_DATA_FIELD_COUNT] = { + cpu_data_field_0_bit}; + params.cpu_data = local_cpu_data_with_bit; + // Check if the CPU supports the feature (otherwise, we crash). + bool supported = iree_cpu_data_field(0) & params.cpu_data[0]; + if (supported) { + // Run with the optional CPU feature. + printf("Device supports CPU feature: %s\n", + get_cpu_features_str(¶ms)); + test_matmuls_for_various_MNK_shapes_and_flags(params, engine); + } else { + printf("Skipped: device does not support CPU feature: %s\n", + get_cpu_features_str(¶ms)); + } + } + iree_mmt4d_test_random_engine_destroy(engine); +} + +#define MMT4D_TEST(type, M0, N0, K0, test_suffix, feature_bit) \ + TEST(Mmt4dTest, type##_tile_##M0##x##N0##x##K0##_##test_suffix) { \ + mmt4d_test(iree_ukernel_mmt4d_type_##type, M0, N0, K0, feature_bit); \ + } + +// Generic tests, not matching any particular CPU feature. This is the place to +// test weird M0, N0, K0 to ensure e.g. that we haven't unwittingly baked in a +// power-of-two assumption +MMT4D_TEST(f32f32f32, 3, 5, 7, generic, 0) +MMT4D_TEST(i8i8i32, 9, 6, 3, generic, 0) + +// ARM_64 tests. +#if defined(IREE_UKERNEL_ARCH_ARM_64) + +#define MMT4D_ARM_64_TEST(type, M0, N0, K0) \ + MMT4D_TEST(type, M0, N0, K0, arm_64, 0) + +#define MMT4D_ARM_64_TEST_WITH_CPU_FEATURE(type, M0, N0, K0, FEATURE) \ + MMT4D_TEST(type, M0, N0, K0, arm_64_##FEATURE, \ + IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##FEATURE) + +MMT4D_ARM_64_TEST(f32f32f32, 8, 8, 1) +MMT4D_ARM_64_TEST(i8i8i32, 8, 8, 1) +MMT4D_ARM_64_TEST_WITH_CPU_FEATURE(i8i8i32, 8, 8, 4, DOTPROD) +MMT4D_ARM_64_TEST_WITH_CPU_FEATURE(i8i8i32, 8, 8, 8, I8MM) +#endif // defined(IREE_UKERNEL_ARCH_ARM_64) + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + iree_cpu_initialize(iree_allocator_system()); + return RUN_ALL_TESTS(); } diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc new file mode 100644 index 000000000000..878b9969d2fb --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc @@ -0,0 +1,162 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h" + +#include +#include + +#include "iree/schemas/cpu_data.h" + +iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type( + const iree_ukernel_mmt4d_params_t* params) { + switch (params->type) { + case iree_ukernel_mmt4d_type_f32f32f32: + return iree_mmt4d_scalar_type_f32; + case iree_ukernel_mmt4d_type_i8i8i32: + return iree_mmt4d_scalar_type_i8; + default: + assert(false && "unknown type"); + return iree_mmt4d_scalar_type_unknown; + } +} + +iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type( + const iree_ukernel_mmt4d_params_t* params) { + // same for now + return iree_ukernel_mmt4d_lhs_type(params); +} + +iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type( + const iree_ukernel_mmt4d_params_t* params) { + switch (params->type) { + case iree_ukernel_mmt4d_type_f32f32f32: + return iree_mmt4d_scalar_type_f32; + case iree_ukernel_mmt4d_type_i8i8i32: + return iree_mmt4d_scalar_type_i32; + default: + assert(false && "unknown type"); + return iree_mmt4d_scalar_type_unknown; + } +} + +iree_ukernel_ssize_t iree_ukernel_mmt4d_lhs_buffer_size( + const iree_ukernel_mmt4d_params_t* params) { + return params->M * params->lhs_stride * + iree_ukernel_mmt4d_lhs_elem_size(params->type); +} + +iree_ukernel_ssize_t iree_ukernel_mmt4d_rhs_buffer_size( + const iree_ukernel_mmt4d_params_t* params) { + return params->N * params->rhs_stride * + iree_ukernel_mmt4d_rhs_elem_size(params->type); +} + +iree_ukernel_ssize_t iree_ukernel_mmt4d_out_buffer_size( + const iree_ukernel_mmt4d_params_t* params) { + return params->M * params->out_stride * + iree_ukernel_mmt4d_out_elem_size(params->type); +} + +struct iree_mmt4d_test_random_engine_t { + std::minstd_rand cpp_random_engine; +}; + +iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create() { + return new iree_mmt4d_test_random_engine_t; +} + +void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e) { + delete e; +} + +static int iree_mmt4d_test_random_engine_get_in_uint16_range( + iree_mmt4d_test_random_engine_t* e) { + uint32_t v = e->cpp_random_engine(); + // return the second-least-signicant out of the 4 bytes of state. It avoids + // some mild issues with the least-significant and most-significant bytes. + return (v >> 8) & 0xffff; +} + +int iree_mmt4d_test_random_engine_get_0_or_1( + iree_mmt4d_test_random_engine_t* e) { + int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e); + return v & 1; +} + +int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15( + iree_mmt4d_test_random_engine_t* e) { + int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e); + return (v % 32) - 16; +} + +template +static void write_random_buffer(T* buffer, iree_ukernel_ssize_t size_in_bytes, + iree_mmt4d_test_random_engine_t* engine) { + iree_ukernel_ssize_t size_in_elems = size_in_bytes / sizeof(T); + assert(size_in_elems * sizeof(T) == size_in_bytes && "bad size"); + for (iree_ukernel_ssize_t i = 0; i < size_in_elems; ++i) { + // Small integers, should work for now for all the types we currently have + // and enable exact float arithmetic, allowing to keep tests simpler for + // now. Watch out for when we'll do float16! + T random_val = + iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(engine); + buffer[i] = random_val; + } +} + +void write_random_buffer(void* buffer, iree_ukernel_ssize_t size_in_bytes, + iree_mmt4d_scalar_type_t type, + iree_mmt4d_test_random_engine_t* engine) { + switch (type) { + case iree_mmt4d_scalar_type_f32: + write_random_buffer(static_cast(buffer), size_in_bytes, engine); + return; + case iree_mmt4d_scalar_type_i32: + write_random_buffer(static_cast(buffer), size_in_bytes, engine); + return; + case iree_mmt4d_scalar_type_i8: + write_random_buffer(static_cast(buffer), size_in_bytes, engine); + return; + default: + assert(false && "unknown type"); + } +} + +const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params) { + switch (params->type) { +#define GET_MMT4D_TYPE_STR_CASE(x) \ + case x: \ + return #x; + GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_f32f32f32); + GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_i8i8i32); + default: + assert(false && "unknown type"); + return "unknown type"; + } +} + +const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params) { + // We set only one feature bit at a time in this test --- not an actual + // detected cpu data field. This might have to change in the future if some + // code path relies on the combination of two features. + // For now, asserting only one bit set, and taking advantage of that to work + // with plain string literals. + assert(0 == (params->cpu_data[0] & (params->cpu_data[0] - 1))); + if (params->cpu_data[0] == 0) { + return "(none)"; + } +#if defined(IREE_UKERNEL_ARCH_ARM_64) + if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) { + return "i8mm"; + } + if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) { + return "dotprod"; + } +#endif // defined(IREE_UKERNEL_ARCH_ARM_64) + assert(false && "unknown CPU feature"); + return "unknown CPU feature"; +} diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h new file mode 100644 index 000000000000..c415567df1f4 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h @@ -0,0 +1,63 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_ +#define IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_ + +// clang-format off +#include // include before ukernel/common.h to keep standard types +// clang-format on + +#include "iree/builtins/ukernel/mmt4d_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +enum iree_mmt4d_scalar_type_t { + iree_mmt4d_scalar_type_unknown, + iree_mmt4d_scalar_type_i8, + iree_mmt4d_scalar_type_i32, + iree_mmt4d_scalar_type_f32, +}; + +typedef enum iree_mmt4d_scalar_type_t iree_mmt4d_scalar_type_t; + +iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type( + const iree_ukernel_mmt4d_params_t* params); +iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type( + const iree_ukernel_mmt4d_params_t* params); +iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type( + const iree_ukernel_mmt4d_params_t* params); + +iree_ukernel_ssize_t iree_ukernel_mmt4d_lhs_buffer_size( + const iree_ukernel_mmt4d_params_t* params); +iree_ukernel_ssize_t iree_ukernel_mmt4d_rhs_buffer_size( + const iree_ukernel_mmt4d_params_t* params); +iree_ukernel_ssize_t iree_ukernel_mmt4d_out_buffer_size( + const iree_ukernel_mmt4d_params_t* params); + +struct iree_mmt4d_test_random_engine_t; +typedef struct iree_mmt4d_test_random_engine_t iree_mmt4d_test_random_engine_t; +iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create(); +void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e); +int iree_mmt4d_test_random_engine_get_0_or_1( + iree_mmt4d_test_random_engine_t* e); +int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15( + iree_mmt4d_test_random_engine_t* e); + +void write_random_buffer(void* buffer, iree_ukernel_ssize_t size_in_bytes, + iree_mmt4d_scalar_type_t type, + iree_mmt4d_test_random_engine_t* engine); + +const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params); +const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_ diff --git a/runtime/src/iree/hal/drivers/local_sync/registration/driver_module.c b/runtime/src/iree/hal/drivers/local_sync/registration/driver_module.c index 19058db8d2ac..22421d52f20d 100644 --- a/runtime/src/iree/hal/drivers/local_sync/registration/driver_module.c +++ b/runtime/src/iree/hal/drivers/local_sync/registration/driver_module.c @@ -41,7 +41,8 @@ static iree_status_t iree_hal_local_sync_driver_factory_try_create( iree_hal_executable_loader_t* loaders[8] = {NULL}; iree_host_size_t loader_count = 0; iree_status_t status = iree_hal_create_all_available_executable_loaders( - IREE_ARRAYSIZE(loaders), &loader_count, loaders, host_allocator); + iree_hal_executable_import_provider_default(), IREE_ARRAYSIZE(loaders), + &loader_count, loaders, host_allocator); iree_hal_allocator_t* device_allocator = NULL; if (iree_status_is_ok(status)) { diff --git a/runtime/src/iree/hal/drivers/local_task/registration/driver_module.c b/runtime/src/iree/hal/drivers/local_task/registration/driver_module.c index 3e28d1034c29..dd0eab8331aa 100644 --- a/runtime/src/iree/hal/drivers/local_task/registration/driver_module.c +++ b/runtime/src/iree/hal/drivers/local_task/registration/driver_module.c @@ -44,7 +44,8 @@ static iree_status_t iree_hal_local_task_driver_factory_try_create( iree_hal_executable_loader_t* loaders[8] = {NULL}; iree_host_size_t loader_count = 0; iree_status_t status = iree_hal_create_all_available_executable_loaders( - IREE_ARRAYSIZE(loaders), &loader_count, loaders, host_allocator); + iree_hal_executable_import_provider_default(), IREE_ARRAYSIZE(loaders), + &loader_count, loaders, host_allocator); iree_task_executor_t* executor = NULL; if (iree_status_is_ok(status)) { diff --git a/runtime/src/iree/hal/drivers/vulkan/builtin_executables.cc b/runtime/src/iree/hal/drivers/vulkan/builtin_executables.cc index dfbed2414703..7c434d758b3b 100644 --- a/runtime/src/iree/hal/drivers/vulkan/builtin_executables.cc +++ b/runtime/src/iree/hal/drivers/vulkan/builtin_executables.cc @@ -54,17 +54,23 @@ iree_status_t BuiltinExecutables::InitializeExecutables() { IREE_TRACE_SCOPE(); // Create descriptor set layouts for our compute pipeline. - // Even though we're just using one set, we still need to create layout - // bindings for those preceding it. + // Even though we're just using one set, we still need to create dummy set + // layout (without any bindings) for those preceding this set. for (size_t i = 0; i < IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET_COUNT; ++i) { iree_hal_descriptor_set_layout_t* layout = NULL; - iree_hal_descriptor_set_layout_binding_t layout_binding; - layout_binding.binding = 0; - layout_binding.type = IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER; - layout_binding.flags = IREE_HAL_DESCRIPTOR_FLAG_NONE; - IREE_RETURN_IF_ERROR(iree_hal_vulkan_native_descriptor_set_layout_create( - logical_device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_FLAG_NONE, - /*binding_count=*/1, &layout_binding, &layout)); + if (i == IREE_HAL_VULKAN_BUILTIN_DESCRIPTOR_SET) { + iree_hal_descriptor_set_layout_binding_t layout_binding; + layout_binding.binding = 0; + layout_binding.type = IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER; + layout_binding.flags = IREE_HAL_DESCRIPTOR_FLAG_NONE; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_native_descriptor_set_layout_create( + logical_device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_FLAG_NONE, + /*binding_count=*/1, &layout_binding, &layout)); + } else { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_native_descriptor_set_layout_create( + logical_device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_FLAG_NONE, + /*binding_count=*/0, /*bindings=*/nullptr, &layout)); + } descriptor_set_layouts_[i] = layout; } diff --git a/runtime/src/iree/hal/drivers/vulkan/native_pipeline_layout.cc b/runtime/src/iree/hal/drivers/vulkan/native_pipeline_layout.cc index d8aeabb595f4..13ca650e1094 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_pipeline_layout.cc +++ b/runtime/src/iree/hal/drivers/vulkan/native_pipeline_layout.cc @@ -52,16 +52,20 @@ static iree_status_t iree_hal_vulkan_create_descriptor_set_layout( create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; create_info.pNext = NULL; create_info.flags = 0; - if (logical_device->enabled_extensions().push_descriptors) { - // Note that we can *only* use push descriptor sets if we set this create - // flag. If push descriptors aren't supported we emulate them with normal - // descriptors so it's fine to have kPushOnly without support. - create_info.flags |= - VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; - } VkDescriptorSetLayoutBinding* native_bindings = NULL; if (binding_count > 0) { + if (logical_device->enabled_extensions().push_descriptors) { + // Note that we can *only* use push descriptor sets if we set this create + // flag. If push descriptors aren't supported we emulate them with normal + // descriptors so it's fine to have kPushOnly without support. + // Also we only enable this when there are at least one binding in it. + // (We can have dummy descriptor sets without any bindings for builtin + // executables.) + create_info.flags |= + VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; + } + // TODO(benvanik): avoid this allocation if possible (inline_array). IREE_RETURN_IF_ERROR(iree_allocator_malloc( logical_device->host_allocator(), diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_driver.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_driver.cc index 9b03af2aff25..4e1433bb31b0 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vulkan_driver.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_driver.cc @@ -549,6 +549,7 @@ static iree_status_t iree_hal_vulkan_driver_find_device_by_index( *found_physical_device = physical_device; + (void)visible_physical_devices; // unused var if IREE_STATUS_MODE=0 IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } diff --git a/runtime/src/iree/hal/local/executable_library_benchmark.c b/runtime/src/iree/hal/local/executable_library_benchmark.c index f08dfa204669..5ec7080015b0 100644 --- a/runtime/src/iree/hal/local/executable_library_benchmark.c +++ b/runtime/src/iree/hal/local/executable_library_benchmark.c @@ -141,7 +141,8 @@ static iree_status_t iree_hal_executable_library_run( // Register the loader used to load (or find) the executable. iree_hal_executable_loader_t* executable_loader = NULL; IREE_RETURN_IF_ERROR(iree_hal_create_executable_loader_by_name( - iree_make_cstring_view(FLAG_executable_format), host_allocator, + iree_make_cstring_view(FLAG_executable_format), + iree_hal_executable_import_provider_default(), host_allocator, &executable_loader)); // Setup the specification used to perform the executable load. diff --git a/runtime/src/iree/hal/local/executable_loader.c b/runtime/src/iree/hal/local/executable_loader.c index 955572b3370b..f0b8c78dbf15 100644 --- a/runtime/src/iree/hal/local/executable_loader.c +++ b/runtime/src/iree/hal/local/executable_loader.c @@ -6,6 +6,24 @@ #include "iree/hal/local/executable_loader.h" +#if defined(IREE_HAL_EXECUTABLE_IMPORT_PROVIDER_DEFAULT_FN) + +// Defined by the user and linked in to the binary: +extern iree_hal_executable_import_provider_t +IREE_HAL_EXECUTABLE_IMPORT_PROVIDER_DEFAULT_FN(void); + +iree_hal_executable_import_provider_t +iree_hal_executable_import_provider_default(void) { + return IREE_HAL_EXECUTABLE_IMPORT_PROVIDER_DEFAULT_FN(); +} + +#else +iree_hal_executable_import_provider_t +iree_hal_executable_import_provider_default(void) { + return iree_hal_executable_import_provider_null(); +} +#endif // IREE_HAL_EXECUTABLE_IMPORT_PROVIDER_DEFAULT_FN + iree_status_t iree_hal_executable_import_provider_resolve( const iree_hal_executable_import_provider_t import_provider, iree_string_view_t symbol_name, void** out_fn_ptr) { diff --git a/runtime/src/iree/hal/local/executable_loader.h b/runtime/src/iree/hal/local/executable_loader.h index 426d6a4c9178..c7a5967b96c7 100644 --- a/runtime/src/iree/hal/local/executable_loader.h +++ b/runtime/src/iree/hal/local/executable_loader.h @@ -41,11 +41,21 @@ typedef struct iree_hal_executable_import_provider_t { } iree_hal_executable_import_provider_t; static inline iree_hal_executable_import_provider_t -iree_hal_executable_import_provider_null() { +iree_hal_executable_import_provider_null(void) { iree_hal_executable_import_provider_t provider = {NULL, NULL}; return provider; } +// Returns the import provider specified by +// IREE_HAL_EXECUTABLE_IMPORT_PROVIDER_DEFAULT_FN or null. +// +// To use define a function like: +// iree_hal_executable_import_provider_t my_provider(void) { ... } +// And define it: +// -DIREE_HAL_EXECUTABLE_IMPORT_PROVIDER_DEFAULT_FN=my_provider +iree_hal_executable_import_provider_t +iree_hal_executable_import_provider_default(void); + // Resolves an import symbol with the given |symbol_name| and stores a pointer // to the function (or its context) in |out_fn_ptr|. // diff --git a/runtime/src/iree/hal/local/loaders/registration/init.c b/runtime/src/iree/hal/local/loaders/registration/init.c index faff2ea735ed..f0e67f521b87 100644 --- a/runtime/src/iree/hal/local/loaders/registration/init.c +++ b/runtime/src/iree/hal/local/loaders/registration/init.c @@ -24,6 +24,7 @@ #endif // IREE_HAVE_HAL_EXECUTABLE_LOADER_VMVX_MODULE IREE_API_EXPORT iree_status_t iree_hal_create_all_available_executable_loaders( + iree_hal_executable_import_provider_t import_provider, iree_host_size_t capacity, iree_host_size_t* out_count, iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator) { IREE_ASSERT_ARGUMENT(out_count); @@ -52,16 +53,14 @@ IREE_API_EXPORT iree_status_t iree_hal_create_all_available_executable_loaders( #if defined(IREE_HAVE_HAL_EXECUTABLE_LOADER_SYSTEM_LIBRARY) if (iree_status_is_ok(status)) { status = iree_hal_system_library_loader_create( - iree_hal_executable_import_provider_null(), host_allocator, - &loaders[count++]); + import_provider, host_allocator, &loaders[count++]); } #endif // IREE_HAVE_HAL_EXECUTABLE_LOADER_SYSTEM_LIBRARY #if defined(IREE_HAVE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF) if (iree_status_is_ok(status)) { status = iree_hal_embedded_elf_loader_create( - iree_hal_executable_import_provider_null(), host_allocator, - &loaders[count++]); + import_provider, host_allocator, &loaders[count++]); } #endif // IREE_HAVE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF @@ -84,21 +83,21 @@ IREE_API_EXPORT iree_status_t iree_hal_create_all_available_executable_loaders( } IREE_API_EXPORT iree_status_t iree_hal_create_executable_loader_by_name( - iree_string_view_t name, iree_allocator_t host_allocator, + iree_string_view_t name, + iree_hal_executable_import_provider_t import_provider, + iree_allocator_t host_allocator, iree_hal_executable_loader_t** out_executable_loader) { #if defined(IREE_HAVE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF) if (iree_string_view_starts_with(name, IREE_SV("embedded-elf"))) { - return iree_hal_embedded_elf_loader_create( - iree_hal_executable_import_provider_null(), host_allocator, - out_executable_loader); + return iree_hal_embedded_elf_loader_create(import_provider, host_allocator, + out_executable_loader); } #endif // IREE_HAVE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF #if defined(IREE_HAVE_HAL_EXECUTABLE_LOADER_SYSTEM_LIBRARY) if (iree_string_view_starts_with(name, IREE_SV("system-library"))) { return iree_hal_system_library_loader_create( - iree_hal_executable_import_provider_null(), host_allocator, - out_executable_loader); + import_provider, host_allocator, out_executable_loader); } #endif // IREE_HAVE_HAL_EXECUTABLE_LOADER_SYSTEM_LIBRARY diff --git a/runtime/src/iree/hal/local/loaders/registration/init.h b/runtime/src/iree/hal/local/loaders/registration/init.h index 8518ffb39183..bce4037c0cf0 100644 --- a/runtime/src/iree/hal/local/loaders/registration/init.h +++ b/runtime/src/iree/hal/local/loaders/registration/init.h @@ -29,7 +29,9 @@ extern "C" { // iree_host_size_t count = 0; // iree_hal_executable_loader_t* loaders[8] = {NULL}; // IREE_RETURN_IF_ERROR(iree_hal_create_all_available_executable_loaders( -// IREE_ARRAYSIZE(loaders), &count, loaders, host_allocator)); +// import_provider, +// IREE_ARRAYSIZE(loaders), &count, loaders, +// host_allocator)); // ... // // use up to count loaders // ... @@ -37,13 +39,16 @@ extern "C" { // iree_hal_executable_loader_release(loaders[i]); // } IREE_API_EXPORT iree_status_t iree_hal_create_all_available_executable_loaders( + iree_hal_executable_import_provider_t import_provider, iree_host_size_t capacity, iree_host_size_t* out_count, iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator); // Creates an executable loader with the given |name|. // |out_executable_loader| must be released by the caller. IREE_API_EXPORT iree_status_t iree_hal_create_executable_loader_by_name( - iree_string_view_t name, iree_allocator_t host_allocator, + iree_string_view_t name, + iree_hal_executable_import_provider_t import_provider, + iree_allocator_t host_allocator, iree_hal_executable_loader_t** out_executable_loader); #ifdef __cplusplus diff --git a/runtime/src/iree/hal/string_util.c b/runtime/src/iree/hal/string_util.c index b6e046dd1239..160ed73d7e13 100644 --- a/runtime/src/iree/hal/string_util.c +++ b/runtime/src/iree/hal/string_util.c @@ -106,6 +106,17 @@ iree_hal_format_shape(iree_host_size_t shape_rank, const iree_hal_dim_t* shape, : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); } +IREE_API_EXPORT iree_status_t iree_hal_append_shape_string( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_string_builder_t* string_builder) { + for (iree_host_size_t i = 0; i < shape_rank; ++i) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_format( + string_builder, (i < shape_rank - 1) ? "%" PRIdim "x" : "%" PRIdim, + shape[i])); + } + return iree_ok_status(); +} + IREE_API_EXPORT iree_status_t iree_hal_parse_element_type( iree_string_view_t value, iree_hal_element_type_t* out_element_type) { IREE_ASSERT_ARGUMENT(out_element_type); @@ -196,6 +207,17 @@ IREE_API_EXPORT iree_status_t iree_hal_format_element_type( : iree_ok_status(); } +IREE_API_EXPORT iree_status_t +iree_hal_append_element_type_string(iree_hal_element_type_t element_type, + iree_string_builder_t* string_builder) { + char temp[8]; + iree_host_size_t length = 0; + IREE_RETURN_IF_ERROR( + iree_hal_format_element_type(element_type, sizeof(temp), temp, &length)); + return iree_string_builder_append_string(string_builder, + iree_make_string_view(temp, length)); +} + IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type( iree_string_view_t value, iree_host_size_t shape_capacity, iree_host_size_t* out_shape_rank, iree_hal_dim_t* out_shape, @@ -244,6 +266,19 @@ IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type( return iree_ok_status(); } +IREE_API_EXPORT iree_status_t iree_hal_append_shape_and_element_type_string( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_hal_element_type_t element_type, + iree_string_builder_t* string_builder) { + if (shape_rank > 0) { + IREE_RETURN_IF_ERROR( + iree_hal_append_shape_string(shape_rank, shape, string_builder)); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_string(string_builder, IREE_SV("x"))); + } + return iree_hal_append_element_type_string(element_type, string_builder); +} + // Parses a string of two character pairs representing hex numbers into bytes. static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to, ptrdiff_t num) { diff --git a/runtime/src/iree/hal/string_util.h b/runtime/src/iree/hal/string_util.h index 3a531165218d..2c14e2f2d9fb 100644 --- a/runtime/src/iree/hal/string_util.h +++ b/runtime/src/iree/hal/string_util.h @@ -32,6 +32,11 @@ iree_hal_format_shape(iree_host_size_t shape_rank, const iree_hal_dim_t* shape, iree_host_size_t buffer_capacity, char* buffer, iree_host_size_t* out_buffer_length); +// Appends shape dimensions to |string_builder| in a `4x5x6` format. +IREE_API_EXPORT iree_status_t iree_hal_append_shape_string( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_string_builder_t* string_builder); + // Parses a serialized iree_hal_element_type_t and sets |out_element_type| if // it is valid. The format is the same as produced by // iree_hal_format_element_type. @@ -48,14 +53,25 @@ IREE_API_EXPORT iree_status_t iree_hal_format_element_type( iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity, char* buffer, iree_host_size_t* out_buffer_length); +// Appends an element type to |string_builder| such as `f16`. +IREE_API_EXPORT iree_status_t +iree_hal_append_element_type_string(iree_hal_element_type_t element_type, + iree_string_builder_t* string_builder); + // Parses a shape and type from a `[shape]x[type]` string |value|. // Behaves the same as calling iree_hal_parse_shape and -// iree_hal_parse_element_type. Ignores any training `=`. +// iree_hal_parse_element_type. Ignores any trailing `=`. IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type( iree_string_view_t value, iree_host_size_t shape_capacity, iree_host_size_t* out_shape_rank, iree_hal_dim_t* out_shape, iree_hal_element_type_t* out_element_type); +// Appends shape dimensions and element type to |string_builder| as `4x5xf32`. +IREE_API_EXPORT iree_status_t iree_hal_append_shape_and_element_type_string( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_hal_element_type_t element_type, + iree_string_builder_t* string_builder); + // Parses a serialized element of |element_type| to its in-memory form. // |data_ptr| must be at least large enough to contain the bytes of the element. // For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4 diff --git a/runtime/src/iree/modules/vmvx/BUILD b/runtime/src/iree/modules/vmvx/BUILD index 05ba13a6a7fa..c4f44c5ea7f7 100644 --- a/runtime/src/iree/modules/vmvx/BUILD +++ b/runtime/src/iree/modules/vmvx/BUILD @@ -29,6 +29,7 @@ iree_runtime_cc_library( deps = [ "//runtime/src/iree/base", "//runtime/src/iree/base:tracing", + "//runtime/src/iree/base/internal:cpu", "//runtime/src/iree/builtins/ukernel", "//runtime/src/iree/vm", ], diff --git a/runtime/src/iree/modules/vmvx/CMakeLists.txt b/runtime/src/iree/modules/vmvx/CMakeLists.txt index 97d5239d64af..82cdb8e46882 100644 --- a/runtime/src/iree/modules/vmvx/CMakeLists.txt +++ b/runtime/src/iree/modules/vmvx/CMakeLists.txt @@ -24,6 +24,7 @@ iree_cc_library( iree::base iree::base::tracing iree::builtins::ukernel + iree::base::internal::cpu iree::vm ${_VMVX_OPTIONAL_DEPS} PUBLIC diff --git a/runtime/src/iree/modules/vmvx/module.c b/runtime/src/iree/modules/vmvx/module.c index 9ae4b6f5040c..da1db6b880f4 100644 --- a/runtime/src/iree/modules/vmvx/module.c +++ b/runtime/src/iree/modules/vmvx/module.c @@ -17,6 +17,7 @@ // Include the ukernel support library so that we can use its implementations // as fixed-function components of the runtime. +#include "iree/base/internal/cpu.h" #include "iree/builtins/ukernel/elementwise.h" #include "iree/builtins/ukernel/mmt4d.h" @@ -104,54 +105,53 @@ static iree_host_size_t iree_vmvx_cast_host_size(int64_t value, return (iree_host_size_t)value; } -#define BUFFER_2D_DECLS(name, dtype, offset, stride0, stride1, size0, size1) \ - uint64_t name##_overflow = 0; \ - iree_host_size_t name##_size0 = \ - iree_vmvx_cast_host_size(size0, &name##_overflow); \ - iree_host_size_t name##_size1 = \ - iree_vmvx_cast_host_size(size1, &name##_overflow); \ - iree_host_size_t name##_stride0 = \ - iree_vmvx_cast_host_size(stride0, &name##_overflow); \ - iree_host_size_t name##_stride1 = \ - iree_vmvx_cast_host_size(stride1, &name##_overflow); \ - iree_host_size_t name##_length_bound = iree_vmvx_2d_length_bound( \ - sizeof(dtype), name##_size0, name##_size1, name##_stride0, \ - name##_stride1, &name##_overflow); \ - iree_host_size_t name##_offset = \ - sizeof(dtype) * iree_vmvx_cast_host_size(offset, &name##_overflow); \ - if (name##_overflow) { \ - IREE_TRACE_ZONE_END(z0); \ - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, \ - "buffer overflow for " #name); \ +#define BUFFER_2D_DECLS(name, dtype_size, offset, stride0, stride1, size0, \ + size1) \ + uint64_t name##_overflow = 0; \ + iree_host_size_t name##_size0 = \ + iree_vmvx_cast_host_size(size0, &name##_overflow); \ + iree_host_size_t name##_size1 = \ + iree_vmvx_cast_host_size(size1, &name##_overflow); \ + iree_host_size_t name##_stride0 = \ + iree_vmvx_cast_host_size(stride0, &name##_overflow); \ + iree_host_size_t name##_stride1 = \ + iree_vmvx_cast_host_size(stride1, &name##_overflow); \ + iree_host_size_t name##_length_bound = iree_vmvx_2d_length_bound( \ + dtype_size, name##_size0, name##_size1, name##_stride0, name##_stride1, \ + &name##_overflow); \ + iree_host_size_t name##_offset = \ + dtype_size * iree_vmvx_cast_host_size(offset, &name##_overflow); \ + if (name##_overflow) { \ + IREE_TRACE_ZONE_END(z0); \ + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, \ + "buffer overflow for " #name); \ } -#define MAP_BUFFER_2D_RO(name, dtype, buffer_ref, offset, stride0, stride1, \ - size0, size1) \ - iree_vm_buffer_t* name##_buffer; \ - iree_const_byte_span_t name##_span; \ - BUFFER_2D_DECLS(name, dtype, offset, stride0, stride1, size0, size1); \ - IREE_RETURN_AND_END_ZONE_IF_ERROR( \ - z0, iree_vm_buffer_check_deref(buffer_ref, &name##_buffer)) \ - IREE_RETURN_AND_END_ZONE_IF_ERROR( \ - z0, iree_vm_buffer_map_ro(name##_buffer, /*offset=*/ \ - name##_offset, /*length=*/ \ - name##_length_bound, /*alignment=*/ \ - sizeof(dtype), &name##_span)); \ - const dtype* name = (dtype*)name##_span.data - -#define MAP_BUFFER_2D_RW(name, dtype, buffer_ref, offset, stride0, stride1, \ - size0, size1) \ - iree_vm_buffer_t* name##_buffer; \ - iree_byte_span_t name##_span; \ - BUFFER_2D_DECLS(name, dtype, offset, stride0, stride1, size0, size1); \ - IREE_RETURN_AND_END_ZONE_IF_ERROR( \ - z0, iree_vm_buffer_check_deref(buffer_ref, &name##_buffer)); \ - IREE_RETURN_AND_END_ZONE_IF_ERROR( \ - z0, iree_vm_buffer_map_rw(name##_buffer, /*offset=*/ \ - name##_offset, /*length=*/ \ - name##_length_bound, \ - /*alignment=*/sizeof(dtype), &name##_span)); \ - dtype* name = (dtype*)name##_span.data +#define MAP_BUFFER_2D_IMPL(mode, ptr_type, span_type, name, dtype_size, \ + buffer_ref, offset, stride0, stride1, size0, size1) \ + iree_vm_buffer_t* name##_buffer; \ + span_type name##_span; \ + BUFFER_2D_DECLS(name, dtype_size, offset, stride0, stride1, size0, size1); \ + IREE_RETURN_AND_END_ZONE_IF_ERROR( \ + z0, iree_vm_buffer_check_deref(buffer_ref, &name##_buffer)) \ + IREE_RETURN_AND_END_ZONE_IF_ERROR( \ + z0, iree_vm_buffer_map_##mode(name##_buffer, /*offset=*/ \ + name##_offset, /*length=*/ \ + name##_length_bound, /*alignment=*/ \ + dtype_size, &name##_span)); \ + ptr_type name = (ptr_type)name##_span.data + +#define MAP_BUFFER_2D_UNTYPED_RO(name, dtype_size, ...) \ + MAP_BUFFER_2D_IMPL(ro, const void*, iree_const_byte_span_t, name, \ + dtype_size, __VA_ARGS__) +#define MAP_BUFFER_2D_UNTYPED_RW(name, dtype_size, ...) \ + MAP_BUFFER_2D_IMPL(rw, void*, iree_byte_span_t, name, dtype_size, __VA_ARGS__) +#define MAP_BUFFER_2D_RO(name, dtype, ...) \ + MAP_BUFFER_2D_IMPL(ro, const dtype*, iree_const_byte_span_t, name, \ + sizeof(dtype), __VA_ARGS__) +#define MAP_BUFFER_2D_RW(name, dtype, ...) \ + MAP_BUFFER_2D_IMPL(rw, dtype*, iree_byte_span_t, name, sizeof(dtype), \ + __VA_ARGS__) //===----------------------------------------------------------------------===// // Shared argument shims @@ -636,7 +636,8 @@ IREE_VMVX_ABI_FIXED_STRUCT(mmt4d, rIIrIIrIIIIIiiii, { }); IREE_VMVX_ABI_DEFINE_SHIM(mmt4d, v); -IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d, v) { +static iree_status_t iree_vmvx_mmt4d(iree_ukernel_mmt4d_type_t type, + const iree_vm_abi_mmt4d_t* args) { IREE_TRACE_ZONE_BEGIN(z0); iree_host_size_t M = (iree_host_size_t)args->m; iree_host_size_t N = (iree_host_size_t)args->n; @@ -647,33 +648,39 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d, v) { iree_host_size_t lhs_tile_size = M0 * K0; iree_host_size_t rhs_tile_size = N0 * K0; iree_host_size_t out_tile_size = M0 * N0; + int lhs_elem_size = iree_ukernel_mmt4d_lhs_elem_size(type); + int rhs_elem_size = iree_ukernel_mmt4d_rhs_elem_size(type); + int out_elem_size = iree_ukernel_mmt4d_out_elem_size(type); // Here are abusing the 2D-specific macros MAP_BUFFER_2D_* to query 4D arrays. // Thanks to the requirement that all dimensions but the outer-most one are // contiguous row-major, the outer-most stride is the only nontrivial stride, // we can correctly coalesce the inner 3 dimensions without changing the // mapped span. - MAP_BUFFER_2D_RO(lhs, float, - /*buffer_ref=*/args->lhs_ref, - /*offset=*/args->lhs_offset, - /*stride0=*/args->lhs_row_stride, - /*stride1=*/1, - /*size0=*/M, - /*size1=*/K* lhs_tile_size); - MAP_BUFFER_2D_RO(rhs, float, - /*buffer_ref=*/args->rhs_ref, - /*offset=*/args->rhs_offset, - /*stride0=*/args->rhs_row_stride, - /*stride1=*/1, - /*size0=*/N, - /*size1=*/K* rhs_tile_size); - MAP_BUFFER_2D_RW(out, float, - /*buffer_ref=*/args->out_ref, - /*offset=*/args->out_offset, - /*stride0=*/args->out_row_stride, - /*stride1=*/1, - /*size0=*/M, - /*size1=*/N* out_tile_size); - iree_ukernel_mmt4d_f32f32f32_params_t ukernel_params = { + MAP_BUFFER_2D_UNTYPED_RO(lhs, + /*dtype_size=*/lhs_elem_size, + /*buffer_ref=*/args->lhs_ref, + /*offset=*/args->lhs_offset, + /*stride0=*/args->lhs_row_stride, + /*stride1=*/1, + /*size0=*/M, + /*size1=*/K * lhs_tile_size); + MAP_BUFFER_2D_UNTYPED_RO(rhs, /*dtype_size=*/rhs_elem_size, + /*buffer_ref=*/args->rhs_ref, + /*offset=*/args->rhs_offset, + /*stride0=*/args->rhs_row_stride, + /*stride1=*/1, + /*size0=*/N, + /*size1=*/K * rhs_tile_size); + MAP_BUFFER_2D_UNTYPED_RW(out, /*dtype_size=*/out_elem_size, + /*buffer_ref=*/args->out_ref, + /*offset=*/args->out_offset, + /*stride0=*/args->out_row_stride, + /*stride1=*/1, + /*size0=*/M, + /*size1=*/N * out_tile_size); + iree_ukernel_mmt4d_params_t ukernel_params = { + .type = type, + .flags = args->flags, .lhs_buffer = lhs, .rhs_buffer = rhs, .out_buffer = out, @@ -686,76 +693,23 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d, v) { .M0 = M0, .N0 = N0, .K0 = K0, - .flags = args->flags, + .cpu_data = iree_cpu_data_fields(), }; - int ukernel_retcode = iree_ukernel_mmt4d_f32f32f32(&ukernel_params); + iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&ukernel_params); IREE_TRACE_ZONE_END(z0); - if (ukernel_retcode) { + if (status != iree_ukernel_mmt4d_status_ok) { return iree_make_status(IREE_STATUS_INTERNAL, - iree_ukernel_mmt4d_error_message(ukernel_retcode)); + iree_ukernel_mmt4d_status_message(status)); } return iree_ok_status(); } +IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d, v) { + return iree_vmvx_mmt4d(iree_ukernel_mmt4d_type_f32f32f32, args); +} + IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_i8i8i32, mmt4d, v) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_host_size_t M = (iree_host_size_t)args->m; - iree_host_size_t N = (iree_host_size_t)args->n; - iree_host_size_t K = (iree_host_size_t)args->k; - iree_host_size_t M0 = (iree_host_size_t)args->m0; - iree_host_size_t N0 = (iree_host_size_t)args->n0; - iree_host_size_t K0 = (iree_host_size_t)args->k0; - iree_host_size_t lhs_tile_size = M0 * K0; - iree_host_size_t rhs_tile_size = N0 * K0; - iree_host_size_t out_tile_size = M0 * N0; - // Here are abusing the 2D-specific macros MAP_BUFFER_2D_* to query 4D arrays. - // Thanks to the requirement that all dimensions but the outer-most one are - // contiguous row-major, the outer-most stride is the only nontrivial stride, - // we can correctly coalesce the inner 3 dimensions without changing the - // mapped span. - MAP_BUFFER_2D_RO(lhs, int8_t, - /*buffer_ref=*/args->lhs_ref, - /*offset=*/args->lhs_offset, - /*stride0=*/args->lhs_row_stride, - /*stride1=*/1, - /*size0=*/M, - /*size1=*/K * lhs_tile_size); - MAP_BUFFER_2D_RO(rhs, int8_t, - /*buffer_ref=*/args->rhs_ref, - /*offset=*/args->rhs_offset, - /*stride0=*/args->rhs_row_stride, - /*stride1=*/1, - /*size0=*/N, - /*size1=*/K * rhs_tile_size); - MAP_BUFFER_2D_RW(out, int32_t, - /*buffer_ref=*/args->out_ref, - /*offset=*/args->out_offset, - /*stride0=*/args->out_row_stride, - /*stride1=*/1, - /*size0=*/M, - /*size1=*/N * out_tile_size); - iree_ukernel_mmt4d_i8i8i32_params_t ukernel_params = { - .lhs_buffer = lhs, - .rhs_buffer = rhs, - .out_buffer = out, - .lhs_stride = lhs_stride0, - .rhs_stride = rhs_stride0, - .out_stride = out_stride0, - .M = M, - .N = N, - .K = K, - .M0 = M0, - .N0 = N0, - .K0 = K0, - .flags = args->flags, - }; - int ukernel_retcode = iree_ukernel_mmt4d_i8i8i32(&ukernel_params); - IREE_TRACE_ZONE_END(z0); - if (ukernel_retcode) { - return iree_make_status(IREE_STATUS_INTERNAL, - iree_ukernel_mmt4d_error_message(ukernel_retcode)); - } - return iree_ok_status(); + return iree_vmvx_mmt4d(iree_ukernel_mmt4d_type_i8i8i32, args); } //===----------------------------------------------------------------------===// diff --git a/runtime/src/iree/testing/benchmark_full.cc b/runtime/src/iree/testing/benchmark_full.cc index c01abf0b6073..69185abf7574 100644 --- a/runtime/src/iree/testing/benchmark_full.cc +++ b/runtime/src/iree/testing/benchmark_full.cc @@ -135,7 +135,7 @@ void iree_benchmark_register(iree_string_view_t name, } if (benchmark_def->minimum_duration_ns != 0) { - instance->MinTime((double)benchmark_def->minimum_duration_ns / 1e-9); + instance->MinTime((double)benchmark_def->minimum_duration_ns * 1e-9); } else if (benchmark_def->iteration_count != 0) { instance->Iterations(benchmark_def->iteration_count); } diff --git a/runtime/src/iree/testing/status_matchers.h b/runtime/src/iree/testing/status_matchers.h index 1697e4cd9cc7..76509fd481b4 100644 --- a/runtime/src/iree/testing/status_matchers.h +++ b/runtime/src/iree/testing/status_matchers.h @@ -358,7 +358,7 @@ inline internal::IsOkMatcherGenerator IsOk() { template void PrintTo(const StatusOr &statusor, std::ostream *os) { if (!statusor.ok()) { - *os << statusor.status(); + *os << statusor.status().ToString(); } else { *os << "OK: " << ::testing::PrintToString(statusor.value()); } diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD index e793590da9be..b4d71c6b8860 100644 --- a/runtime/src/iree/tooling/BUILD +++ b/runtime/src/iree/tooling/BUILD @@ -12,6 +12,67 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "buffer_view_matchers", + srcs = ["buffer_view_matchers.c"], + hdrs = ["buffer_view_matchers.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/base:tracing", + "//runtime/src/iree/base/internal", + "//runtime/src/iree/hal", + ], +) + +iree_runtime_cc_test( + name = "buffer_view_matchers_test", + srcs = ["buffer_view_matchers_test.cc"], + deps = [ + ":buffer_view_matchers", + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal", + "//runtime/src/iree/base/internal:span", + "//runtime/src/iree/hal", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/testing:gtest_main", + ], +) + +cc_library( + name = "comparison", + srcs = ["comparison.cc"], + hdrs = ["comparison.h"], + deps = [ + ":buffer_view_matchers", + ":vm_util", + "//runtime/src/iree/base", + "//runtime/src/iree/base:cc", + "//runtime/src/iree/base:tracing", + "//runtime/src/iree/base/internal:flags", + "//runtime/src/iree/base/internal:span", + "//runtime/src/iree/hal", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/vm", + "//runtime/src/iree/vm:cc", + ], +) + +cc_test( + name = "comparison_test", + srcs = ["comparison_test.cc"], + deps = [ + ":comparison", + ":vm_util", + "//runtime/src/iree/base", + "//runtime/src/iree/hal", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/testing:gtest_main", + "//runtime/src/iree/vm", + "//runtime/src/iree/vm:cc", + ], +) + cc_library( name = "context_util", srcs = ["context_util.c"], diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt index 22a901b140d3..c20fade9fe72 100644 --- a/runtime/src/iree/tooling/CMakeLists.txt +++ b/runtime/src/iree/tooling/CMakeLists.txt @@ -10,6 +10,75 @@ iree_add_all_subdirs() +iree_cc_library( + NAME + buffer_view_matchers + HDRS + "buffer_view_matchers.h" + SRCS + "buffer_view_matchers.c" + DEPS + iree::base + iree::base::internal + iree::base::tracing + iree::hal + PUBLIC +) + +iree_cc_test( + NAME + buffer_view_matchers_test + SRCS + "buffer_view_matchers_test.cc" + DEPS + ::buffer_view_matchers + iree::base + iree::base::internal + iree::base::internal::span + iree::hal + iree::testing::gtest + iree::testing::gtest_main +) + +iree_cc_library( + NAME + comparison + HDRS + "comparison.h" + SRCS + "comparison.cc" + DEPS + ::buffer_view_matchers + ::vm_util + iree::base + iree::base::cc + iree::base::internal::flags + iree::base::internal::span + iree::base::tracing + iree::hal + iree::modules::hal + iree::vm + iree::vm::cc + PUBLIC +) + +iree_cc_test( + NAME + comparison_test + SRCS + "comparison_test.cc" + DEPS + ::comparison + ::vm_util + iree::base + iree::hal + iree::modules::hal + iree::testing::gtest + iree::testing::gtest_main + iree::vm + iree::vm::cc +) + iree_cc_library( NAME context_util diff --git a/runtime/src/iree/tooling/buffer_view_matchers.c b/runtime/src/iree/tooling/buffer_view_matchers.c new file mode 100644 index 000000000000..9112ec2d612d --- /dev/null +++ b/runtime/src/iree/tooling/buffer_view_matchers.c @@ -0,0 +1,663 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/tooling/buffer_view_matchers.h" + +#include + +#include "iree/base/internal/math.h" +#include "iree/base/tracing.h" + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_equality_t +//===----------------------------------------------------------------------===// + +static iree_hal_buffer_element_t iree_hal_buffer_element_at( + iree_hal_element_type_t element_type, iree_const_byte_span_t elements, + iree_host_size_t index) { + iree_host_size_t element_size = + iree_hal_element_dense_byte_count(element_type); + iree_const_byte_span_t element_data = iree_make_const_byte_span( + elements.data + index * element_size, element_size); + iree_hal_buffer_element_t element = { + .type = element_type, + }; + memcpy(element.storage, element_data.data, element_size); + return element; +} + +static iree_status_t iree_hal_append_element_string( + iree_hal_buffer_element_t value, iree_string_builder_t* builder) { + char temp[64]; + iree_host_size_t temp_length = 0; + IREE_RETURN_IF_ERROR(iree_hal_format_element( + iree_make_const_byte_span(value.storage, + iree_hal_element_dense_byte_count(value.type)), + value.type, sizeof(temp), temp, &temp_length)); + return iree_string_builder_append_string( + builder, iree_make_string_view(temp, temp_length)); +} + +static iree_status_t iree_hal_append_element_mismatch_string( + iree_host_size_t index, iree_hal_buffer_element_t expected_element, + iree_hal_buffer_element_t actual_element, iree_string_builder_t* builder) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_format( + builder, "element at index %" PRIhsz " (", index)); + IREE_RETURN_IF_ERROR(iree_hal_append_element_string(actual_element, builder)); + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV(") does not match the expected ("))); + IREE_RETURN_IF_ERROR( + iree_hal_append_element_string(expected_element, builder)); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_string(builder, IREE_SV(")"))); + return iree_ok_status(); +} + +static bool iree_hal_compare_strided_elements_exact( + iree_hal_element_type_t element_type, iree_host_size_t element_count, + iree_const_byte_span_t expected_elements, iree_host_size_t expected_stride, + iree_const_byte_span_t actual_elements, iree_host_size_t actual_stride, + iree_host_size_t* out_index) { + const iree_host_size_t element_size = + iree_hal_element_dense_byte_count(element_type); + const uint8_t* expected_ptr = expected_elements.data; + const uint8_t* actual_ptr = actual_elements.data; + for (iree_host_size_t i = 0; i < element_count; ++i) { + int cmp = memcmp(expected_ptr, actual_ptr, element_size); + if (cmp != 0) { + *out_index = i; + return false; + } + expected_ptr += expected_stride * element_size; + actual_ptr += actual_stride * element_size; + } + return true; +} + +static bool iree_hal_compare_strided_elements_approximate_absolute_f16( + iree_hal_buffer_equality_t equality, iree_host_size_t element_count, + const uint16_t* expected_ptr, iree_host_size_t expected_stride, + const uint16_t* actual_ptr, iree_host_size_t actual_stride, + iree_host_size_t* out_index) { + for (iree_host_size_t i = 0; i < element_count; ++i) { + if (fabsf(iree_math_f16_to_f32(*expected_ptr) - + iree_math_f16_to_f32(*actual_ptr)) > equality.f16_threshold) { + *out_index = i; + return false; + } + expected_ptr += expected_stride; + actual_ptr += actual_stride; + } + return true; +} + +static bool iree_hal_compare_strided_elements_approximate_absolute_f32( + iree_hal_buffer_equality_t equality, iree_host_size_t element_count, + const float* expected_ptr, iree_host_size_t expected_stride, + const float* actual_ptr, iree_host_size_t actual_stride, + iree_host_size_t* out_index) { + for (iree_host_size_t i = 0; i < element_count; ++i) { + if (fabsf(*expected_ptr - *actual_ptr) > equality.f32_threshold) { + *out_index = i; + return false; + } + expected_ptr += expected_stride; + actual_ptr += actual_stride; + } + return true; +} + +static bool iree_hal_compare_strided_elements_approximate_absolute_f64( + iree_hal_buffer_equality_t equality, iree_host_size_t element_count, + const double* expected_ptr, iree_host_size_t expected_stride, + const double* actual_ptr, iree_host_size_t actual_stride, + iree_host_size_t* out_index) { + for (iree_host_size_t i = 0; i < element_count; ++i) { + if (fabs(*expected_ptr - *actual_ptr) > equality.f64_threshold) { + *out_index = i; + return false; + } + expected_ptr += expected_stride; + actual_ptr += actual_stride; + } + return true; +} + +static bool iree_hal_compare_strided_elements_approximate_absolute( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t expected_elements, + iree_host_size_t expected_stride, iree_const_byte_span_t actual_elements, + iree_host_size_t actual_stride, iree_host_size_t* out_index) { + switch (element_type) { + case IREE_HAL_ELEMENT_TYPE_FLOAT_16: + return iree_hal_compare_strided_elements_approximate_absolute_f16( + equality, element_count, (const uint16_t*)expected_elements.data, + expected_stride, (const uint16_t*)actual_elements.data, actual_stride, + out_index); + case IREE_HAL_ELEMENT_TYPE_FLOAT_32: + return iree_hal_compare_strided_elements_approximate_absolute_f32( + equality, element_count, (const float*)expected_elements.data, + expected_stride, (const float*)actual_elements.data, actual_stride, + out_index); + case IREE_HAL_ELEMENT_TYPE_FLOAT_64: + return iree_hal_compare_strided_elements_approximate_absolute_f64( + equality, element_count, (const double*)expected_elements.data, + expected_stride, (const double*)actual_elements.data, actual_stride, + out_index); + default: + return iree_hal_compare_strided_elements_exact( + element_type, element_count, expected_elements, expected_stride, + actual_elements, actual_stride, out_index); + } +} + +// Compares two buffers element by element. +// The provided strides (in elements) are applied up to |element_count|. +static bool iree_hal_compare_strided_elements( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t expected_elements, + iree_host_size_t expected_stride, iree_const_byte_span_t actual_elements, + iree_host_size_t actual_stride, iree_host_size_t* out_index) { + switch (equality.mode) { + case IREE_HAL_BUFFER_EQUALITY_EXACT: + return iree_hal_compare_strided_elements_exact( + element_type, element_count, expected_elements, expected_stride, + actual_elements, actual_stride, out_index); + case IREE_HAL_BUFFER_EQUALITY_APPROXIMATE_ABSOLUTE: + return iree_hal_compare_strided_elements_approximate_absolute( + equality, element_type, element_count, expected_elements, + expected_stride, actual_elements, actual_stride, out_index); + default: + IREE_ASSERT(false && "unhandled equality mode"); + return false; + } +} + +bool iree_hal_compare_buffer_elements_broadcast( + iree_hal_buffer_equality_t equality, + iree_hal_buffer_element_t expected_element, iree_host_size_t element_count, + iree_const_byte_span_t actual_elements, iree_host_size_t* out_index) { + return iree_hal_compare_strided_elements( + equality, expected_element.type, element_count, + iree_make_const_byte_span( + expected_element.storage, + iree_hal_element_dense_byte_count(expected_element.type)), + /*expected_stride=*/0, actual_elements, /*actual_stride=*/1, out_index); +} + +bool iree_hal_compare_buffer_elements_elementwise( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t expected_elements, + iree_const_byte_span_t actual_elements, iree_host_size_t* out_index) { + return iree_hal_compare_strided_elements( + equality, element_type, element_count, expected_elements, + /*expected_stride=*/1, actual_elements, /*actual_stride=*/1, out_index); +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_metadata_matcher_t +//===----------------------------------------------------------------------===// + +iree_status_t iree_hal_buffer_view_metadata_matcher_initialize( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_hal_element_type_t element_type, + iree_hal_encoding_type_t encoding_type, + iree_hal_buffer_view_metadata_matcher_t* out_matcher) { + IREE_ASSERT_ARGUMENT(!shape_rank || shape); + memset(out_matcher, 0, sizeof(*out_matcher)); + if (shape_rank > IREE_ARRAYSIZE(out_matcher->shape)) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "maximum shape rank exceeded"); + } + out_matcher->shape_rank = shape_rank; + memcpy(out_matcher->shape, shape, shape_rank * sizeof(*shape)); + out_matcher->element_type = element_type; + out_matcher->encoding_type = encoding_type; + return iree_ok_status(); +} + +void iree_hal_buffer_view_metadata_matcher_deinitialize( + iree_hal_buffer_view_metadata_matcher_t* matcher) { + IREE_ASSERT_ARGUMENT(matcher); + memset(matcher, 0, sizeof(*matcher)); +} + +iree_status_t iree_hal_buffer_view_metadata_matcher_describe( + iree_hal_buffer_view_metadata_matcher_t* matcher, + iree_string_builder_t* builder) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(builder); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_string(builder, IREE_SV("matches "))); + IREE_RETURN_IF_ERROR(iree_hal_append_shape_and_element_type_string( + matcher->shape_rank, matcher->shape, matcher->element_type, builder)); + return iree_ok_status(); +} + +static bool iree_hal_buffer_view_shape_matches( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_hal_buffer_view_t* matchee) { + if (shape_rank != iree_hal_buffer_view_shape_rank(matchee)) return false; + for (iree_host_size_t i = 0; i < shape_rank; ++i) { + if (shape[i] != iree_hal_buffer_view_shape_dim(matchee, i)) return false; + } + return true; +} + +iree_status_t iree_hal_buffer_view_metadata_matcher_match( + iree_hal_buffer_view_metadata_matcher_t* matcher, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(matchee); + IREE_ASSERT_ARGUMENT(builder); + IREE_ASSERT_ARGUMENT(out_matched); + *out_matched = false; + + const bool shape_match = iree_hal_buffer_view_shape_matches( + matcher->shape_rank, matcher->shape, matchee); + const bool element_type_match = + matcher->element_type == IREE_HAL_ELEMENT_TYPE_NONE || + matcher->element_type == iree_hal_buffer_view_element_type(matchee); + const bool encoding_type_match = + matcher->encoding_type == IREE_HAL_ENCODING_TYPE_OPAQUE || + matcher->encoding_type == iree_hal_buffer_view_encoding_type(matchee); + if (shape_match && element_type_match && encoding_type_match) { + *out_matched = true; + return iree_ok_status(); + } + + IREE_RETURN_IF_ERROR( + iree_string_builder_append_string(builder, IREE_SV("metadata is "))); + IREE_RETURN_IF_ERROR(iree_hal_append_shape_and_element_type_string( + iree_hal_buffer_view_shape_rank(matchee), + iree_hal_buffer_view_shape_dims(matchee), + iree_hal_buffer_view_element_type(matchee), builder)); + + *out_matched = false; + return iree_ok_status(); +} + +iree_status_t iree_hal_buffer_view_match_metadata( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_hal_element_type_t element_type, + iree_hal_encoding_type_t encoding_type, iree_hal_buffer_view_t* matchee, + iree_string_builder_t* builder, bool* out_matched) { + iree_hal_buffer_view_metadata_matcher_t matcher; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_metadata_matcher_initialize( + shape_rank, shape, element_type, encoding_type, &matcher)); + iree_status_t status = iree_hal_buffer_view_metadata_matcher_match( + &matcher, matchee, builder, out_matched); + if (iree_status_is_ok(status) && !*out_matched) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("; expected that the view "))); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_metadata_matcher_describe(&matcher, builder)); + } + iree_hal_buffer_view_metadata_matcher_deinitialize(&matcher); + return status; +} + +iree_status_t iree_hal_buffer_view_match_metadata_like( + iree_hal_buffer_view_t* expected, iree_hal_buffer_view_t* matchee, + iree_string_builder_t* builder, bool* out_matched) { + IREE_ASSERT_ARGUMENT(expected); + return iree_hal_buffer_view_match_metadata( + iree_hal_buffer_view_shape_rank(expected), + iree_hal_buffer_view_shape_dims(expected), + iree_hal_buffer_view_element_type(expected), + iree_hal_buffer_view_encoding_type(expected), matchee, builder, + out_matched); +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_element_matcher_t +//===----------------------------------------------------------------------===// + +iree_status_t iree_hal_buffer_view_element_matcher_initialize( + iree_hal_buffer_equality_t equality, iree_hal_buffer_element_t value, + iree_hal_buffer_view_element_matcher_t* out_matcher) { + memset(out_matcher, 0, sizeof(*out_matcher)); + out_matcher->equality = equality; + out_matcher->value = value; + return iree_ok_status(); +} + +void iree_hal_buffer_view_element_matcher_deinitialize( + iree_hal_buffer_view_element_matcher_t* matcher) { + IREE_ASSERT_ARGUMENT(matcher); + memset(matcher, 0, sizeof(*matcher)); +} + +iree_status_t iree_hal_buffer_view_element_matcher_describe( + iree_hal_buffer_view_element_matcher_t* matcher, + iree_string_builder_t* builder) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(builder); + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("has all elements match "))); + IREE_RETURN_IF_ERROR( + iree_hal_append_element_type_string(matcher->value.type, builder)); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_string(builder, IREE_SV("="))); + IREE_RETURN_IF_ERROR(iree_hal_append_element_string(matcher->value, builder)); + return iree_ok_status(); +} + +iree_status_t iree_hal_buffer_view_element_matcher_match( + iree_hal_buffer_view_element_matcher_t* matcher, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(matchee); + IREE_ASSERT_ARGUMENT(builder); + IREE_ASSERT_ARGUMENT(out_matched); + *out_matched = false; + + if (iree_hal_buffer_view_encoding_type(matchee) != + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-dense encodings not supported for matching"); + } else if (iree_hal_buffer_view_element_type(matchee) != + matcher->value.type) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("whose element type ("))); + IREE_RETURN_IF_ERROR(iree_hal_append_element_type_string( + iree_hal_buffer_view_element_type(matchee), builder)); + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV(") does not match expected ("))); + IREE_RETURN_IF_ERROR( + iree_hal_append_element_type_string(matcher->value.type, builder)); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_string(builder, IREE_SV(")"))); + *out_matched = false; + return iree_ok_status(); + } + + iree_hal_buffer_mapping_t actual_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + iree_hal_buffer_view_buffer(matchee), IREE_HAL_MAPPING_MODE_SCOPED, + IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &actual_mapping)); + iree_const_byte_span_t actual_contents = iree_make_const_byte_span( + actual_mapping.contents.data, actual_mapping.contents.data_length); + + iree_host_size_t i = 0; + const bool all_match = iree_hal_compare_buffer_elements_broadcast( + matcher->equality, matcher->value, + iree_hal_buffer_view_element_count(matchee), actual_contents, &i); + iree_hal_buffer_element_t actual_element = iree_hal_buffer_element_at( + iree_hal_buffer_view_element_type(matchee), actual_contents, i); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&actual_mapping)); + + if (!all_match) { + IREE_RETURN_IF_ERROR(iree_hal_append_element_mismatch_string( + i, matcher->value, actual_element, builder)); + } + + *out_matched = all_match; + return iree_ok_status(); +} + +iree_status_t iree_hal_buffer_view_match_elements( + iree_hal_buffer_equality_t equality, iree_hal_buffer_element_t value, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched) { + iree_hal_buffer_view_element_matcher_t matcher; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_element_matcher_initialize( + equality, value, &matcher)); + iree_status_t status = iree_hal_buffer_view_element_matcher_match( + &matcher, matchee, builder, out_matched); + if (iree_status_is_ok(status) && !*out_matched) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("; expected that the view "))); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_element_matcher_describe(&matcher, builder)); + } + iree_hal_buffer_view_element_matcher_deinitialize(&matcher); + return status; +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_array_matcher_t +//===----------------------------------------------------------------------===// + +iree_status_t iree_hal_buffer_view_array_matcher_initialize( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t elements, + iree_hal_buffer_view_array_matcher_t* out_matcher) { + IREE_ASSERT_ARGUMENT(!element_count || + !iree_const_byte_span_is_empty(elements)); + memset(out_matcher, 0, sizeof(*out_matcher)); + out_matcher->equality = equality; + out_matcher->element_type = element_type; + out_matcher->element_count = element_count; + out_matcher->elements = elements; + return iree_ok_status(); +} + +void iree_hal_buffer_view_array_matcher_deinitialize( + iree_hal_buffer_view_array_matcher_t* matcher) { + IREE_ASSERT_ARGUMENT(matcher); + memset(matcher, 0, sizeof(*matcher)); +} + +iree_status_t iree_hal_buffer_view_array_matcher_describe( + iree_hal_buffer_view_array_matcher_t* matcher, + iree_string_builder_t* builder) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(builder); + IREE_RETURN_IF_ERROR(iree_string_builder_append_format( + builder, + "has all elements match those in %" PRIhsz " element of ", + matcher->element_count)); + IREE_RETURN_IF_ERROR( + iree_hal_append_element_type_string(matcher->element_type, builder)); + // TODO(benvanik): format array contents (elided)? make caller do? + return iree_ok_status(); +} + +iree_status_t iree_hal_buffer_view_array_matcher_match( + iree_hal_buffer_view_array_matcher_t* matcher, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(matchee); + IREE_ASSERT_ARGUMENT(builder); + IREE_ASSERT_ARGUMENT(out_matched); + *out_matched = false; + + if (iree_hal_buffer_view_encoding_type(matchee) != + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-dense encodings not supported for matching"); + } else if (iree_hal_buffer_view_element_type(matchee) != + matcher->element_type) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("whose element type ("))); + IREE_RETURN_IF_ERROR(iree_hal_append_element_type_string( + iree_hal_buffer_view_element_type(matchee), builder)); + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV(") does not match expected ("))); + IREE_RETURN_IF_ERROR( + iree_hal_append_element_type_string(matcher->element_type, builder)); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_string(builder, IREE_SV(")"))); + *out_matched = false; + return iree_ok_status(); + } else if (iree_hal_buffer_view_element_count(matchee) != + matcher->element_count) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_format( + builder, + "whose element count (%" PRIhsz ") does not match expected (%" PRIhsz + ")", + iree_hal_buffer_view_element_count(matchee), matcher->element_count)); + *out_matched = false; + return iree_ok_status(); + } + + iree_hal_buffer_mapping_t actual_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + iree_hal_buffer_view_buffer(matchee), IREE_HAL_MAPPING_MODE_SCOPED, + IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &actual_mapping)); + iree_const_byte_span_t actual_contents = iree_make_const_byte_span( + actual_mapping.contents.data, actual_mapping.contents.data_length); + + iree_host_size_t i = 0; + const bool all_match = iree_hal_compare_buffer_elements_elementwise( + matcher->equality, iree_hal_buffer_view_element_type(matchee), + iree_hal_buffer_view_element_count(matchee), matcher->elements, + actual_contents, &i); + iree_hal_buffer_element_t actual_element = iree_hal_buffer_element_at( + iree_hal_buffer_view_element_type(matchee), actual_contents, i); + iree_hal_buffer_element_t expected_element = iree_hal_buffer_element_at( + iree_hal_buffer_view_element_type(matchee), matcher->elements, i); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&actual_mapping)); + + if (!all_match) { + IREE_RETURN_IF_ERROR(iree_hal_append_element_mismatch_string( + i, expected_element, actual_element, builder)); + } + + *out_matched = all_match; + return iree_ok_status(); +} + +iree_status_t iree_hal_buffer_view_match_array( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t elements, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched) { + iree_hal_buffer_view_array_matcher_t matcher; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_array_matcher_initialize( + equality, element_type, element_count, elements, &matcher)); + iree_status_t status = iree_hal_buffer_view_array_matcher_match( + &matcher, matchee, builder, out_matched); + if (iree_status_is_ok(status) && !*out_matched) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("; expected that the view "))); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_array_matcher_describe(&matcher, builder)); + } + iree_hal_buffer_view_array_matcher_deinitialize(&matcher); + return status; +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_matcher_t +//===----------------------------------------------------------------------===// + +iree_status_t iree_hal_buffer_view_matcher_initialize( + iree_hal_buffer_equality_t equality, iree_hal_buffer_view_t* expected, + iree_hal_buffer_view_matcher_t* out_matcher) { + IREE_ASSERT_ARGUMENT(expected); + memset(out_matcher, 0, sizeof(*out_matcher)); + out_matcher->equality = equality; + out_matcher->expected = expected; + iree_hal_buffer_view_retain(expected); + return iree_ok_status(); +} + +void iree_hal_buffer_view_matcher_deinitialize( + iree_hal_buffer_view_matcher_t* matcher) { + IREE_ASSERT_ARGUMENT(matcher); + iree_hal_buffer_view_release(matcher->expected); + memset(matcher, 0, sizeof(*matcher)); +} + +iree_status_t iree_hal_buffer_view_matcher_describe( + iree_hal_buffer_view_matcher_t* matcher, iree_string_builder_t* builder) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(builder); + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("is equal to contents of a view of "))); + IREE_RETURN_IF_ERROR(iree_hal_append_shape_and_element_type_string( + iree_hal_buffer_view_shape_rank(matcher->expected), + iree_hal_buffer_view_shape_dims(matcher->expected), + iree_hal_buffer_view_element_type(matcher->expected), builder)); + // TODO(benvanik): format buffer view contents (elided)? make caller do? + return iree_ok_status(); +} + +iree_status_t iree_hal_buffer_view_matcher_match( + iree_hal_buffer_view_matcher_t* matcher, iree_hal_buffer_view_t* matchee, + iree_string_builder_t* builder, bool* out_matched) { + IREE_ASSERT_ARGUMENT(matcher); + IREE_ASSERT_ARGUMENT(matchee); + IREE_ASSERT_ARGUMENT(builder); + IREE_ASSERT_ARGUMENT(out_matched); + *out_matched = false; + + if (iree_hal_buffer_view_encoding_type(matchee) != + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-dense encodings not supported for matching"); + } + + // Reuse metadata matching to ensure the buffer views are the same shape/type. + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_match_metadata_like( + matcher->expected, matchee, builder, out_matched)); + if (!*out_matched) return iree_ok_status(); + + iree_hal_buffer_mapping_t actual_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + iree_hal_buffer_view_buffer(matchee), IREE_HAL_MAPPING_MODE_SCOPED, + IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &actual_mapping)); + iree_hal_buffer_mapping_t expected_mapping; + iree_status_t status = iree_hal_buffer_map_range( + iree_hal_buffer_view_buffer(matcher->expected), + IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ, 0, + IREE_WHOLE_BUFFER, &expected_mapping); + if (!iree_status_is_ok(status)) { + iree_hal_buffer_unmap_range(&actual_mapping); + return status; + } + iree_const_byte_span_t actual_contents = iree_make_const_byte_span( + actual_mapping.contents.data, actual_mapping.contents.data_length); + iree_const_byte_span_t expected_contents = iree_make_const_byte_span( + expected_mapping.contents.data, expected_mapping.contents.data_length); + + iree_host_size_t i = 0; + const bool all_match = iree_hal_compare_buffer_elements_elementwise( + matcher->equality, iree_hal_buffer_view_element_type(matchee), + iree_hal_buffer_view_element_count(matchee), expected_contents, + actual_contents, &i); + iree_hal_buffer_element_t actual_element = iree_hal_buffer_element_at( + iree_hal_buffer_view_element_type(matchee), actual_contents, i); + iree_hal_buffer_element_t expected_element = iree_hal_buffer_element_at( + iree_hal_buffer_view_element_type(matchee), expected_contents, i); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&actual_mapping)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&expected_mapping)); + + if (!all_match) { + IREE_RETURN_IF_ERROR(iree_hal_append_element_mismatch_string( + i, expected_element, actual_element, builder)); + } + + *out_matched = all_match; + return iree_ok_status(); +} + +iree_status_t iree_hal_buffer_view_match_equal( + iree_hal_buffer_equality_t equality, iree_hal_buffer_view_t* expected, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched) { + iree_hal_buffer_view_matcher_t matcher; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_matcher_initialize(equality, expected, &matcher)); + iree_status_t status = iree_hal_buffer_view_matcher_match( + &matcher, matchee, builder, out_matched); + if (iree_status_is_ok(status) && !*out_matched) { + IREE_RETURN_IF_ERROR(iree_string_builder_append_string( + builder, IREE_SV("; expected that the view "))); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_matcher_describe(&matcher, builder)); + } + iree_hal_buffer_view_matcher_deinitialize(&matcher); + return status; +} diff --git a/runtime/src/iree/tooling/buffer_view_matchers.h b/runtime/src/iree/tooling/buffer_view_matchers.h new file mode 100644 index 000000000000..f6a6b078957b --- /dev/null +++ b/runtime/src/iree/tooling/buffer_view_matchers.h @@ -0,0 +1,273 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// Buffer view matchers +//===----------------------------------------------------------------------===// +// +// Provides a set of gtest-like buffer views matchers that can either be +// wrapped in C++ and exposed directly to gtest or used programmatically to +// perform buffer view comparisons. +// +// Each matcher has a simple method that returns whether the match was +// successful. Most code should prefer those. +// +// Support for rare element types and encodings are added as-needed and will +// generally return match failure or a status error when unimplemented. +// +// TODO(benvanik): add C++ wrappers in iree/testing/. + +#ifndef IREE_TOOLING_BUFFER_VIEW_MATCHERS_H_ +#define IREE_TOOLING_BUFFER_VIEW_MATCHERS_H_ + +#include "iree/base/api.h" +#include "iree/base/internal/math.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_equality_t +//===----------------------------------------------------------------------===// + +typedef enum { + // a == b + IREE_HAL_BUFFER_EQUALITY_EXACT = 0, + // abs(a - b) <= threshold + IREE_HAL_BUFFER_EQUALITY_APPROXIMATE_ABSOLUTE, +} iree_hal_buffer_equality_mode_t; + +// TODO(benvanik): initializers/configuration for equality comparisons. +typedef struct { + iree_hal_buffer_equality_mode_t mode; + // TODO(benvanik): allow override in approximate modes (ULP, abs/rel diff). + // For now we just have some hardcoded types that are used in place of + // compile-time constants. Consider these provisional. + float f16_threshold; + float f32_threshold; + double f64_threshold; +} iree_hal_buffer_equality_t; + +// Variant type storing known HAL buffer elements. +typedef struct { + iree_hal_element_type_t type; + union { + int8_t i8; + int16_t i16; + int32_t i32; + int64_t i64; + float f32; + double f64; + uint8_t storage[8]; // max size of all value types + }; +} iree_hal_buffer_element_t; + +static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_i8( + int8_t value) { + iree_hal_buffer_element_t element; + element.type = IREE_HAL_ELEMENT_TYPE_INT_8; + element.i8 = value; + return element; +} + +static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_i16( + int16_t value) { + iree_hal_buffer_element_t element; + element.type = IREE_HAL_ELEMENT_TYPE_INT_16; + element.i16 = value; + return element; +} + +static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_i32( + int32_t value) { + iree_hal_buffer_element_t element; + element.type = IREE_HAL_ELEMENT_TYPE_INT_32; + element.i32 = value; + return element; +} + +static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_i64( + int64_t value) { + iree_hal_buffer_element_t element; + element.type = IREE_HAL_ELEMENT_TYPE_INT_64; + element.i64 = value; + return element; +} + +static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_f16( + float value) { + iree_hal_buffer_element_t element; + element.type = IREE_HAL_ELEMENT_TYPE_FLOAT_16; + element.i16 = iree_math_f32_to_f16(value); + return element; +} + +static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_f32( + float value) { + iree_hal_buffer_element_t element; + element.type = IREE_HAL_ELEMENT_TYPE_FLOAT_32; + element.f32 = value; + return element; +} + +static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_f64( + double value) { + iree_hal_buffer_element_t element; + element.type = IREE_HAL_ELEMENT_TYPE_FLOAT_64; + element.f64 = value; + return element; +} + +// Returns true if all elements match the uniform value based on |equality|. +// |out_index| will contain the first index that does not match. +bool iree_hal_compare_buffer_elements_broadcast( + iree_hal_buffer_equality_t equality, + iree_hal_buffer_element_t expected_element, iree_host_size_t element_count, + iree_const_byte_span_t actual_elements, iree_host_size_t* out_index); + +// Returns true if all elements match based on |equality|. +// |out_index| will contain the first index that does not match. +bool iree_hal_compare_buffer_elements_elementwise( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t expected_elements, + iree_const_byte_span_t actual_elements, iree_host_size_t* out_index); + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_metadata_matcher_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_host_size_t shape_rank; + iree_hal_dim_t shape[128]; + iree_hal_element_type_t element_type; + iree_hal_encoding_type_t encoding_type; +} iree_hal_buffer_view_metadata_matcher_t; + +iree_status_t iree_hal_buffer_view_metadata_matcher_initialize( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_hal_element_type_t element_type, + iree_hal_encoding_type_t encoding_type, + iree_hal_buffer_view_metadata_matcher_t* out_matcher); +void iree_hal_buffer_view_metadata_matcher_deinitialize( + iree_hal_buffer_view_metadata_matcher_t* matcher); +iree_status_t iree_hal_buffer_view_metadata_matcher_describe( + iree_hal_buffer_view_metadata_matcher_t* matcher, + iree_string_builder_t* builder); +iree_status_t iree_hal_buffer_view_metadata_matcher_match( + iree_hal_buffer_view_metadata_matcher_t* matcher, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched); + +// Matches |matchee| against the given metadata. +// Use IREE_HAL_ELEMENT_TYPE_NONE to ignore |element_type| and +// use IREE_HAL_ENCODING_TYPE_OPAQUE to ignore |encoding_type|. +iree_status_t iree_hal_buffer_view_match_metadata( + iree_host_size_t shape_rank, const iree_hal_dim_t* shape, + iree_hal_element_type_t element_type, + iree_hal_encoding_type_t encoding_type, iree_hal_buffer_view_t* matchee, + iree_string_builder_t* builder, bool* out_matched); + +// Matches |matchee| against |expected| if all metadata (shape, encoding, etc) +// is equivalent. +iree_status_t iree_hal_buffer_view_match_metadata_like( + iree_hal_buffer_view_t* expected, iree_hal_buffer_view_t* matchee, + iree_string_builder_t* builder, bool* out_matched); + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_element_matcher_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_buffer_equality_t equality; + iree_hal_buffer_element_t value; +} iree_hal_buffer_view_element_matcher_t; + +iree_status_t iree_hal_buffer_view_element_matcher_initialize( + iree_hal_buffer_equality_t equality, iree_hal_buffer_element_t value, + iree_hal_buffer_view_element_matcher_t* out_matcher); +void iree_hal_buffer_view_element_matcher_deinitialize( + iree_hal_buffer_view_element_matcher_t* matcher); +iree_status_t iree_hal_buffer_view_element_matcher_describe( + iree_hal_buffer_view_element_matcher_t* matcher, + iree_string_builder_t* builder); +iree_status_t iree_hal_buffer_view_element_matcher_match( + iree_hal_buffer_view_element_matcher_t* matcher, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched); + +// Matches all elements of |matchee| against |value|. +iree_status_t iree_hal_buffer_view_match_elements( + iree_hal_buffer_equality_t equality, iree_hal_buffer_element_t value, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched); + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_array_matcher_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_buffer_equality_t equality; + iree_hal_element_type_t element_type; + iree_host_size_t element_count; + // TODO(benvanik): copy in? would make easier to take from std::vector. + iree_const_byte_span_t elements; // unowned +} iree_hal_buffer_view_array_matcher_t; + +iree_status_t iree_hal_buffer_view_array_matcher_initialize( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t elements, + iree_hal_buffer_view_array_matcher_t* out_matcher); +void iree_hal_buffer_view_array_matcher_deinitialize( + iree_hal_buffer_view_array_matcher_t* matcher); +iree_status_t iree_hal_buffer_view_array_matcher_describe( + iree_hal_buffer_view_array_matcher_t* matcher, + iree_string_builder_t* builder); +iree_status_t iree_hal_buffer_view_array_matcher_match( + iree_hal_buffer_view_array_matcher_t* matcher, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched); + +// Matches |matchee| against all |element_count| elements in |elements|. +// The element count of |matchee| must be equal to |element_count|. +iree_status_t iree_hal_buffer_view_match_array( + iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type, + iree_host_size_t element_count, iree_const_byte_span_t elements, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched); + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_matcher_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_buffer_equality_t equality; + iree_hal_buffer_view_t* expected; +} iree_hal_buffer_view_matcher_t; + +iree_status_t iree_hal_buffer_view_matcher_initialize( + iree_hal_buffer_equality_t equality, iree_hal_buffer_view_t* expected, + iree_hal_buffer_view_matcher_t* out_matcher); +void iree_hal_buffer_view_matcher_deinitialize( + iree_hal_buffer_view_matcher_t* matcher); +iree_status_t iree_hal_buffer_view_matcher_describe( + iree_hal_buffer_view_matcher_t* matcher, iree_string_builder_t* builder); +iree_status_t iree_hal_buffer_view_matcher_match( + iree_hal_buffer_view_matcher_t* matcher, iree_hal_buffer_view_t* matchee, + iree_string_builder_t* builder, bool* out_matched); + +// Matches |matchee| against |expected| for both metadata and elements. +iree_status_t iree_hal_buffer_view_match_equal( + iree_hal_buffer_equality_t equality, iree_hal_buffer_view_t* expected, + iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder, + bool* out_matched); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_TOOLING_BUFFER_VIEW_MATCHERS_H_ diff --git a/runtime/src/iree/tooling/buffer_view_matchers_test.cc b/runtime/src/iree/tooling/buffer_view_matchers_test.cc new file mode 100644 index 000000000000..a4857b787423 --- /dev/null +++ b/runtime/src/iree/tooling/buffer_view_matchers_test.cc @@ -0,0 +1,698 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/tooling/buffer_view_matchers.h" + +#include "iree/base/api.h" +#include "iree/base/internal/math.h" +#include "iree/base/internal/span.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace iree { +namespace { + +using iree::testing::status::IsOk; +using iree::testing::status::StatusIs; +using ::testing::HasSubstr; + +// TODO(benvanik): move this handle type to a base cc helper. +struct StringBuilder { + static StringBuilder MakeSystem() { + iree_string_builder_t builder; + iree_string_builder_initialize(iree_allocator_system(), &builder); + return StringBuilder(builder); + } + static StringBuilder MakeEmpty() { + iree_string_builder_t builder; + iree_string_builder_initialize(iree_allocator_null(), &builder); + return StringBuilder(builder); + } + explicit StringBuilder(iree_string_builder_t builder) + : builder(std::move(builder)) {} + ~StringBuilder() { iree_string_builder_deinitialize(&builder); } + operator iree_string_builder_t*() { return &builder; } + std::string ToString() const { + return std::string(builder.buffer, builder.size); + } + iree_string_builder_t builder; +}; + +// TODO(benvanik): move this handle type to a hal cc helper. + +// C API iree_*_retain/iree_*_release function pointer. +template +using HandleRefFn = void(IREE_API_PTR*)(T*); + +// C++ RAII wrapper for an IREE C reference object. +// Behaves the same as a thread-safe intrusive pointer. +template retain_fn, HandleRefFn release_fn> +class Handle { + public: + using handle_type = Handle; + + static Handle Wrap(T* value) noexcept { return Handle(value, false); } + + Handle() noexcept = default; + Handle(std::nullptr_t) noexcept {} + Handle(T* value) noexcept : value_(value) { retain_fn(value_); } + + ~Handle() noexcept { + if (value_) release_fn(value_); + } + + Handle(const Handle& rhs) noexcept : value_(rhs.value_) { + if (value_) retain_fn(value_); + } + Handle& operator=(const Handle& rhs) noexcept { + if (value_ != rhs.value_) { + if (value_) release_fn(value_); + value_ = rhs.get(); + if (value_) retain_fn(value_); + } + return *this; + } + + Handle(Handle&& rhs) noexcept : value_(rhs.release()) {} + Handle& operator=(Handle&& rhs) noexcept { + if (value_ != rhs.value_) { + if (value_) release_fn(value_); + value_ = rhs.release(); + } + return *this; + } + + // Gets the pointer referenced by this instance. + constexpr T* get() const noexcept { return value_; } + constexpr operator T*() const noexcept { return value_; } + + // Resets the object to nullptr and decrements the reference count, possibly + // deleting it. + void reset() noexcept { + if (value_) { + release_fn(value_); + value_ = nullptr; + } + } + + // Returns the current pointer held by this object without having its + // reference count decremented and resets the handle to empty. Returns + // nullptr if the handle holds no value. To re-wrap in a handle use either + // ctor(value) or assign(). + T* release() noexcept { + auto* p = value_; + value_ = nullptr; + return p; + } + + // Assigns a pointer. + // The pointer will be accepted by the handle and its reference count will + // not be incremented. + void assign(T* value) noexcept { + reset(); + value_ = value; + } + + // Returns a pointer to the inner pointer storage. + // This allows passing a pointer to the handle as an output argument to + // C-style creation functions. + constexpr T** operator&() noexcept { return &value_; } + + // Support boolean expression evaluation ala unique_ptr/shared_ptr: + // https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool + typedef T* Handle::*unspecified_bool_type; + constexpr operator unspecified_bool_type() const noexcept { + return value_ ? &Handle::value_ : nullptr; + } + + // Supports unary expression evaluation. + constexpr bool operator!() const noexcept { return !value_; } + + // Swap support. + void swap(Handle& rhs) noexcept { std::swap(value_, rhs.value_); } + + protected: + Handle(T* value, bool) noexcept : value_(value) {} + + private: + T* value_ = nullptr; +}; + +// C++ wrapper for iree_hal_buffer_view_t. +struct BufferView final + : public Handle { + using handle_type::handle_type; +}; + +static const iree_hal_buffer_equality_t kExactEquality = ([]() { + iree_hal_buffer_equality_t equality; + equality.mode = IREE_HAL_BUFFER_EQUALITY_EXACT; + return equality; +})(); + +static const iree_hal_buffer_equality_t kApproximateEquality = ([]() { + iree_hal_buffer_equality_t equality; + equality.mode = IREE_HAL_BUFFER_EQUALITY_APPROXIMATE_ABSOLUTE; + equality.f16_threshold = 0.001f; + equality.f32_threshold = 0.0001f; + equality.f64_threshold = 0.0001; + return equality; +})(); + +class BufferViewMatchersTest : public ::testing::Test { + protected: + iree_hal_allocator_t* device_allocator_ = nullptr; + virtual void SetUp() { + IREE_CHECK_OK(iree_hal_allocator_create_heap( + IREE_SV("heap"), iree_allocator_system(), iree_allocator_system(), + &device_allocator_)); + } + virtual void TearDown() { iree_hal_allocator_release(device_allocator_); } + + template + StatusOr CreateBufferView(iree::span shape, + iree_hal_element_type_t element_type, + const T* contents) { + iree_hal_dim_t num_elements = 1; + for (iree_hal_dim_t dim : shape) num_elements *= dim; + iree_hal_buffer_params_t params = {0}; + params.type = + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + params.usage = + IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING; + BufferView buffer_view; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer( + device_allocator_, shape.size(), shape.data(), element_type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, params, + iree_make_const_byte_span(contents, num_elements * sizeof(T)), + &buffer_view)); + return std::move(buffer_view); + } +}; + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_equality_t +//===----------------------------------------------------------------------===// + +TEST_F(BufferViewMatchersTest, CompareBroadcastI8EQ) { + const int8_t lhs = 1; + const int8_t rhs[] = {1, 1, 1}; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_i8(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastI8NE) { + const int8_t lhs = 1; + const int8_t rhs[] = {1, 2, 3}; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_i8(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); + EXPECT_EQ(index, 1); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastI64EQ) { + const int64_t lhs = 1; + const int64_t rhs[] = {1, 1, 1}; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_i64(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastI64NE) { + const int64_t lhs = 1; + const int64_t rhs[] = {1, 2, 3}; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_i64(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); + EXPECT_EQ(index, 1); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastF16EQ) { + const float lhs = 1.0f; + const uint16_t rhs[] = { + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(1.0f), + }; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_f16(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastF16NE) { + const float lhs = 1.0f; + const uint16_t rhs[] = { + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(3.0f), + iree_math_f32_to_f16(4.0f), + }; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_f16(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); + EXPECT_EQ(index, 1); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastF32EQ) { + const float lhs = 1.0f; + const float rhs[] = {1.0f, 1.0f, 1.0f}; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_f32(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastF32NE) { + const float lhs = 1.0f; + const float rhs[] = {1.0f, 3.0f, 4.0f}; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_f32(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); + EXPECT_EQ(index, 1); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastF64EQ) { + const double lhs = 1.0; + const double rhs[] = {1.0, 1.0, 1.0}; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_f64(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); +} + +TEST_F(BufferViewMatchersTest, CompareBroadcastF64NE) { + const double lhs = 1.0; + const double rhs[] = {1.0, 3.0, 4.0}; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast( + kApproximateEquality, iree_hal_make_buffer_element_f64(lhs), + IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)), + &index)); + EXPECT_EQ(index, 1); +} + +TEST_F(BufferViewMatchersTest, CompareElementwiseF16EQ) { + const uint16_t lhs[] = { + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(2.0f), + iree_math_f32_to_f16(3.0f), + }; + const uint16_t rhs[] = { + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(2.0f), + iree_math_f32_to_f16(3.0f), + }; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise( + kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_16, IREE_ARRAYSIZE(lhs), + iree_make_const_byte_span(lhs, sizeof(lhs)), + iree_make_const_byte_span(rhs, sizeof(rhs)), &index)); +} + +TEST_F(BufferViewMatchersTest, CompareElementwiseF16NearEQ) { + const uint16_t lhs[] = { + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(1.99999f), + iree_math_f32_to_f16(0.00001f), + iree_math_f32_to_f16(4.0f), + }; + const uint16_t rhs[] = { + iree_math_f32_to_f16(1.00001f), + iree_math_f32_to_f16(2.0f), + iree_math_f32_to_f16(0.0f), + iree_math_f32_to_f16(4.0f), + }; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise( + kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_16, IREE_ARRAYSIZE(lhs), + iree_make_const_byte_span(lhs, sizeof(lhs)), + iree_make_const_byte_span(rhs, sizeof(rhs)), &index)); +} + +TEST_F(BufferViewMatchersTest, CompareElementwiseF16NE) { + const uint16_t lhs[] = { + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(2.0f), + iree_math_f32_to_f16(4.0f), + }; + const uint16_t rhs[] = { + iree_math_f32_to_f16(1.0f), + iree_math_f32_to_f16(3.0f), + iree_math_f32_to_f16(4.0f), + }; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_elementwise( + kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_16, IREE_ARRAYSIZE(lhs), + iree_make_const_byte_span(lhs, sizeof(lhs)), + iree_make_const_byte_span(rhs, sizeof(rhs)), &index)); + EXPECT_EQ(index, 1); +} + +TEST_F(BufferViewMatchersTest, CompareElementwiseF32EQ) { + const float lhs[] = {1.0f, 2.0f, 3.0f}; + const float rhs[] = {1.0f, 2.0f, 3.0f}; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise( + kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_ARRAYSIZE(lhs), + iree_make_const_byte_span(lhs, sizeof(lhs)), + iree_make_const_byte_span(rhs, sizeof(rhs)), &index)); +} + +TEST_F(BufferViewMatchersTest, CompareElementwiseF32NE) { + const float lhs[] = {1.0f, 2.0f, 4.0f}; + const float rhs[] = {1.0f, 3.0f, 4.0f}; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_elementwise( + kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_ARRAYSIZE(lhs), + iree_make_const_byte_span(lhs, sizeof(lhs)), + iree_make_const_byte_span(rhs, sizeof(rhs)), &index)); + EXPECT_EQ(index, 1); +} + +TEST_F(BufferViewMatchersTest, CompareElementwiseF64EQ) { + const double lhs[] = {1.0, 2.0, 3.0}; + const double rhs[] = {1.0, 2.0, 3.0}; + iree_host_size_t index = 0; + EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise( + kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_64, IREE_ARRAYSIZE(lhs), + iree_make_const_byte_span(lhs, sizeof(lhs)), + iree_make_const_byte_span(rhs, sizeof(rhs)), &index)); +} + +TEST_F(BufferViewMatchersTest, CompareElementwiseF64NE) { + const double lhs[] = {1.0, 2.0, 4.0}; + const double rhs[] = {1.0, 3.0, 4.0}; + iree_host_size_t index = 0; + EXPECT_FALSE(iree_hal_compare_buffer_elements_elementwise( + kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_64, IREE_ARRAYSIZE(lhs), + iree_make_const_byte_span(lhs, sizeof(lhs)), + iree_make_const_byte_span(rhs, sizeof(rhs)), &index)); + EXPECT_EQ(index, 1); +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_metadata_matcher_t +//===----------------------------------------------------------------------===// + +TEST_F(BufferViewMatchersTest, MetadataEmpty) { + const float contents[1] = {0}; + iree_hal_dim_t shape[] = {0}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_metadata_like(lhs, rhs, sb, &match)); + EXPECT_TRUE(match); +} + +TEST_F(BufferViewMatchersTest, MetadataShapesDiffer) { + const float lhs_contents[] = {1.0f}; + const float rhs_contents[] = {1.0f, 2.0f}; + iree_hal_dim_t lhs_shape[] = {1}; + iree_hal_dim_t rhs_shape[] = {1, 2}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, CreateBufferView(lhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + lhs_contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, CreateBufferView(rhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_metadata_like(lhs, rhs, sb, &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("is 1x2xf32")); + EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xf32")); +} + +TEST_F(BufferViewMatchersTest, MetadataElementTypesDiffer) { + const float contents[] = {1.0f}; + iree_hal_dim_t shape[] = {1}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_metadata_like(lhs, rhs, sb, &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("is 1xf32")); + EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xi32")); +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_element_matcher_t +//===----------------------------------------------------------------------===// + +TEST_F(BufferViewMatchersTest, ElementTypesDiffer) { + const float lhs_value = 1; + const int32_t rhs_contents[] = {1, 1, 1}; + const iree_hal_dim_t shape[] = {1, 3}; + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK(iree_hal_buffer_view_match_elements( + kExactEquality, iree_hal_make_buffer_element_f32(lhs_value), rhs, sb, + &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("type (i32)")); + EXPECT_THAT(sb.ToString(), HasSubstr("expected (f32)")); +} + +TEST_F(BufferViewMatchersTest, MatchElementContentsI32) { + const int32_t lhs_value = 1; + const int32_t rhs_contents[] = {1, 1, 1}; + const iree_hal_dim_t shape[] = {1, 3}; + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK(iree_hal_buffer_view_match_elements( + kExactEquality, iree_hal_make_buffer_element_i32(lhs_value), rhs, sb, + &match)); + EXPECT_TRUE(match); + EXPECT_TRUE(sb.ToString().empty()); +} + +TEST_F(BufferViewMatchersTest, MismatchElementContentsI32) { + const int32_t lhs_value = 1; + const int32_t rhs_contents[] = {1, 2, 3}; + const iree_hal_dim_t shape[] = {1, 3}; + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK(iree_hal_buffer_view_match_elements( + kExactEquality, iree_hal_make_buffer_element_i32(lhs_value), rhs, sb, + &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("element at index 1")); +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_array_matcher_t +//===----------------------------------------------------------------------===// + +TEST_F(BufferViewMatchersTest, MatchArrayTypesDiffer) { + const float lhs_contents[] = {1, 1, 1}; + const int32_t rhs_contents[] = {1, 1, 1}; + const iree_hal_dim_t shape[] = {1, 3}; + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK(iree_hal_buffer_view_match_array( + kExactEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + IREE_ARRAYSIZE(lhs_contents), + iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb, + &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("type (i32)")); + EXPECT_THAT(sb.ToString(), HasSubstr("expected (f32)")); +} + +TEST_F(BufferViewMatchersTest, MatchArrayCountsDiffer) { + const int32_t lhs_contents[] = {1, 1}; + const int32_t rhs_contents[] = {1, 1, 1}; + const iree_hal_dim_t shape[] = {1, 3}; + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK(iree_hal_buffer_view_match_array( + kExactEquality, IREE_HAL_ELEMENT_TYPE_INT_32, + IREE_ARRAYSIZE(lhs_contents), + iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb, + &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("count (3)")); + EXPECT_THAT(sb.ToString(), HasSubstr("expected (2)")); +} + +TEST_F(BufferViewMatchersTest, MatchArrayContentsI32) { + const int32_t lhs_contents[] = {1, 1, 1}; + const int32_t rhs_contents[] = {1, 1, 1}; + const iree_hal_dim_t shape[] = {1, 3}; + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK(iree_hal_buffer_view_match_array( + kExactEquality, IREE_HAL_ELEMENT_TYPE_INT_32, + IREE_ARRAYSIZE(lhs_contents), + iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb, + &match)); + EXPECT_TRUE(match); + EXPECT_TRUE(sb.ToString().empty()); +} + +TEST_F(BufferViewMatchersTest, MismatchArrayContentsI32) { + const int32_t lhs_contents[] = {1, 1, 1}; + const int32_t rhs_contents[] = {1, 2, 3}; + const iree_hal_dim_t shape[] = {1, 3}; + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK(iree_hal_buffer_view_match_array( + kExactEquality, IREE_HAL_ELEMENT_TYPE_INT_32, + IREE_ARRAYSIZE(lhs_contents), + iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb, + &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("element at index 1")); +} + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_matcher_t +//===----------------------------------------------------------------------===// + +TEST_F(BufferViewMatchersTest, MatchEmpty) { + const float contents[1] = {0}; + iree_hal_dim_t shape[] = {0}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match)); + EXPECT_TRUE(match); + EXPECT_TRUE(sb.ToString().empty()); +} + +TEST_F(BufferViewMatchersTest, MatchShapesDiffer) { + const float lhs_contents[] = {1.0f}; + const float rhs_contents[] = {1.0f, 2.0f}; + iree_hal_dim_t lhs_shape[] = {1}; + iree_hal_dim_t rhs_shape[] = {1, 2}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, CreateBufferView(lhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + lhs_contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, CreateBufferView(rhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("is 1x2xf32")); + EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xf32")); +} + +TEST_F(BufferViewMatchersTest, MatchElementTypesDiffer) { + const float contents[] = {1.0f}; + iree_hal_dim_t shape[] = {1}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("is 1xf32")); + EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xi32")); +} + +TEST_F(BufferViewMatchersTest, MatchContentsF16) { + const uint16_t lhs_contents[] = {iree_math_f32_to_f16(2.0f)}; + const uint16_t rhs_contents[] = {iree_math_f32_to_f16(2.0f)}; + iree_hal_dim_t shape[] = {1}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, lhs_contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match)); + EXPECT_TRUE(match); + EXPECT_TRUE(sb.ToString().empty()); +} + +TEST_F(BufferViewMatchersTest, MismatchContentsF16) { + const uint16_t lhs_contents[] = {iree_math_f32_to_f16(1.0f)}; + const uint16_t rhs_contents[] = {iree_math_f32_to_f16(2.0f)}; + const iree_hal_dim_t shape[] = {1}; + IREE_ASSERT_OK_AND_ASSIGN( + auto lhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, lhs_contents)); + IREE_ASSERT_OK_AND_ASSIGN( + auto rhs, + CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, rhs_contents)); + auto sb = StringBuilder::MakeSystem(); + bool match = false; + IREE_ASSERT_OK( + iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match)); + EXPECT_FALSE(match); + EXPECT_THAT(sb.ToString(), HasSubstr("element at index 0")); +} + +} // namespace +} // namespace iree diff --git a/runtime/src/iree/tooling/comparison.cc b/runtime/src/iree/tooling/comparison.cc new file mode 100644 index 000000000000..914baccd7a7b --- /dev/null +++ b/runtime/src/iree/tooling/comparison.cc @@ -0,0 +1,294 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/tooling/comparison.h" + +#include +#include + +#include "iree/base/api.h" +#include "iree/base/internal/flags.h" +#include "iree/base/status_cc.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/module.h" +#include "iree/tooling/buffer_view_matchers.h" +#include "iree/tooling/vm_util.h" +#include "iree/vm/ref_cc.h" + +using namespace iree; + +IREE_FLAG(float, expected_f16_threshold, 0.001f, + "Threshold under which two f16 values are considered equal."); +IREE_FLAG(float, expected_f32_threshold, 0.0001f, + "Threshold under which two f32 values are considered equal."); +IREE_FLAG(double, expected_f64_threshold, 0.0001, + "Threshold under which two f64 values are considered equal."); + +static iree_hal_buffer_equality_t iree_tooling_equality_from_flags(void) { + iree_hal_buffer_equality_t equality; + equality.mode = IREE_HAL_BUFFER_EQUALITY_APPROXIMATE_ABSOLUTE; + equality.f16_threshold = FLAG_expected_f16_threshold; + equality.f32_threshold = FLAG_expected_f32_threshold; + equality.f64_threshold = FLAG_expected_f64_threshold; + return equality; +} + +// Prints a buffer view with contents without a trailing newline. +static iree_status_t iree_tooling_append_buffer_view_string( + iree_hal_buffer_view_t* buffer_view, iree_host_size_t max_element_count, + iree_string_builder_t* builder) { + // NOTE: we could see how many bytes are available in the builder (capacity - + // size) and then pass those in to the initial format - if there's enough + // space it'll fill what it needs. We'd need to adjust the string builder + // afterward somehow. + iree_host_size_t required_length = 0; + iree_status_t status = iree_hal_buffer_view_format( + buffer_view, max_element_count, /*buffer_capacity=*/0, /*buffer=*/NULL, + &required_length); + if (!iree_status_is_out_of_range(status)) return status; + char* buffer = NULL; + IREE_RETURN_IF_ERROR( + iree_string_builder_append_inline(builder, required_length, &buffer)); + if (!buffer) return iree_ok_status(); + return iree_hal_buffer_view_format(buffer_view, max_element_count, + required_length + /*NUL=*/1, buffer, + &required_length); +} + +static iree_status_t iree_vm_append_variant_type_string( + iree_vm_variant_t variant, iree_string_builder_t* builder) { + if (iree_vm_variant_is_empty(variant)) { + return iree_string_builder_append_string(builder, IREE_SV("empty")); + } else if (iree_vm_variant_is_value(variant)) { + const char* type = NULL; + switch (variant.type.value_type) { + case IREE_VM_VALUE_TYPE_I8: + type = "i8"; + break; + case IREE_VM_VALUE_TYPE_I16: + type = "i16"; + break; + case IREE_VM_VALUE_TYPE_I32: + type = "i32"; + break; + case IREE_VM_VALUE_TYPE_I64: + type = "i64"; + break; + case IREE_VM_VALUE_TYPE_F32: + type = "f32"; + break; + case IREE_VM_VALUE_TYPE_F64: + type = "f64"; + break; + default: + type = "?"; + break; + } + return iree_string_builder_append_cstring(builder, type); + } else if (iree_vm_variant_is_ref(variant)) { + return iree_string_builder_append_string( + builder, iree_vm_ref_type_name(variant.type.ref_type)); + } else { + return iree_string_builder_append_string(builder, IREE_SV("unknown")); + } +} + +static bool iree_tooling_compare_values(int result_index, + iree_vm_variant_t expected_variant, + iree_vm_variant_t actual_variant, + iree_string_builder_t* builder) { + IREE_ASSERT_EQ(expected_variant.type.value_type, + actual_variant.type.value_type); + switch (expected_variant.type.value_type) { + case IREE_VM_VALUE_TYPE_I8: + if (expected_variant.i8 != actual_variant.i8) { + IREE_CHECK_OK(iree_string_builder_append_format( + builder, + "[FAILED] result[%d]: i8 values differ\n expected: %" PRIi8 + "\n actual: %" PRIi8 "\n", + result_index, expected_variant.i8, actual_variant.i8)); + return false; + } + return true; + case IREE_VM_VALUE_TYPE_I16: + if (expected_variant.i16 != actual_variant.i16) { + IREE_CHECK_OK(iree_string_builder_append_format( + builder, + "[FAILED] result[%d]: i16 values differ\n expected: %" PRIi16 + "\n actual: %" PRIi16 "\n", + result_index, expected_variant.i16, actual_variant.i16)); + return false; + } + return true; + case IREE_VM_VALUE_TYPE_I32: + if (expected_variant.i32 != actual_variant.i32) { + IREE_CHECK_OK(iree_string_builder_append_format( + builder, + "[FAILED] result[%d]: i32 values differ\n expected: %" PRIi32 + "\n actual: %" PRIi32 "\n", + result_index, expected_variant.i32, actual_variant.i32)); + return false; + } + return true; + case IREE_VM_VALUE_TYPE_I64: + if (expected_variant.i64 != actual_variant.i64) { + IREE_CHECK_OK(iree_string_builder_append_format( + builder, + "[FAILED] result[%d]: i64 values differ\n expected: %" PRIi64 + "\n actual: %" PRIi64 "\n", + result_index, expected_variant.i64, actual_variant.i64)); + return false; + } + return true; + case IREE_VM_VALUE_TYPE_F32: + // TODO(benvanik): use tolerance flag. + if (expected_variant.f32 != actual_variant.f32) { + IREE_CHECK_OK(iree_string_builder_append_format( + builder, + "[FAILED] result[%d]: f32 values differ\n expected: %G\n actual: " + "%G\n", + result_index, expected_variant.f32, actual_variant.f32)); + return false; + } + return true; + case IREE_VM_VALUE_TYPE_F64: + // TODO(benvanik): use tolerance flag. + if (expected_variant.f64 != actual_variant.f64) { + IREE_CHECK_OK(iree_string_builder_append_format( + builder, + "[FAILED] result[%d]: f64 values differ\n expected: %G\n actual: " + "%G\n", + result_index, expected_variant.f64, actual_variant.f64)); + return false; + } + return true; + default: + IREE_CHECK_OK(iree_string_builder_append_format( + builder, "[FAILED] result[%d]: unknown value type, cannot match\n", + result_index)); + return false; + } +} + +static bool iree_tooling_compare_buffer_views( + int result_index, iree_hal_buffer_view_t* expected_view, + iree_hal_buffer_view_t* actual_view, iree_allocator_t host_allocator, + iree_host_size_t max_element_count, iree_string_builder_t* builder) { + iree_string_builder_t subbuilder; + iree_string_builder_initialize(host_allocator, &subbuilder); + + iree_hal_buffer_equality_t equality = iree_tooling_equality_from_flags(); + bool did_match = false; + IREE_CHECK_OK(iree_hal_buffer_view_match_equal( + equality, expected_view, actual_view, &subbuilder, &did_match)); + if (did_match) { + iree_string_builder_deinitialize(&subbuilder); + return true; + } + IREE_CHECK_OK(iree_string_builder_append_format( + builder, "[FAILED] result[%d]: ", result_index)); + IREE_CHECK_OK(iree_string_builder_append_string( + builder, iree_string_builder_view(&subbuilder))); + iree_string_builder_deinitialize(&subbuilder); + + IREE_CHECK_OK( + iree_string_builder_append_string(builder, IREE_SV("\n expected:\n"))); + IREE_CHECK_OK(iree_tooling_append_buffer_view_string( + expected_view, max_element_count, builder)); + IREE_CHECK_OK( + iree_string_builder_append_string(builder, IREE_SV("\n actual:\n"))); + IREE_CHECK_OK(iree_tooling_append_buffer_view_string( + actual_view, max_element_count, builder)); + IREE_CHECK_OK(iree_string_builder_append_string(builder, IREE_SV("\n"))); + + return false; +} + +static bool iree_tooling_compare_variants(int result_index, + iree_vm_variant_t expected_variant, + iree_vm_variant_t actual_variant, + iree_allocator_t host_allocator, + iree_host_size_t max_element_count, + iree_string_builder_t* builder) { + IREE_TRACE_SCOPE(); + + if (iree_vm_variant_is_empty(expected_variant)) { + return true; // expected empty is sentinel for (ignored) + } else if (iree_vm_variant_is_empty(actual_variant) && + iree_vm_variant_is_empty(expected_variant)) { + return true; // both empty + } else if (iree_vm_variant_is_value(actual_variant) && + iree_vm_variant_is_value(expected_variant)) { + if (expected_variant.type.value_type != actual_variant.type.value_type) { + return iree_tooling_compare_values(result_index, expected_variant, + actual_variant, builder); + } + } else if (iree_vm_variant_is_ref(actual_variant) && + iree_vm_variant_is_ref(expected_variant)) { + if (iree_hal_buffer_view_isa(actual_variant.ref) && + iree_hal_buffer_view_isa(expected_variant.ref)) { + return iree_tooling_compare_buffer_views( + result_index, iree_hal_buffer_view_deref(expected_variant.ref), + iree_hal_buffer_view_deref(actual_variant.ref), host_allocator, + max_element_count, builder); + } + } + + IREE_CHECK_OK(iree_string_builder_append_format( + builder, "[FAILED] result[%d]: ", result_index)); + IREE_CHECK_OK(iree_string_builder_append_string( + builder, IREE_SV("variant types mismatch; expected "))); + IREE_CHECK_OK(iree_vm_append_variant_type_string(expected_variant, builder)); + IREE_CHECK_OK( + iree_string_builder_append_string(builder, IREE_SV(" but got "))); + IREE_CHECK_OK(iree_vm_append_variant_type_string(actual_variant, builder)); + IREE_CHECK_OK(iree_string_builder_append_string(builder, IREE_SV("\n"))); + + return false; +} + +bool iree_tooling_compare_variant_lists_and_append( + iree_vm_list_t* expected_list, iree_vm_list_t* actual_list, + iree_allocator_t host_allocator, iree_string_builder_t* builder) { + IREE_TRACE_SCOPE(); + + if (iree_vm_list_size(expected_list) != iree_vm_list_size(actual_list)) { + IREE_CHECK_OK(iree_string_builder_append_format( + builder, "[FAILED] expected %zu list elements but %zu provided\n", + iree_vm_list_size(expected_list), iree_vm_list_size(actual_list))); + return false; + } + + bool all_match = true; + for (iree_host_size_t i = 0; i < iree_vm_list_size(expected_list); ++i) { + iree_vm_variant_t expected_variant = iree_vm_variant_empty(); + IREE_CHECK_OK( + iree_vm_list_get_variant(expected_list, i, &expected_variant)); + iree_vm_variant_t actual_variant = iree_vm_variant_empty(); + IREE_CHECK_OK(iree_vm_list_get_variant(actual_list, i, &actual_variant)); + bool did_match = iree_tooling_compare_variants( + (int)i, expected_variant, actual_variant, host_allocator, + /*max_element_count=*/1024, builder); + if (!did_match) all_match = false; + } + + return all_match; +} + +bool iree_tooling_compare_variant_lists(iree_vm_list_t* expected_list, + iree_vm_list_t* actual_list, + iree_allocator_t host_allocator, + FILE* file) { + iree_string_builder_t builder; + iree_string_builder_initialize(host_allocator, &builder); + bool all_match = iree_tooling_compare_variant_lists_and_append( + expected_list, actual_list, host_allocator, &builder); + fwrite(iree_string_builder_buffer(&builder), 1, + iree_string_builder_size(&builder), file); + iree_string_builder_deinitialize(&builder); + return all_match; +} diff --git a/runtime/src/iree/tooling/comparison.h b/runtime/src/iree/tooling/comparison.h new file mode 100644 index 000000000000..c7e5acfffe39 --- /dev/null +++ b/runtime/src/iree/tooling/comparison.h @@ -0,0 +1,30 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_TOOLING_COMPARISON_H_ +#define IREE_TOOLING_COMPARISON_H_ + +#include "iree/base/api.h" +#include "iree/vm/api.h" +#include "stdio.h" + +// Compares expected vs actual results and appends to |builder|. +// Returns true if all values match and false otherwise. +// Errors when performing comparison will abort the process. +// When all list elements match no output is written and otherwise +// newline-separated strings detailing the differing elements is appended. +bool iree_tooling_compare_variant_lists_and_append( + iree_vm_list_t* expected_list, iree_vm_list_t* actual_list, + iree_allocator_t host_allocator, iree_string_builder_t* builder); + +// Compares expected vs actual results and appends to |file|. +// Refer to iree_tooling_compare_variant_lists_and_append for details. +bool iree_tooling_compare_variant_lists(iree_vm_list_t* expected_list, + iree_vm_list_t* actual_list, + iree_allocator_t host_allocator, + FILE* file); + +#endif // IREE_TOOLING_COMPARISON_H_ diff --git a/runtime/src/iree/tooling/comparison_test.cc b/runtime/src/iree/tooling/comparison_test.cc new file mode 100644 index 000000000000..e5e7fd64e9f0 --- /dev/null +++ b/runtime/src/iree/tooling/comparison_test.cc @@ -0,0 +1,125 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/tooling/comparison.h" + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/module.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" +#include "iree/tooling/vm_util.h" +#include "iree/vm/api.h" +#include "iree/vm/ref_cc.h" + +namespace iree { +namespace { + +using ::testing::HasSubstr; + +class ComparisonTest : public ::testing::Test { + protected: + virtual void SetUp() { + IREE_ASSERT_OK(iree_vm_instance_create(host_allocator_, &instance_)); + IREE_ASSERT_OK(iree_hal_module_register_all_types(instance_)); + IREE_ASSERT_OK(iree_hal_allocator_create_heap( + IREE_SV("heap"), host_allocator_, host_allocator_, &device_allocator_)); + } + + virtual void TearDown() { + iree_hal_allocator_release(device_allocator_); + iree_vm_instance_release(instance_); + } + + bool ParseAndCompareVariantLists( + iree::span expected_strings, + iree::span actual_strings, std::string* out_string) { + vm::ref expected_list; + IREE_CHECK_OK(ParseToVariantList(device_allocator_, expected_strings, + host_allocator_, &expected_list)); + + vm::ref actual_list; + IREE_CHECK_OK(ParseToVariantList(device_allocator_, actual_strings, + host_allocator_, &actual_list)); + iree_string_builder_t builder; + iree_string_builder_initialize(host_allocator_, &builder); + bool all_match = iree_tooling_compare_variant_lists_and_append( + expected_list.get(), actual_list.get(), host_allocator_, &builder); + out_string->assign(iree_string_builder_buffer(&builder), + iree_string_builder_size(&builder)); + iree_string_builder_deinitialize(&builder); + + return all_match; + } + + iree_vm_instance_t* instance_ = nullptr; + iree_allocator_t host_allocator_ = iree_allocator_system(); + iree_hal_allocator_t* device_allocator_ = nullptr; +}; + +TEST_F(ComparisonTest, CompareEqualLists) { + std::string buf_string1 = "2x2xi32=[42 43][44 45]"; + std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]"; + auto buf_strings = std::vector{buf_string1, buf_string2}; + std::string result; + EXPECT_TRUE(ParseAndCompareVariantLists(buf_strings, buf_strings, &result)); + EXPECT_EQ(result, ""); +} + +TEST_F(ComparisonTest, CompareListsWithIgnored) { + std::string buf_string1 = "2x2xi32=[42 43][44 45]"; + std::string buf_string2 = "2x3xf64=[1 2 999][4 5 6]"; + std::string buf_string2_ignored = "(ignored)"; + auto actual_strings = std::vector{buf_string1, buf_string2}; + auto expected_strings = + std::vector{buf_string1, buf_string2_ignored}; + std::string result; + EXPECT_TRUE( + ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); + EXPECT_EQ(result, ""); +} + +TEST_F(ComparisonTest, CompareTruncatedLists) { + std::string buf_string1 = "2x2xi32=[42 43][44 45]"; + std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]"; + auto actual_strings = std::vector{buf_string1, buf_string2}; + auto expected_strings = std::vector{buf_string1}; + std::string result; + EXPECT_FALSE( + ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); + EXPECT_THAT(result, HasSubstr("expected 1 list elements but 2 provided")); +} + +TEST_F(ComparisonTest, CompareDifferingLists) { + std::string buf_string1 = "2x2xi32=[42 43][44 45]"; + std::string buf_string2 = "2x3xf64=[1 2 999][4 5 6]"; + std::string buf_string2_good = "2x3xf64=[1 2 3][4 5 6]"; + auto actual_strings = std::vector{buf_string1, buf_string2}; + auto expected_strings = + std::vector{buf_string1, buf_string2_good}; + std::string result; + EXPECT_FALSE( + ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); + EXPECT_THAT( + result, + HasSubstr("element at index 2 (999) does not match the expected (3)")); +} + +TEST_F(ComparisonTest, CompareListsWithDifferingTypes) { + std::string buf_string1 = "2x2xi32=[42 43][44 45]"; + std::string buf_string2 = "123"; + std::string buf_string2_good = "2x3xf64=[1 2 3][4 5 6]"; + auto actual_strings = std::vector{buf_string1, buf_string2}; + auto expected_strings = + std::vector{buf_string1, buf_string2_good}; + std::string result; + EXPECT_FALSE( + ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); + EXPECT_THAT(result, HasSubstr("variant types mismatch")); +} + +} // namespace +} // namespace iree diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c index 988bf3b0a567..9f1ef9cbac07 100644 --- a/runtime/src/iree/tooling/context_util.c +++ b/runtime/src/iree/tooling/context_util.c @@ -210,7 +210,8 @@ static iree_status_t iree_tooling_load_hal_loader_module( iree_host_size_t loader_count = 0; iree_hal_executable_loader_t* loaders[16]; iree_status_t status = iree_hal_create_all_available_executable_loaders( - IREE_ARRAYSIZE(loaders), &loader_count, loaders, host_allocator); + iree_hal_executable_import_provider_default(), IREE_ARRAYSIZE(loaders), + &loader_count, loaders, host_allocator); // Create the module; it retains the loaders for its lifetime. iree_vm_module_t* module = NULL; diff --git a/runtime/src/iree/tooling/numpy_io_test.cc b/runtime/src/iree/tooling/numpy_io_test.cc index 342abe03aa76..ef5443d67f38 100644 --- a/runtime/src/iree/tooling/numpy_io_test.cc +++ b/runtime/src/iree/tooling/numpy_io_test.cc @@ -19,6 +19,11 @@ using iree::testing::status::IsOk; using iree::testing::status::StatusIs; using ::testing::ElementsAreArray; +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + class NumpyIOTest : public ::testing::Test { protected: virtual void SetUp() { diff --git a/runtime/src/iree/tooling/vm_util.cc b/runtime/src/iree/tooling/vm_util.cc index c6474d7adeeb..f1b288a1ed45 100644 --- a/runtime/src/iree/tooling/vm_util.cc +++ b/runtime/src/iree/tooling/vm_util.cc @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -23,42 +22,44 @@ // TODO(benvanik): drop use of stdio and make an iree_io_stream_t. #if defined(IREE_PLATFORM_WINDOWS) -static uint64_t GetFileLength(FILE* file) { +static uint64_t iree_file_query_length(FILE* file) { _fseeki64(file, 0, SEEK_END); uint64_t file_length = _ftelli64(file); _fseeki64(file, 0, SEEK_SET); return file_length; } -static bool IsEOF(FILE* file, uint64_t file_length) { +static bool iree_file_is_eof(FILE* file, uint64_t file_length) { return _ftelli64(file) == file_length; } #else -static uint64_t GetFileLength(FILE* file) { +static uint64_t iree_file_query_length(FILE* file) { fseeko(file, 0, SEEK_END); uint64_t file_length = ftello(file); fseeko(file, 0, SEEK_SET); return file_length; } -static bool IsEOF(FILE* file, uint64_t file_length) { +static bool iree_file_is_eof(FILE* file, uint64_t file_length) { return ftello(file) == file_length; } #endif // IREE_PLATFORM_* +using namespace iree; + namespace iree { -static iree_status_t LoadNdarraysFromFile( +static iree_status_t iree_tooling_load_ndarrays_from_file( iree_string_view_t file_path, iree_hal_allocator_t* device_allocator, iree_vm_list_t* variant_list) { // Open the file for reading. std::string file_path_str(file_path.data, file_path.size); - FILE* file = std::fopen(file_path_str.c_str(), "rb"); + FILE* file = fopen(file_path_str.c_str(), "rb"); if (!file) { return iree_make_status(iree_status_code_from_errno(errno), "failed to open file '%.*s'", (int)file_path.size, file_path.data); } - uint64_t file_length = GetFileLength(file); + uint64_t file_length = iree_file_query_length(file); iree_hal_buffer_params_t buffer_params = {}; buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; @@ -66,7 +67,7 @@ static iree_status_t LoadNdarraysFromFile( buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; iree_status_t status = iree_ok_status(); - while (iree_status_is_ok(status) && !IsEOF(file, file_length)) { + while (iree_status_is_ok(status) && !iree_file_is_eof(file, file_length)) { iree_hal_buffer_view_t* buffer_view = NULL; status = iree_numpy_npy_load_ndarray( file, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params, @@ -78,7 +79,7 @@ static iree_status_t LoadNdarraysFromFile( iree_hal_buffer_view_release(buffer_view); } - std::fclose(file); + fclose(file); return status; } @@ -117,7 +118,7 @@ static iree_status_t CreateBufferViewFromFile( // Open the file for reading. std::string file_path_str(file_path.data, file_path.size); - FILE* file = std::fopen(file_path_str.c_str(), "rb"); + FILE* file = fopen(file_path_str.c_str(), "rb"); if (!file) { return iree_make_status(iree_status_code_from_errno(errno), "failed to open file '%.*s'", (int)file_path.size, @@ -138,8 +139,8 @@ static iree_status_t CreateBufferViewFromFile( +[](iree_hal_buffer_mapping_t* mapping, void* user_data) { auto* read_params = reinterpret_cast(user_data); size_t bytes_read = - std::fread(mapping->contents.data, 1, mapping->contents.data_length, - read_params->file); + fread(mapping->contents.data, 1, mapping->contents.data_length, + read_params->file); if (bytes_read != mapping->contents.data_length) { return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "file contents truncated; expected %zu bytes " @@ -150,7 +151,7 @@ static iree_status_t CreateBufferViewFromFile( }, &read_params, out_buffer_view); - std::fclose(file); + fclose(file); return status; } @@ -159,6 +160,8 @@ Status ParseToVariantList(iree_hal_allocator_t* device_allocator, iree::span input_strings, iree_allocator_t host_allocator, iree_vm_list_t** out_list) { + IREE_TRACE_SCOPE(); + *out_list = NULL; vm::ref variant_list; IREE_RETURN_IF_ERROR(iree_vm_list_create( @@ -168,10 +171,11 @@ Status ParseToVariantList(iree_hal_allocator_t* device_allocator, iree_string_view_t input_view = iree_string_view_trim(iree_make_string_view( input_strings[i].data(), input_strings[i].size())); if (iree_string_view_consume_prefix(&input_view, IREE_SV("@"))) { - IREE_RETURN_IF_ERROR(LoadNdarraysFromFile(input_view, device_allocator, - variant_list.get())); + IREE_RETURN_IF_ERROR(iree_tooling_load_ndarrays_from_file( + input_view, device_allocator, variant_list.get())); continue; - } else if (iree_string_view_equal(input_view, IREE_SV("(null)"))) { + } else if (iree_string_view_equal(input_view, IREE_SV("(null)")) || + iree_string_view_equal(input_view, IREE_SV("(ignored)"))) { iree_vm_ref_t null_ref = iree_vm_ref_null(); IREE_RETURN_IF_ERROR( iree_vm_list_push_ref_retain(variant_list.get(), &null_ref)); @@ -245,67 +249,110 @@ Status ParseToVariantList(iree_hal_allocator_t* device_allocator, return OkStatus(); } +// Prints a buffer view with contents without a trailing newline. +static iree_status_t PrintBufferView(iree_hal_buffer_view_t* buffer_view, + iree_host_size_t max_element_count, + iree_string_builder_t* builder) { + std::string result_str(4096, '\0'); + iree_status_t status; + do { + iree_host_size_t actual_length = 0; + status = iree_hal_buffer_view_format(buffer_view, max_element_count, + result_str.size() + 1, &result_str[0], + &actual_length); + result_str.resize(actual_length); + } while (iree_status_is_out_of_range(status)); + IREE_RETURN_IF_ERROR(status); + iree_string_builder_append_string( + builder, iree_make_string_view(result_str.data(), result_str.size())); + return iree_ok_status(); +} + +#define IREE_PRINTVARIANT_CASE_I(SIZE, B, V, STATUS) \ + case IREE_VM_VALUE_TYPE_I##SIZE: \ + STATUS = iree_string_builder_append_format( \ + B, "i" #SIZE "=%" PRIi##SIZE "\n", (V).i##SIZE); \ + break; + +#define IREE_PRINTVARIANT_CASE_F(SIZE, B, V, STATUS) \ + case IREE_VM_VALUE_TYPE_F##SIZE: \ + STATUS = \ + iree_string_builder_append_format(B, "f" #SIZE "=%g\n", (V).f##SIZE); \ + break; + +// Prints variant description including a trailing newline. +static Status PrintVariant(iree_vm_variant_t variant, size_t max_element_count, + iree_string_builder_t* builder) { + if (iree_vm_variant_is_empty(variant)) { + iree_string_builder_append_string(builder, IREE_SVL("(null)\n")); + } else if (iree_vm_variant_is_value(variant)) { + iree_status_t status = iree_ok_status(); + switch (variant.type.value_type) { + IREE_PRINTVARIANT_CASE_I(8, builder, variant, status) + IREE_PRINTVARIANT_CASE_I(16, builder, variant, status) + IREE_PRINTVARIANT_CASE_I(32, builder, variant, status) + IREE_PRINTVARIANT_CASE_I(64, builder, variant, status) + IREE_PRINTVARIANT_CASE_F(32, builder, variant, status) + IREE_PRINTVARIANT_CASE_F(64, builder, variant, status) + default: + status = iree_string_builder_append_string(builder, IREE_SVL("?\n")); + break; + } + IREE_RETURN_IF_ERROR(status); + } else if (iree_vm_variant_is_ref(variant)) { + iree_string_view_t type_name = iree_vm_ref_type_name(variant.type.ref_type); + iree_string_builder_append_string(builder, type_name); + iree_string_builder_append_string(builder, IREE_SVL("\n")); + if (iree_hal_buffer_view_isa(variant.ref)) { + auto* buffer_view = iree_hal_buffer_view_deref(variant.ref); + IREE_RETURN_IF_ERROR( + PrintBufferView(buffer_view, max_element_count, builder)); + iree_string_builder_append_string(builder, IREE_SVL("\n")); + } else { + // TODO(benvanik): a way for ref types to describe themselves. + iree_string_builder_append_string(builder, IREE_SVL("(no printer)\n")); + } + } else { + iree_string_builder_append_string(builder, IREE_SVL("(null)\n")); + } + return OkStatus(); +} + Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count, - std::ostream* os) { + iree_string_builder_t* builder) { + IREE_TRACE_SCOPE(); for (iree_host_size_t i = 0; i < iree_vm_list_size(variant_list); ++i) { iree_vm_variant_t variant = iree_vm_variant_empty(); IREE_RETURN_IF_ERROR(iree_vm_list_get_variant(variant_list, i, &variant), "variant %zu not present", i); - - *os << "result[" << i << "]: "; - if (iree_vm_variant_is_empty(variant)) { - *os << "(null)\n"; - } else if (iree_vm_variant_is_value(variant)) { - switch (variant.type.value_type) { - case IREE_VM_VALUE_TYPE_I8: - *os << "i8=" << variant.i8 << "\n"; - break; - case IREE_VM_VALUE_TYPE_I16: - *os << "i16=" << variant.i16 << "\n"; - break; - case IREE_VM_VALUE_TYPE_I32: - *os << "i32=" << variant.i32 << "\n"; - break; - case IREE_VM_VALUE_TYPE_I64: - *os << "i64=" << variant.i64 << "\n"; - break; - case IREE_VM_VALUE_TYPE_F32: - *os << "f32=" << variant.f32 << "\n"; - break; - case IREE_VM_VALUE_TYPE_F64: - *os << "f64=" << variant.f64 << "\n"; - break; - default: - *os << "?\n"; - break; - } - } else if (iree_vm_variant_is_ref(variant)) { - iree_string_view_t type_name = - iree_vm_ref_type_name(variant.type.ref_type); - *os << std::string(type_name.data, type_name.size) << "\n"; - if (iree_hal_buffer_view_isa(variant.ref)) { - auto* buffer_view = iree_hal_buffer_view_deref(variant.ref); - std::string result_str(4096, '\0'); - iree_status_t status; - do { - iree_host_size_t actual_length = 0; - status = iree_hal_buffer_view_format(buffer_view, max_element_count, - result_str.size() + 1, - &result_str[0], &actual_length); - result_str.resize(actual_length); - } while (iree_status_is_out_of_range(status)); - IREE_RETURN_IF_ERROR(status); - *os << result_str << "\n"; - } else { - // TODO(benvanik): a way for ref types to describe themselves. - *os << "(no printer)\n"; - } - } else { - *os << "(null)\n"; - } + iree_string_builder_append_format(builder, "result[%zu]: ", i); + IREE_RETURN_IF_ERROR(PrintVariant(variant, max_element_count, builder)); } - return OkStatus(); } +Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count, + std::string* out_string) { + iree_string_builder_t builder; + iree_string_builder_initialize(iree_allocator_system(), &builder); + IREE_RETURN_IF_ERROR( + PrintVariantList(variant_list, max_element_count, &builder)); + out_string->assign(iree_string_builder_buffer(&builder), + iree_string_builder_size(&builder)); + iree_string_builder_deinitialize(&builder); + return iree_ok_status(); +} + +Status PrintVariantList(iree_vm_list_t* variant_list, + size_t max_element_count) { + iree_string_builder_t builder; + iree_string_builder_initialize(iree_allocator_system(), &builder); + IREE_RETURN_IF_ERROR( + PrintVariantList(variant_list, max_element_count, &builder)); + printf("%.*s", (int)iree_string_builder_size(&builder), + iree_string_builder_buffer(&builder)); + iree_string_builder_deinitialize(&builder); + return iree_ok_status(); +} + } // namespace iree diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h index b61d449afcca..76c9df3a853f 100644 --- a/runtime/src/iree/tooling/vm_util.h +++ b/runtime/src/iree/tooling/vm_util.h @@ -7,13 +7,12 @@ #ifndef IREE_TOOLING_VM_UTIL_H_ #define IREE_TOOLING_VM_UTIL_H_ -#include -#include #include #include #include "iree/base/internal/span.h" #include "iree/base/status_cc.h" +#include "iree/base/string_builder.h" #include "iree/hal/api.h" #include "iree/vm/api.h" #include "iree/vm/ref_cc.h" @@ -36,23 +35,32 @@ Status ParseToVariantList(iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator, iree_vm_list_t** out_list); -// Prints a variant list of VM scalars and buffers to |os|. +// Appends a variant list of VM scalars and buffers to |builder|. // Prints scalars in the format: // value // Prints buffers in the IREE standard shaped buffer format: // [shape]xtype=[value] // described in // https://github.com/iree-org/iree/tree/main/iree/hal/api.h -Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count, - std::ostream* os); -inline Status PrintVariantList(iree_vm_list_t* variant_list, std::ostream* os) { - return PrintVariantList(variant_list, 1024, os); +Status AppendVariantList(iree_vm_list_t* variant_list, size_t max_element_count, + iree_string_builder_t* builder); +inline Status AppendVariantList(iree_vm_list_t* variant_list, + iree_string_builder_t* builder) { + return AppendVariantList(variant_list, 1024, builder); } +Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count, + std::string* out_string); + +// Prints a variant list to |out_string|. inline Status PrintVariantList(iree_vm_list_t* variant_list, - size_t max_element_count = 1024) { - return PrintVariantList(variant_list, max_element_count, &std::cout); + std::string* out_string) { + return PrintVariantList(variant_list, 1024, out_string); } +// Prints a variant list to stdout. +Status PrintVariantList(iree_vm_list_t* variant_list, + size_t max_element_count = 1024); + } // namespace iree #endif // IREE_TOOLING_VM_UTIL_H_ diff --git a/runtime/src/iree/tooling/vm_util_test.cc b/runtime/src/iree/tooling/vm_util_test.cc index 5bfe891879c2..0992b103e481 100644 --- a/runtime/src/iree/tooling/vm_util_test.cc +++ b/runtime/src/iree/tooling/vm_util_test.cc @@ -52,10 +52,9 @@ TEST_F(VmUtilTest, ParsePrintBuffer) { IREE_ASSERT_OK( ParseToVariantList(allocator_, std::vector{buf_string}, iree_vm_instance_allocator(instance_), &variant_list)); - std::stringstream os; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &os)); - // TODO(benvanik): add a !hal.buffer printer. - EXPECT_EQ(os.str(), + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, std::string("result[0]: hal.buffer\n") + "(no printer)" + "\n"); } @@ -65,9 +64,9 @@ TEST_F(VmUtilTest, ParsePrintBufferView) { IREE_ASSERT_OK( ParseToVariantList(allocator_, std::vector{buf_string}, iree_vm_instance_allocator(instance_), &variant_list)); - std::stringstream os; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &os)); - EXPECT_EQ(os.str(), + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, std::string("result[0]: hal.buffer_view\n") + buf_string + "\n"); } @@ -77,9 +76,9 @@ TEST_F(VmUtilTest, ParsePrintScalar) { IREE_ASSERT_OK( ParseToVariantList(allocator_, std::vector{input_string}, iree_vm_instance_allocator(instance_), &variant_list)); - std::stringstream os; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &os)); - EXPECT_EQ(os.str(), std::string("result[0]: i32=") + input_string + "\n"); + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, std::string("result[0]: i32=") + input_string + "\n"); } TEST_F(VmUtilTest, ParsePrintRank0BufferView) { @@ -88,9 +87,9 @@ TEST_F(VmUtilTest, ParsePrintRank0BufferView) { IREE_ASSERT_OK( ParseToVariantList(allocator_, std::vector{buf_string}, iree_vm_instance_allocator(instance_), &variant_list)); - std::stringstream os; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &os)); - EXPECT_EQ(os.str(), + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, std::string("result[0]: hal.buffer_view\n") + buf_string + "\n"); } @@ -101,11 +100,10 @@ TEST_F(VmUtilTest, ParsePrintMultipleBufferViews) { IREE_ASSERT_OK(ParseToVariantList( allocator_, std::vector{buf_string1, buf_string2}, iree_vm_instance_allocator(instance_), &variant_list)); - std::stringstream os; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &os)); - EXPECT_EQ(os.str(), std::string("result[0]: hal.buffer_view\n") + - buf_string1 + "\nresult[1]: hal.buffer_view\n" + - buf_string2 + "\n"); + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, std::string("result[0]: hal.buffer_view\n") + buf_string1 + + "\nresult[1]: hal.buffer_view\n" + buf_string2 + "\n"); } } // namespace diff --git a/runtime/src/iree/vm/bytecode_dispatch.c b/runtime/src/iree/vm/bytecode_dispatch.c index db38d9e90c48..6cef93b9f43c 100644 --- a/runtime/src/iree/vm/bytecode_dispatch.c +++ b/runtime/src/iree/vm/bytecode_dispatch.c @@ -1158,8 +1158,9 @@ static iree_status_t iree_vm_bytecode_dispatch( iree_host_size_t offset = VM_DecOperandRegI64HostSize("source_offset"); uint32_t* result_ptr = VM_DecResultRegI32("result"); uint16_t result_x16 = 0; - IREE_RETURN_IF_ERROR(iree_vm_buffer_read_elements( - buffer, offset, &result_x16, 1, sizeof(result_x16))); + IREE_RETURN_IF_ERROR( + iree_vm_buffer_read_elements(buffer, offset * sizeof(result_x16), + &result_x16, 1, sizeof(result_x16))); *result_ptr = vm_ext_i16i32u(result_x16); }); DISPATCH_OP(CORE, BufferLoadI16S, { @@ -1174,8 +1175,9 @@ static iree_status_t iree_vm_bytecode_dispatch( iree_host_size_t offset = VM_DecOperandRegI64HostSize("source_offset"); uint32_t* result_ptr = VM_DecResultRegI32("result"); int16_t result_x16 = 0; - IREE_RETURN_IF_ERROR(iree_vm_buffer_read_elements( - buffer, offset, &result_x16, 1, sizeof(result_x16))); + IREE_RETURN_IF_ERROR( + iree_vm_buffer_read_elements(buffer, offset * sizeof(result_x16), + &result_x16, 1, sizeof(result_x16))); *result_ptr = vm_ext_i16i32s(result_x16); }); DISPATCH_OP(CORE, BufferLoadI32, { @@ -1189,8 +1191,8 @@ static iree_status_t iree_vm_bytecode_dispatch( } iree_host_size_t offset = VM_DecOperandRegI64HostSize("source_offset"); uint32_t* result = VM_DecResultRegI32("result"); - IREE_RETURN_IF_ERROR(iree_vm_buffer_read_elements(buffer, offset, result, - 1, sizeof(*result))); + IREE_RETURN_IF_ERROR(iree_vm_buffer_read_elements( + buffer, offset * sizeof(*result), result, 1, sizeof(*result))); }); DISPATCH_OP(CORE, BufferLoadI64, { bool buffer_is_move; @@ -1203,8 +1205,8 @@ static iree_status_t iree_vm_bytecode_dispatch( } iree_host_size_t offset = VM_DecOperandRegI64HostSize("source_offset"); uint64_t* result = VM_DecResultRegI64("result"); - IREE_RETURN_IF_ERROR(iree_vm_buffer_read_elements(buffer, offset, result, - 1, sizeof(*result))); + IREE_RETURN_IF_ERROR(iree_vm_buffer_read_elements( + buffer, offset * sizeof(*result), result, 1, sizeof(*result))); }); // TODO(benvanik): rework dispatch so that the StoreI* ops can share the @@ -1235,8 +1237,8 @@ static iree_status_t iree_vm_bytecode_dispatch( } iree_host_size_t offset = VM_DecOperandRegI64HostSize("target_offset"); uint16_t value = (uint16_t)VM_DecOperandRegI32("value"); - IREE_RETURN_IF_ERROR(iree_vm_buffer_write_elements(&value, buffer, offset, - 1, sizeof(uint16_t))); + IREE_RETURN_IF_ERROR(iree_vm_buffer_write_elements( + &value, buffer, offset * sizeof(value), 1, sizeof(value))); }); DISPATCH_OP(CORE, BufferStoreI32, { bool buffer_is_move; @@ -1249,8 +1251,8 @@ static iree_status_t iree_vm_bytecode_dispatch( } iree_host_size_t offset = VM_DecOperandRegI64HostSize("target_offset"); uint32_t value = VM_DecOperandRegI32("value"); - IREE_RETURN_IF_ERROR(iree_vm_buffer_write_elements(&value, buffer, offset, - 1, sizeof(uint32_t))); + IREE_RETURN_IF_ERROR(iree_vm_buffer_write_elements( + &value, buffer, offset * sizeof(value), 1, sizeof(value))); }); DISPATCH_OP(CORE, BufferStoreI64, { bool buffer_is_move; @@ -1263,8 +1265,8 @@ static iree_status_t iree_vm_bytecode_dispatch( } iree_host_size_t offset = VM_DecOperandRegI64HostSize("target_offset"); uint64_t value = (uint64_t)VM_DecOperandRegI64("value"); - IREE_RETURN_IF_ERROR(iree_vm_buffer_write_elements(&value, buffer, offset, - 1, sizeof(uint64_t))); + IREE_RETURN_IF_ERROR(iree_vm_buffer_write_elements( + &value, buffer, offset * sizeof(value), 1, sizeof(value))); }); //===------------------------------------------------------------------===// @@ -2124,7 +2126,7 @@ static iree_status_t iree_vm_bytecode_dispatch( iree_host_size_t offset = VM_DecOperandRegI64HostSize("source_offset"); float* result = VM_DecResultRegF32("result"); IREE_RETURN_IF_ERROR(iree_vm_buffer_read_elements( - buffer, offset, result, 1, sizeof(*result))); + buffer, offset * sizeof(*result), result, 1, sizeof(*result))); }); DISPATCH_OP(EXT_F32, BufferStoreF32, { @@ -2139,7 +2141,7 @@ static iree_status_t iree_vm_bytecode_dispatch( iree_host_size_t offset = VM_DecOperandRegI64HostSize("target_offset"); float value = VM_DecOperandRegF32("value"); IREE_RETURN_IF_ERROR(iree_vm_buffer_write_elements( - &value, buffer, offset, 1, sizeof(float))); + &value, buffer, offset * sizeof(value), 1, sizeof(value))); }); } END_DISPATCH_PREFIX(); diff --git a/runtime/src/iree/vm/bytecode_dispatch_test.cc b/runtime/src/iree/vm/bytecode_dispatch_test.cc index 58ea3821f829..9de8ff044cc9 100644 --- a/runtime/src/iree/vm/bytecode_dispatch_test.cc +++ b/runtime/src/iree/vm/bytecode_dispatch_test.cc @@ -129,7 +129,7 @@ TEST_P(VMBytecodeDispatchTest, Check) { GTEST_SUCCEED(); } else { GTEST_FAIL() << "Function expected success but failed with error: " - << iree::Status(std::move(status)); + << iree::Status(std::move(status)).ToString(); } } } diff --git a/runtime/src/iree/vm/bytecode_module_benchmark.mlir b/runtime/src/iree/vm/bytecode_module_benchmark.mlir index 5ceed1ef4de4..1459bd9e8f45 100644 --- a/runtime/src/iree/vm/bytecode_module_benchmark.mlir +++ b/runtime/src/iree/vm/bytecode_module_benchmark.mlir @@ -82,20 +82,20 @@ vm.module @bytecode_module_benchmark { vm.func @buffer_reduce(%count : i32) -> i32 { %c0 = vm.const.i64.zero %c0_i32 = vm.const.i32.zero - %c1 = vm.const.i32 1 - %c4 = vm.const.i32 4 - %max = vm.mul.i32 %count, %c4 : i32 - %max_i64 = vm.ext.i32.i64.u %max : i32 -> i64 - %buf = vm.buffer.alloc %max_i64 : !vm.buffer - vm.buffer.fill.i32 %buf, %c0, %max_i64, %c1 : i32 -> !vm.buffer - vm.br ^loop(%c0_i32, %c0_i32 : i32, i32) - ^loop(%i : i32, %sum : i32): - %i_i64 = vm.ext.i32.i64.u %i : i32 -> i64 - %element = vm.buffer.load.i32 %buf[%i_i64] : !vm.buffer -> i32 + %pattern = vm.const.i32 1 + %c1 = vm.const.i64 1 + %c4 = vm.const.i64 4 + %count_i64 = vm.ext.i32.i64.u %count : i32 -> i64 + %max = vm.mul.i64 %count_i64, %c4 : i64 + %buf = vm.buffer.alloc %max : !vm.buffer + vm.buffer.fill.i32 %buf, %c0, %max, %pattern : i32 -> !vm.buffer + vm.br ^loop(%c0, %c0_i32 : i64, i32) + ^loop(%i : i64, %sum : i32): + %element = vm.buffer.load.i32 %buf[%i] : !vm.buffer -> i32 %new_sum = vm.add.i32 %sum, %element : i32 - %ip4 = vm.add.i32 %i, %c4 : i32 - %cmp = vm.cmp.lt.i32.s %ip4, %max : i32 - vm.cond_br %cmp, ^loop(%ip4, %new_sum : i32, i32), ^loop_exit(%new_sum : i32) + %ip1 = vm.add.i64 %i, %c1 : i64 + %cmp = vm.cmp.lt.i64.s %ip1, %count_i64 : i64 + vm.cond_br %cmp, ^loop(%ip1, %new_sum : i64, i32), ^loop_exit(%new_sum : i32) ^loop_exit(%result : i32): vm.return %result : i32 } @@ -105,30 +105,31 @@ vm.module @bytecode_module_benchmark { vm.export @buffer_reduce_unrolled vm.func @buffer_reduce_unrolled(%count : i32) -> i32 { %c0 = vm.const.i64.zero - %c1 = vm.const.i32 1 + %pattern = vm.const.i32 1 + %c1 = vm.const.i64 1 %c4 = vm.const.i64 4 %count_i64 = vm.ext.i32.i64.u %count : i32 -> i64 %max = vm.mul.i64 %count_i64, %c4 : i64 %buf = vm.buffer.alloc %max : !vm.buffer - vm.buffer.fill.i32 %buf, %c0, %max, %c1 : i32 -> !vm.buffer + vm.buffer.fill.i32 %buf, %c0, %max, %pattern : i32 -> !vm.buffer %sum_init = vm.const.i32.zero vm.br ^loop(%c0, %sum_init : i64, i32) ^loop(%i0 : i64, %sum : i32): // TODO(#5544): add addressing modes to load/store. %e0 = vm.buffer.load.i32 %buf[%i0] : !vm.buffer -> i32 - %i1 = vm.add.i64 %i0, %c4 : i64 + %i1 = vm.add.i64 %i0, %c1 : i64 %e1 = vm.buffer.load.i32 %buf[%i1] : !vm.buffer -> i32 - %i2 = vm.add.i64 %i1, %c4 : i64 + %i2 = vm.add.i64 %i1, %c1 : i64 %e2 = vm.buffer.load.i32 %buf[%i2] : !vm.buffer -> i32 - %i3 = vm.add.i64 %i2, %c4 : i64 + %i3 = vm.add.i64 %i2, %c1 : i64 %e3 = vm.buffer.load.i32 %buf[%i3] : !vm.buffer -> i32 - %i4 = vm.add.i64 %i3, %c4 : i64 + %i4 = vm.add.i64 %i3, %c1 : i64 %e4 = vm.buffer.load.i32 %buf[%i4] : !vm.buffer -> i32 - %i5 = vm.add.i64 %i4, %c4 : i64 + %i5 = vm.add.i64 %i4, %c1 : i64 %e5 = vm.buffer.load.i32 %buf[%i5] : !vm.buffer -> i32 - %i6 = vm.add.i64 %i5, %c4 : i64 + %i6 = vm.add.i64 %i5, %c1 : i64 %e6 = vm.buffer.load.i32 %buf[%i6] : !vm.buffer -> i32 - %i7 = vm.add.i64 %i6, %c4 : i64 + %i7 = vm.add.i64 %i6, %c1 : i64 %e7 = vm.buffer.load.i32 %buf[%i7] : !vm.buffer -> i32 // If we do reductions like this we could add a horizontal-add op. %new_sum0 = vm.add.i32 %sum, %e0 : i32 @@ -139,8 +140,8 @@ vm.module @bytecode_module_benchmark { %new_sum5 = vm.add.i32 %new_sum4, %e5 : i32 %new_sum6 = vm.add.i32 %new_sum5, %e6 : i32 %new_sum7 = vm.add.i32 %new_sum6, %e7 : i32 - %next_i = vm.add.i64 %i7, %c4 : i64 - %cmp = vm.cmp.lt.i64.s %next_i, %max : i64 + %next_i = vm.add.i64 %i7, %c1 : i64 + %cmp = vm.cmp.lt.i64.s %next_i, %count_i64 : i64 vm.cond_br %cmp, ^loop(%next_i, %new_sum7 : i64, i32), ^loop_exit(%new_sum7 : i32) ^loop_exit(%result : i32): vm.return %result : i32 diff --git a/runtime/src/iree/vm/bytecode_module_impl.h b/runtime/src/iree/vm/bytecode_module_impl.h index 9916bb9ff281..6090f94643e3 100644 --- a/runtime/src/iree/vm/bytecode_module_impl.h +++ b/runtime/src/iree/vm/bytecode_module_impl.h @@ -33,7 +33,7 @@ extern "C" { // Major bytecode version; mismatches on this will fail in either direction. // This allows coarse versioning of completely incompatible versions. // Matches BytecodeEncoder::kVersionMajor in the compiler. -#define IREE_VM_BYTECODE_VERSION_MAJOR 12 +#define IREE_VM_BYTECODE_VERSION_MAJOR 13 // Minor bytecode version; lower versions are allowed to enable newer runtimes // to load older serialized files when there are backwards-compatible changes. // Higher versions are disallowed as they occur when new ops are added that diff --git a/runtime/src/iree/vm/test/buffer_ops.mlir b/runtime/src/iree/vm/test/buffer_ops.mlir index dd9044ef13c7..ba6aab1883d1 100644 --- a/runtime/src/iree/vm/test/buffer_ops.mlir +++ b/runtime/src/iree/vm/test/buffer_ops.mlir @@ -417,24 +417,24 @@ vm.module @buffer_ops { vm.export @test_load_i16u attributes {emitc.exclude} vm.func private @test_load_i16u() { %c0 = vm.const.i64 0 + %c1 = vm.const.i64 1 %c2 = vm.const.i64 2 + %c3 = vm.const.i64 3 %c4 = vm.const.i64 4 - %c6 = vm.const.i64 6 - %c8 = vm.const.i64 8 %rodata = vm.const.ref.rodata @test_load_i16_data : !vm.buffer %v0 = vm.buffer.load.i16.u %rodata[%c0] : !vm.buffer -> i32 %e0 = vm.const.i32 0 vm.check.eq %v0, %e0, "0" : i32 - %v1 = vm.buffer.load.i16.u %rodata[%c2] : !vm.buffer -> i32 + %v1 = vm.buffer.load.i16.u %rodata[%c1] : !vm.buffer -> i32 %e1 = vm.const.i32 1 vm.check.eq %v1, %e1, "1" : i32 - %v2 = vm.buffer.load.i16.u %rodata[%c4] : !vm.buffer -> i32 + %v2 = vm.buffer.load.i16.u %rodata[%c2] : !vm.buffer -> i32 %e2 = vm.const.i32 0x7FFF vm.check.eq %v2, %e2, "0x7FFF" : i32 - %v3 = vm.buffer.load.i16.u %rodata[%c6] : !vm.buffer -> i32 + %v3 = vm.buffer.load.i16.u %rodata[%c3] : !vm.buffer -> i32 %e3 = vm.const.i32 0x8000 vm.check.eq %v3, %e3, "0x8000" : i32 - %v4 = vm.buffer.load.i16.u %rodata[%c8] : !vm.buffer -> i32 + %v4 = vm.buffer.load.i16.u %rodata[%c4] : !vm.buffer -> i32 %e4 = vm.const.i32 0xFFFF vm.check.eq %v4, %e4, "0xFFFF" : i32 vm.return @@ -443,24 +443,24 @@ vm.module @buffer_ops { vm.export @test_load_i16s attributes {emitc.exclude} vm.func private @test_load_i16s() { %c0 = vm.const.i64 0 + %c1 = vm.const.i64 1 %c2 = vm.const.i64 2 + %c3 = vm.const.i64 3 %c4 = vm.const.i64 4 - %c6 = vm.const.i64 6 - %c8 = vm.const.i64 8 %rodata = vm.const.ref.rodata @test_load_i16_data : !vm.buffer %v0 = vm.buffer.load.i16.s %rodata[%c0] : !vm.buffer -> i32 %e0 = vm.const.i32 0 vm.check.eq %v0, %e0, "0" : i32 - %v1 = vm.buffer.load.i16.s %rodata[%c2] : !vm.buffer -> i32 + %v1 = vm.buffer.load.i16.s %rodata[%c1] : !vm.buffer -> i32 %e1 = vm.const.i32 1 vm.check.eq %v1, %e1, "1" : i32 - %v2 = vm.buffer.load.i16.s %rodata[%c4] : !vm.buffer -> i32 + %v2 = vm.buffer.load.i16.s %rodata[%c2] : !vm.buffer -> i32 %e2 = vm.const.i32 0x7FFF vm.check.eq %v2, %e2, "0x7FFF" : i32 - %v3 = vm.buffer.load.i16.s %rodata[%c6] : !vm.buffer -> i32 + %v3 = vm.buffer.load.i16.s %rodata[%c3] : !vm.buffer -> i32 %e3 = vm.const.i32 -32768 vm.check.eq %v3, %e3, "-32768" : i32 - %v4 = vm.buffer.load.i16.s %rodata[%c8] : !vm.buffer -> i32 + %v4 = vm.buffer.load.i16.s %rodata[%c4] : !vm.buffer -> i32 %e4 = vm.const.i32 -1 vm.check.eq %v4, %e4, "-1" : i32 vm.return @@ -471,45 +471,29 @@ vm.module @buffer_ops { vm.export @test_load_i32 attributes {emitc.exclude} vm.func private @test_load_i32() { %c0 = vm.const.i64 0 + %c1 = vm.const.i64 1 + %c2 = vm.const.i64 2 + %c3 = vm.const.i64 3 %c4 = vm.const.i64 4 - %c8 = vm.const.i64 8 - %c12 = vm.const.i64 12 - %c16 = vm.const.i64 16 %rodata = vm.const.ref.rodata @test_load_i32_data : !vm.buffer %v0 = vm.buffer.load.i32 %rodata[%c0] : !vm.buffer -> i32 %e0 = vm.const.i32 0 vm.check.eq %v0, %e0, "0" : i32 - %v1 = vm.buffer.load.i32 %rodata[%c4] : !vm.buffer -> i32 + %v1 = vm.buffer.load.i32 %rodata[%c1] : !vm.buffer -> i32 %e1 = vm.const.i32 1 vm.check.eq %v1, %e1, "1" : i32 - %v2 = vm.buffer.load.i32 %rodata[%c8] : !vm.buffer -> i32 + %v2 = vm.buffer.load.i32 %rodata[%c2] : !vm.buffer -> i32 %e2 = vm.const.i32 0x7FFFFFFF vm.check.eq %v2, %e2, "0x7FFFFFFF" : i32 - %v3 = vm.buffer.load.i32 %rodata[%c12] : !vm.buffer -> i32 + %v3 = vm.buffer.load.i32 %rodata[%c3] : !vm.buffer -> i32 %e3 = vm.const.i32 0x80000000 vm.check.eq %v3, %e3, "0x80000000" : i32 - %v4 = vm.buffer.load.i32 %rodata[%c16] : !vm.buffer -> i32 + %v4 = vm.buffer.load.i32 %rodata[%c4] : !vm.buffer -> i32 %e4 = vm.const.i32 0xFFFFFFFF vm.check.eq %v4, %e4, "0xFFFFFFFF" : i32 vm.return } - vm.rodata private @test_load_i32_unaligned_data dense<[0x00112233, 0x44556677, 0x8899AABB, 0xCCDDEEFF]> : tensor<4xui32> - - // Unaligned loads are not supported and offsets will be rounded down. - vm.export @test_load_i32_unaligned attributes {emitc.exclude} - vm.func private @test_load_i32_unaligned() { - %rodata = vm.const.ref.rodata @test_load_i32_unaligned_data : !vm.buffer - - // Byte offset 5 rounded to byte offset 4 (element 1). - %c5 = vm.const.i64 5 - %v1 = vm.buffer.load.i32 %rodata[%c5] : !vm.buffer -> i32 - %e1 = vm.const.i32 0x44556677 - vm.check.eq %v1, %e1, "0x44556677" : i32 - - vm.return - } - //===--------------------------------------------------------------------===// // Store //===--------------------------------------------------------------------===// @@ -561,18 +545,18 @@ vm.module @buffer_ops { %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 vm.buffer.store.i16 %e0, %buf_dno[%c0] : i32 -> !vm.buffer - %c2 = vm.const.i64 2 + %c1 = vm.const.i64 1 %e1 = vm.const.i32 1 - vm.buffer.store.i16 %e1, %buf_dno[%c2] : i32 -> !vm.buffer - %c4 = vm.const.i64 4 + vm.buffer.store.i16 %e1, %buf_dno[%c1] : i32 -> !vm.buffer + %c2 = vm.const.i64 2 %e2 = vm.const.i32 0x7FFF - vm.buffer.store.i16 %e2, %buf_dno[%c4] : i32 -> !vm.buffer - %c6 = vm.const.i64 6 + vm.buffer.store.i16 %e2, %buf_dno[%c2] : i32 -> !vm.buffer + %c3 = vm.const.i64 3 %e3 = vm.const.i32 0x8000 - vm.buffer.store.i16 %e3, %buf_dno[%c6] : i32 -> !vm.buffer - %c8 = vm.const.i64 8 + vm.buffer.store.i16 %e3, %buf_dno[%c3] : i32 -> !vm.buffer + %c4 = vm.const.i64 4 %e4 = vm.const.i32 0xFFFF - vm.buffer.store.i16 %e4, %buf_dno[%c8] : i32 -> !vm.buffer + vm.buffer.store.i16 %e4, %buf_dno[%c4] : i32 -> !vm.buffer %cmp = vm.buffer.compare %ref_dno, %c0, %buf_dno, %c0, %ref_length : !vm.buffer, !vm.buffer vm.check.nz %cmp, "source and target match" : i32 @@ -594,18 +578,18 @@ vm.module @buffer_ops { %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 vm.buffer.store.i32 %e0, %buf_dno[%c0] : i32 -> !vm.buffer - %c4 = vm.const.i64 4 + %c1 = vm.const.i64 1 %e1 = vm.const.i32 1 - vm.buffer.store.i32 %e1, %buf_dno[%c4] : i32 -> !vm.buffer - %c8 = vm.const.i64 8 + vm.buffer.store.i32 %e1, %buf_dno[%c1] : i32 -> !vm.buffer + %c2 = vm.const.i64 2 %e2 = vm.const.i32 0x7FFFFFFF - vm.buffer.store.i32 %e2, %buf_dno[%c8] : i32 -> !vm.buffer - %c12 = vm.const.i64 12 + vm.buffer.store.i32 %e2, %buf_dno[%c2] : i32 -> !vm.buffer + %c3 = vm.const.i64 3 %e3 = vm.const.i32 0x80000000 - vm.buffer.store.i32 %e3, %buf_dno[%c12] : i32 -> !vm.buffer - %c16 = vm.const.i64 16 + vm.buffer.store.i32 %e3, %buf_dno[%c3] : i32 -> !vm.buffer + %c4 = vm.const.i64 4 %e4 = vm.const.i32 0xFFFFFFFF - vm.buffer.store.i32 %e4, %buf_dno[%c16] : i32 -> !vm.buffer + vm.buffer.store.i32 %e4, %buf_dno[%c4] : i32 -> !vm.buffer %cmp = vm.buffer.compare %ref_dno, %c0, %buf_dno, %c0, %ref_length : !vm.buffer, !vm.buffer vm.check.nz %cmp, "source and target match" : i32 @@ -613,24 +597,4 @@ vm.module @buffer_ops { vm.return } - // Unaligned stores are not supported and offsets will be rounded down. - vm.export @test_store_i32_unaligned attributes {emitc.exclude} - vm.func private @test_store_i32_unaligned() { - %c12 = vm.const.i64 12 - %buf = vm.buffer.alloc %c12 : !vm.buffer - %buf_dno = util.do_not_optimize(%buf) : !vm.buffer - - // Byte offset 5 rounded to byte offset 4 (element 1). - %c5 = vm.const.i64 5 - %e1 = vm.const.i32 0x44556677 - vm.buffer.store.i32 %e1, %buf_dno[%c5] : i32 -> !vm.buffer - - // Read back at offset 4 (where the data should be). - %c4 = vm.const.i64 4 - %a1 = vm.buffer.load.i32 %buf_dno[%c4] : !vm.buffer -> i32 - vm.check.eq %a1, %e1, "0x44556677" : i32 - - vm.return - } - } diff --git a/samples/colab/edge_detection.ipynb b/samples/colab/edge_detection.ipynb index 10fa776d4182..6a0f9a2817a3 100644 --- a/samples/colab/edge_detection.ipynb +++ b/samples/colab/edge_detection.ipynb @@ -330,7 +330,9 @@ "# application, we would probably want to freeze the version of IREE used and\n", "# compile as completely as possible ahead of time, then use some other scheme\n", "# to load the module into the application at runtime.\n", - "compiler_module = tfc.compile_module(EdgeDetectionModule(), import_only=True)\n", + "compiler_module = tfc.compile_module(\n", + " EdgeDetectionModule(), import_only=True,\n", + " import_extra_args=[\"--output-format=mlir-ir\"])\n", "print(\"Edge Detection MLIR: \", compiler_module.decode('utf-8'))\n", "\n", "edge_detection_mlir_path = os.path.join(ARTIFACTS_DIR, \"edge_detection.mlir\")\n", diff --git a/samples/colab/tflite_text_classification.ipynb b/samples/colab/tflite_text_classification.ipynb index b4a94d258b31..3d5dfb39f775 100644 --- a/samples/colab/tflite_text_classification.ipynb +++ b/samples/colab/tflite_text_classification.ipynb @@ -64,7 +64,7 @@ "import tflite_runtime.interpreter as tflite\n", "\n", "from iree import runtime as iree_rt\n", - "from iree.compiler import compile_str\n", + "from iree.compiler import compile_file, compile_str\n", "from iree.tools import tflite as iree_tflite\n", "\n", "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", @@ -320,14 +320,12 @@ }, "outputs": [], "source": [ - "# Convert TFLite model to TOSA MLIR with IREE's import tool.\n", + "# Convert TFLite model to TOSA MLIR (bytecode) with IREE's import tool.\n", "IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')\n", - "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={ARTIFACTS_DIR}/text_classification.mlir\n", + "tosa_mlirbc_file = ARTIFACTS_DIR.joinpath(\"text_classification.mlirbc\")\n", + "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={tosa_mlirbc_file}\n", "\n", - "with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n", - " tosa_mlir = mlir_file.read()\n", - "\n", - "# The generated .mlir file could now be saved and used outside of Python, with\n", + "# The generated .mlirbc file could now be saved and used outside of Python, with\n", "# IREE native tools or in apps, etc." ] }, @@ -348,16 +346,16 @@ "text": [ "module {\n", " func.func @main(%arg0: tensor<1x256xi32> {iree.identifier = \"input_5\"}) -> (tensor<1x2xf32> {iree.identifier = \"Identity\"}) {\n", - " %0 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<1x1xf32>} : () -> tensor<1x1xf32>\n", - " %1 = \"tosa.const\"() {value = opaque<\"elided_large_const\", \"0xDEADBEEF\"> : tensor<1x10003x16xf32>} : () -> tensor<1x10003x16xf32>\n", - " %2 = \"tosa.const\"() {value = opaque<\"elided_large_const\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n", + " %0 = \"tosa.const\"() {value = dense_resource<__elided__> : tensor<1x10003x16xf32>} : () -> tensor<1x10003x16xf32>\n", + " %1 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<1x1xf32>} : () -> tensor<1x1xf32>\n", + " %2 = \"tosa.const\"() {value = dense_resource<__elided__> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n", " %3 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n", " %4 = \"tosa.const\"() {value = dense<[[0.091361463, -1.23269629, 1.33242488, 0.92142266, -0.445623249, 0.849273681, -1.27237022, 1.28574562, 0.436188251, -0.963210225, 0.745473146, -0.255745709, -1.4491415, -1.4687326, 0.900665163, -1.36293614], [-0.0968776941, 0.771379471, -1.36363328, -1.1110599, -0.304591209, -1.05579722, 0.795746565, -1.3122592, 0.352218777, 1.04682362, -1.18796027, -0.0409261398, 1.05883229, 1.48620188, -1.13325548, 1.03072512]]> : tensor<2x16xf32>} : () -> tensor<2x16xf32>\n", " %5 = \"tosa.const\"() {value = dense<[0.043447677, -0.0434476472]> : tensor<2xf32>} : () -> tensor<2xf32>\n", - " %6 = \"tosa.gather\"(%1, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n", + " %6 = \"tosa.gather\"(%0, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n", " %7 = \"tosa.reduce_sum\"(%6) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n", " %8 = \"tosa.reshape\"(%7) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n", - " %9 = \"tosa.mul\"(%8, %0) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n", + " %9 = \"tosa.mul\"(%8, %1) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n", " %10 = \"tosa.fully_connected\"(%9, %2, %3) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n", " %11 = \"tosa.clamp\"(%10) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n", " %12 = \"tosa.fully_connected\"(%11, %4, %5) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n", @@ -374,7 +372,7 @@ ], "source": [ "# The model contains very large constants, so recompile a truncated version to print.\n", - "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={ARTIFACTS_DIR}/text_classification_truncated.mlir --mlir-elide-elementsattrs-if-larger=50\n", + "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={ARTIFACTS_DIR}/text_classification_truncated.mlir --output-format=mlir-ir --mlir-elide-elementsattrs-if-larger=50\n", "\n", "with open(ARTIFACTS_DIR.joinpath(\"text_classification_truncated.mlir\")) as truncated_mlir_file:\n", " truncated_tosa_mlir = truncated_mlir_file.read()\n", @@ -390,7 +388,7 @@ "outputs": [], "source": [ "# Compile the TOSA MLIR into a VM module.\n", - "compiled_flatbuffer = compile_str(tosa_mlir, input_type=\"tosa\", target_backends=[\"vmvx\"])\n", + "compiled_flatbuffer = compile_file(tosa_mlirbc_file, input_type=\"tosa\", target_backends=[\"vmvx\"])\n", "\n", "# Register the module with a runtime context.\n", "config = iree_rt.Config(\"local-task\")\n", diff --git a/samples/dynamic_shapes/dynamic_shapes.ipynb b/samples/dynamic_shapes/dynamic_shapes.ipynb index b2d04e4016cf..257713e8814e 100644 --- a/samples/dynamic_shapes/dynamic_shapes.ipynb +++ b/samples/dynamic_shapes/dynamic_shapes.ipynb @@ -164,7 +164,8 @@ "\n", "compiler_module = tfc.compile_module(\n", " DynamicShapesModule(), import_only=True, \n", - " output_mlir_debuginfo=False)\n", + " output_mlir_debuginfo=False,\n", + " import_extra_args=[\"--output-format=mlir-ir\"])\n", "clear_output() # Skip over TensorFlow's output.\n", "\n", "# Print the imported MLIR to see how the compiler views this program.\n", @@ -184,45 +185,33 @@ "text": [ "Dynamic Shapes MLIR:\n", "```\n", - "\"builtin.module\"() ({\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: !iree_input.buffer_view):\n", - " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor\n", - " %1 = \"func.call\"(%0) {callee = @__inference_add_one_70} : (tensor) -> tensor\n", - " %2 = \"iree_input.cast.tensor_to_buffer_view\"(%1) : (tensor) -> !iree_input.buffer_view\n", - " \"func.return\"(%2) : (!iree_input.buffer_view) -> ()\n", - " }) {function_type = (!iree_input.buffer_view) -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22v\\22:1}\", sym_name = \"add_one\"} : () -> ()\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: tensor):\n", - " %0 = \"mhlo.constant\"() {value = dense<1> : tensor} : () -> tensor\n", - " %1 = \"chlo.broadcast_add\"(%arg0, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor\n", - " \"func.return\"(%1) : (tensor) -> ()\n", - " }) {arg_attrs = [{tf._user_specified_name = \"values\"}], function_type = (tensor) -> tensor, sym_name = \"__inference_add_one_70\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape]} : () -> ()\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: !iree_input.buffer_view):\n", - " %0 = \"mhlo.constant\"() {value = dense<0> : tensor} : () -> tensor\n", - " %1 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor\n", - " %2 = \"mhlo.reduce\"(%1, %0) ({\n", - " ^bb0(%arg1: tensor, %arg2: tensor):\n", - " %4 = \"mhlo.add\"(%arg1, %arg2) : (tensor, tensor) -> tensor\n", - " \"mhlo.return\"(%4) : (tensor) -> ()\n", - " }) {dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor) -> tensor\n", - " %3 = \"iree_input.cast.tensor_to_buffer_view\"(%2) : (tensor) -> !iree_input.buffer_view\n", - " \"func.return\"(%3) : (!iree_input.buffer_view) -> ()\n", - " }) {function_type = (!iree_input.buffer_view) -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\", sym_name = \"reduce_sum_1d\"} : () -> ()\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: !iree_input.buffer_view):\n", - " %0 = \"mhlo.constant\"() {value = dense<0> : tensor} : () -> tensor\n", - " %1 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor\n", - " %2 = \"mhlo.reduce\"(%1, %0) ({\n", - " ^bb0(%arg1: tensor, %arg2: tensor):\n", - " %4 = \"mhlo.add\"(%arg1, %arg2) : (tensor, tensor) -> tensor\n", - " \"mhlo.return\"(%4) : (tensor) -> ()\n", - " }) {dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor) -> tensor<3xi32>\n", - " %3 = \"iree_input.cast.tensor_to_buffer_view\"(%2) : (tensor<3xi32>) -> !iree_input.buffer_view\n", - " \"func.return\"(%3) : (!iree_input.buffer_view) -> ()\n", - " }) {function_type = (!iree_input.buffer_view) -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,2,null,3]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,3]],\\22v\\22:1}\", sym_name = \"reduce_sum_2d\"} : () -> ()\n", - "}) : () -> ()\n", + "module {\n", + " func.func @add_one(%arg0: !iree_input.buffer_view) -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22v\\22:1}\"} {\n", + " %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor\n", + " %1 = call @__inference_add_one_70(%0) : (tensor) -> tensor\n", + " %2 = iree_input.cast.tensor_to_buffer_view %1 : tensor -> !iree_input.buffer_view\n", + " return %2 : !iree_input.buffer_view\n", + " }\n", + " func.func private @__inference_add_one_70(%arg0: tensor {tf._user_specified_name = \"values\"}) -> tensor attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape]} {\n", + " %0 = mhlo.constant dense<1> : tensor\n", + " %1 = chlo.broadcast_add %arg0, %0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor\n", + " return %1 : tensor\n", + " }\n", + " func.func @reduce_sum_1d(%arg0: !iree_input.buffer_view) -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n", + " %0 = mhlo.constant dense<0> : tensor\n", + " %1 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor\n", + " %2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [0] : (tensor, tensor) -> tensor\n", + " %3 = iree_input.cast.tensor_to_buffer_view %2 : tensor -> !iree_input.buffer_view\n", + " return %3 : !iree_input.buffer_view\n", + " }\n", + " func.func @reduce_sum_2d(%arg0: !iree_input.buffer_view) -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,2,null,3]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,3]],\\22v\\22:1}\"} {\n", + " %0 = mhlo.constant dense<0> : tensor\n", + " %1 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor\n", + " %2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [0] : (tensor, tensor) -> tensor<3xi32>\n", + " %3 = iree_input.cast.tensor_to_buffer_view %2 : tensor<3xi32> -> !iree_input.buffer_view\n", + " return %3 : !iree_input.buffer_view\n", + " }\n", + "}\n", "\n", "```\n", "\n", diff --git a/samples/static_library/CMakeLists.txt b/samples/static_library/CMakeLists.txt index 5a226f33c179..17a173dd8edb 100644 --- a/samples/static_library/CMakeLists.txt +++ b/samples/static_library/CMakeLists.txt @@ -20,6 +20,7 @@ list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu") list(APPEND _COMPILE_ARGS "--iree-llvm-link-embedded=false") list(APPEND _COMPILE_ARGS "--iree-llvm-link-static") list(APPEND _COMPILE_ARGS "--iree-llvm-static-library-output-path=simple_mul.o") +list(APPEND _COMPILE_ARGS "--iree-vm-target-index-bits=32") list(APPEND _COMPILE_ARGS "${CMAKE_CURRENT_SOURCE_DIR}/simple_mul.mlir") list(APPEND _COMPILE_ARGS "-o") list(APPEND _COMPILE_ARGS "simple_mul.vmfb") @@ -106,6 +107,7 @@ list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu") list(APPEND _COMPILE_ARGS "--iree-llvm-link-embedded=false") list(APPEND _COMPILE_ARGS "--iree-llvm-link-static") list(APPEND _COMPILE_ARGS "--iree-llvm-static-library-output-path=simple_mul_c_module.o") +list(APPEND _COMPILE_ARGS "--iree-vm-target-index-bits=32") list(APPEND _COMPILE_ARGS "${CMAKE_CURRENT_SOURCE_DIR}/simple_mul.mlir") list(APPEND _COMPILE_ARGS "-o") list(APPEND _COMPILE_ARGS "simple_mul_emitc.h") diff --git a/samples/variables_and_state/variables_and_state.ipynb b/samples/variables_and_state/variables_and_state.ipynb index 5c0ca86e79d7..ed4f70b6914f 100644 --- a/samples/variables_and_state/variables_and_state.ipynb +++ b/samples/variables_and_state/variables_and_state.ipynb @@ -168,7 +168,9 @@ "from iree.compiler import tf as tfc\n", "\n", "compiler_module = tfc.compile_module(\n", - " CounterModule(), import_only=True, output_mlir_debuginfo=False)\n", + " CounterModule(), import_only=True,\n", + " output_mlir_debuginfo=False,\n", + " import_extra_args=[\"--output-format=mlir-ir\"])\n", "clear_output() # Skip over TensorFlow's output.\n", "\n", "# Print the imported MLIR to see how the compiler views this TensorFlow program.\n", @@ -189,55 +191,47 @@ "text": [ "Counter MLIR:\n", "```\n", - "\"builtin.module\"() ({\n", - " \"iree_input.global\"() {initial_value = dense<0> : tensor, is_mutable, sym_name = \"counter\", sym_visibility = \"private\", type = tensor} : () -> ()\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: !iree_input.buffer_view):\n", - " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor\n", - " \"func.call\"(%0) {callee = @__inference_add_to_value_100} : (tensor) -> ()\n", - " \"func.return\"() : () -> ()\n", - " }) {function_type = (!iree_input.buffer_view) -> (), iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"add_to_value\"} : () -> ()\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: tensor):\n", - " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr>\n", - " %1 = \"iree_input.global.load.indirect\"(%0) : (!iree_input.ptr>) -> tensor\n", - " %2 = \"chlo.broadcast_add\"(%1, %arg0) : (tensor, tensor) -> tensor\n", - " \"iree_input.global.store.indirect\"(%2, %0) : (tensor, !iree_input.ptr>) -> ()\n", - " \"func.return\"() : () -> ()\n", - " }) {arg_attrs = [{tf._user_specified_name = \"x\"}], function_type = (tensor) -> (), sym_name = \"__inference_add_to_value_100\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", - " \"func.func\"() ({\n", - " %0 = \"func.call\"() {callee = @__inference_get_value_160} : () -> tensor\n", - " %1 = \"iree_input.cast.tensor_to_buffer_view\"(%0) : (tensor) -> !iree_input.buffer_view\n", - " \"func.return\"(%1) : (!iree_input.buffer_view) -> ()\n", - " }) {function_type = () -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\", sym_name = \"get_value\"} : () -> ()\n", - " \"func.func\"() ({\n", - " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr>\n", - " %1 = \"iree_input.global.load.indirect\"(%0) : (!iree_input.ptr>) -> tensor\n", - " \"func.return\"(%1) : (tensor) -> ()\n", - " }) {function_type = () -> tensor, sym_name = \"__inference_get_value_160\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", - " \"func.func\"() ({\n", - " \"func.call\"() {callee = @__inference_reset_value_270} : () -> ()\n", - " \"func.return\"() : () -> ()\n", - " }) {function_type = () -> (), iree.abi = \"{\\22a\\22:[],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"reset_value\"} : () -> ()\n", - " \"func.func\"() ({\n", - " %0 = \"mhlo.constant\"() {value = dense<0> : tensor} : () -> tensor\n", - " %1 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr>\n", - " \"iree_input.global.store.indirect\"(%0, %1) : (tensor, !iree_input.ptr>) -> ()\n", - " \"func.return\"() : () -> ()\n", - " }) {function_type = () -> (), sym_name = \"__inference_reset_value_270\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: !iree_input.buffer_view):\n", - " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor\n", - " \"func.call\"(%0) {callee = @__sm_exported___inference_set_value_230} : (tensor) -> ()\n", - " \"func.return\"() : () -> ()\n", - " }) {function_type = (!iree_input.buffer_view) -> (), iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"set_value\"} : () -> ()\n", - " \"func.func\"() ({\n", - " ^bb0(%arg0: tensor):\n", - " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr>\n", - " \"iree_input.global.store.indirect\"(%arg0, %0) : (tensor, !iree_input.ptr>) -> ()\n", - " \"func.return\"() : () -> ()\n", - " }) {arg_attrs = [{tf._user_specified_name = \"new_value\"}], function_type = (tensor) -> (), sym_name = \"__sm_exported___inference_set_value_230\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", - "}) : () -> ()\n", + "module {\n", + " ml_program.global private mutable @counter(dense<0> : tensor) : tensor\n", + " func.func @add_to_value(%arg0: !iree_input.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n", + " %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor\n", + " call @__inference_add_to_value_100(%0) : (tensor) -> ()\n", + " return\n", + " }\n", + " func.func private @__inference_add_to_value_100(%arg0: tensor {tf._user_specified_name = \"x\"}) attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {\n", + " %0 = ml_program.global_load @counter : tensor\n", + " %1 = chlo.broadcast_add %0, %arg0 : (tensor, tensor) -> tensor\n", + " ml_program.global_store @counter = %1 : tensor\n", + " return\n", + " }\n", + " func.func @get_value() -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n", + " %0 = call @__inference_get_value_160() : () -> tensor\n", + " %1 = iree_input.cast.tensor_to_buffer_view %0 : tensor -> !iree_input.buffer_view\n", + " return %1 : !iree_input.buffer_view\n", + " }\n", + " func.func private @__inference_get_value_160() -> tensor attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} {\n", + " %0 = ml_program.global_load @counter : tensor\n", + " return %0 : tensor\n", + " }\n", + " func.func @reset_value() attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[],\\22v\\22:1}\"} {\n", + " call @__inference_reset_value_270() : () -> ()\n", + " return\n", + " }\n", + " func.func private @__inference_reset_value_270() attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} {\n", + " %0 = mhlo.constant dense<0> : tensor\n", + " ml_program.global_store @counter = %0 : tensor\n", + " return\n", + " }\n", + " func.func @set_value(%arg0: !iree_input.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n", + " %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor\n", + " call @__sm_exported___inference_set_value_230(%0) : (tensor) -> ()\n", + " return\n", + " }\n", + " func.func private @__sm_exported___inference_set_value_230(%arg0: tensor {tf._user_specified_name = \"new_value\"}) attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {\n", + " ml_program.global_store @counter = %arg0 : tensor\n", + " return\n", + " }\n", + "}\n", "\n", "```\n", "\n", diff --git a/tests/e2e/linalg/BUILD b/tests/e2e/linalg/BUILD new file mode 100644 index 000000000000..a7cf36635e8d --- /dev/null +++ b/tests/e2e/linalg/BUILD @@ -0,0 +1,75 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Tests of end-to-end IREE support for individual ops in the TOSA dialect. +# Each test file should have a name matching the corresponding TOSA op and test only the +# functionality of that op (though may make use of other ops where necessary). Tests should be +# written using the IREE Check framework. +# See https://github.com/iree-org/iree/blob/main/docs/developers/developing_iree/testing_guide.md#iree-core-end-to-end-tests. + +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") +load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +LLVM_SRCS = enforce_glob( + [ + "conv2d.mlir", + ], + include = ["*.mlir"], +) + +iree_check_single_backend_test_suite( + name = "check_llvm-cpu_local-task", + srcs = LLVM_SRCS, + compiler_flags = ["--iree-input-type=none"], + driver = "local-task", + target_backend = "llvm-cpu", +) + +VMVX_SRCS = enforce_glob( + [ + "conv2d.mlir", + ], + include = ["*.mlir"], +) + +iree_check_single_backend_test_suite( + name = "check_vmvx_local-task", + srcs = VMVX_SRCS, + compiler_flags = [ + "--iree-input-type=none", + ], + driver = "local-task", + target_backend = "vmvx", +) + +VULKAN_SRCS = enforce_glob( + [ + "conv2d.mlir", + ], + include = ["*.mlir"], +) + +iree_check_single_backend_test_suite( + name = "check_vulkan-spirv_vulkan", + srcs = VULKAN_SRCS, + compiler_flags = ["--iree-input-type=none"], + driver = "vulkan", + target_backend = "vulkan-spirv", +) + +test_suite( + name = "check", + tests = [ + ":check_llvm-cpu_local-task", + ":check_vmvx_local-task", + ":check_vulkan-spirv_vulkan", + ], +) diff --git a/tests/e2e/linalg/CMakeLists.txt b/tests/e2e/linalg/CMakeLists.txt new file mode 100644 index 000000000000..c946a5fb0a8e --- /dev/null +++ b/tests/e2e/linalg/CMakeLists.txt @@ -0,0 +1,52 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# tests/e2e/linalg/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_check_single_backend_test_suite( + NAME + check_llvm-cpu_local-task + SRCS + "conv2d.mlir" + TARGET_BACKEND + "llvm-cpu" + DRIVER + "local-task" + COMPILER_FLAGS + "--iree-input-type=none" +) + +iree_check_single_backend_test_suite( + NAME + check_vmvx_local-task + SRCS + "conv2d.mlir" + TARGET_BACKEND + "vmvx" + DRIVER + "local-task" + COMPILER_FLAGS + "--iree-input-type=none" +) + +iree_check_single_backend_test_suite( + NAME + check_vulkan-spirv_vulkan + SRCS + "conv2d.mlir" + TARGET_BACKEND + "vulkan-spirv" + DRIVER + "vulkan" + COMPILER_FLAGS + "--iree-input-type=none" +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/tests/e2e/linalg/conv2d.mlir b/tests/e2e/linalg/conv2d.mlir new file mode 100644 index 000000000000..1edf2f357748 --- /dev/null +++ b/tests/e2e/linalg/conv2d.mlir @@ -0,0 +1,27 @@ +func.func @conv2d_nopadding() { + %inputs = util.unfoldable_constant dense<[[[ + [1.0, 3.0, 5.0, 7.0], + [11.0, 13.0, 15.0, 17.0], + [21.0, 23.0, 25.0, 27.0], + [31.0, 33.0, 35.0, 37.0]], + [[2.0, 4.0, 6.0, 8.0], + [12.0, 14.0, 16.0, 18.0], + [22.0, 24.0, 26.0, 28.0], + [32.0, 34.0, 36.0, 38.0]]]]> : tensor<1x2x4x4xf32> + %weights = util.unfoldable_constant dense<[[ + [[1.0, 3.0], + [5.0, 7.0], + [9.0, 11.0]], + [[2.0, 4.0], + [6.0, 8.0], + [10.0, 12.0]]]]> : tensor<1x2x3x2xf32> + %cst = arith.constant 0.000000e+00 : f32 + %fill = linalg.init_tensor [1, 1, 2, 3] : tensor<1x1x2x3xf32> + %out = linalg.fill ins(%cst : f32) outs(%fill : tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf32> + %result = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inputs, %weights : tensor<1x2x4x4xf32>, tensor<1x2x3x2xf32>) outs(%out : tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf32> + check.expect_almost_eq_const(%result, dense<[[ + [[1310.0, 1466.0, 1622.0], + [2090.0, 2246.0, 2402.0]] + ]]> : tensor<1x1x2x3xf32>) : tensor<1x1x2x3xf32> + return +} diff --git a/tests/e2e/matmul/BUILD b/tests/e2e/matmul/BUILD index f2e8e733be02..7c10b4521c42 100644 --- a/tests/e2e/matmul/BUILD +++ b/tests/e2e/matmul/BUILD @@ -240,15 +240,15 @@ py_binary( ]] [iree_generated_trace_runner_test( - name = "e2e_matmul_direct_f32_gpu_large_%s" % vulkan_target, + name = "e2e_matmul_direct_f32_gpu_large_%s" % vulkan_target_and_pipeline[0], compiler_flags = [ - "--iree-vulkan-target-triple=%s" % vulkan_target, + "--iree-vulkan-target-triple=%s" % vulkan_target_and_pipeline[0], ], generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", "--shapes=gpu_large", - "--compilation_info=SPIRVVectorize", + "--compilation_info=%s" % vulkan_target_and_pipeline[1], ], tags = [ "requires-gpu-nvidia", @@ -257,9 +257,9 @@ py_binary( ("vulkan-spirv", "vulkan"), ], trace_runner = "//tools:iree-e2e-matmul-test", -) for vulkan_target in [ - "valhall-unknown-android31", - "ampere-unknown-linux", +) for vulkan_target_and_pipeline in [ + ("valhall-unknown-android31", "SPIRVVectorizeMali"), + ("ampere-unknown-linux", "SPIRVVectorizeNVIDIA"), ]] # Check tests diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index e7799245a532..ca50afbc3b33 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -332,7 +332,7 @@ iree_generated_trace_runner_test( GENERATOR_ARGS "--lhs_rhs_type=f32" "--shapes=gpu_large" - "--compilation_info=SPIRVVectorize" + "--compilation_info=SPIRVVectorizeMali" TRACE_RUNNER iree-e2e-matmul-test TARGET_BACKENDS @@ -353,7 +353,7 @@ iree_generated_trace_runner_test( GENERATOR_ARGS "--lhs_rhs_type=f32" "--shapes=gpu_large" - "--compilation_info=SPIRVVectorize" + "--compilation_info=SPIRVVectorizeNVIDIA" TRACE_RUNNER iree-e2e-matmul-test TARGET_BACKENDS diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 186e759ce2a5..21cfd9d94ff4 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -14,6 +14,7 @@ import enum import dataclasses import typing +import itertools # Data type of matrix entries. The string values must match MLIR data types. @@ -42,7 +43,8 @@ class CompilationInfoId(enum.Enum): NONE = "" LLVMGPUMatmulSimt = "LLVMGPUMatmulSimt" LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore" - SPIRVVectorize = "SPIRVVectorize" + SPIRVVectorizeMali = "SPIRVVectorizeMali" + SPIRVVectorizeNVIDIA = "SPIRVVectorizeNVIDIA" # Enumerates ways to construct MLIR tensor types. @@ -73,7 +75,7 @@ class TestShape: @dataclasses.dataclass class CompilationInfo: # Lowering Config - tile_sizes: typing.List[int] + tile_sizes: typing.List[typing.List[int]] # Translation Info dispatch_lowering_pass_pipeline: str workload_per_wg: typing.List[int] @@ -155,30 +157,61 @@ def get_dynamicities(shapes_id: ShapesId): @dataclasses.dataclass class TileWorkgroupSizePair: - tile_size: typing.List[int] + tile_size: typing.List[typing.List[int]] workgroup_size: typing.List[int] +# Constructs a TileWorkgroupSizePair for SPIRV Targets enforcing the constraints between +# the workgroup_size and tile size +def get_spirv_tile_workgroup_size_pair(workgroup_size, + t_tile_k, + t_tile_m=4, + t_tile_n=4): + x, y, z = workgroup_size + wg_tile_m = y * t_tile_m + wg_tile_n = x * t_tile_n + return TileWorkgroupSizePair( + [[wg_tile_m, wg_tile_n], [t_tile_m, t_tile_n], [0, 0, t_tile_k]], + workgroup_size) + + +# Returns all the TileWorkgroupSizePairs for a given SPIRV Target +def get_all_spirv_tile_workgroup_size_pairs(t_tile_k): + tile_workgroup_size_pairs = [ + get_spirv_tile_workgroup_size_pair([32, 8, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([16, 8, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([64, 2, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([8, 8, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([16, 2, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k), + ] + return tile_workgroup_size_pairs + + # Returns the list of CompilationInfo's to use for the CompilationInfoId. def get_test_compilation_infos( compilation_info_id: CompilationInfoId ) -> typing.List[typing.Optional[CompilationInfo]]: if compilation_info_id == CompilationInfoId.NONE: return [None] - if (compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt or - compilation_info_id == CompilationInfoId.SPIRVVectorize): + if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt: tile_workgroup_size_pairs = [ - TileWorkgroupSizePair([32, 128, 32], [32, 8, 1]), - TileWorkgroupSizePair([128, 64, 8], [16, 8, 1]), - TileWorkgroupSizePair([16, 256, 32], [64, 2, 1]), - TileWorkgroupSizePair([8, 32, 32], [8, 8, 1]), - TileWorkgroupSizePair([8, 128, 4], [32, 1, 1]), - TileWorkgroupSizePair([16, 64, 4], [16, 2, 1]), - TileWorkgroupSizePair([1, 128, 8], [32, 1, 1]), + TileWorkgroupSizePair([[32, 128, 32]], [32, 8, 1]), + TileWorkgroupSizePair([[128, 64, 8]], [16, 8, 1]), + TileWorkgroupSizePair([[16, 256, 32]], [64, 2, 1]), + TileWorkgroupSizePair([[8, 32, 32]], [8, 8, 1]), + TileWorkgroupSizePair([[8, 128, 4]], [32, 1, 1]), + TileWorkgroupSizePair([[16, 64, 4]], [16, 2, 1]), + TileWorkgroupSizePair([[1, 128, 8]], [32, 1, 1]), ] + elif compilation_info_id == CompilationInfoId.SPIRVVectorizeNVIDIA: + tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(32) + elif compilation_info_id == CompilationInfoId.SPIRVVectorizeMali: + tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(4) elif compilation_info_id == CompilationInfoId.LLVMGPUMatmulTensorCore: tile_workgroup_size_pairs = [ - TileWorkgroupSizePair([32, 32, 16], [64, 2, 1]), + TileWorkgroupSizePair([[32, 32, 16]], [64, 2, 1]), ] compilation_infos = [] @@ -313,8 +346,9 @@ def generate_function_name( info = "" if compilation_info: + tile_sizes = list(itertools.chain(*compilation_info.tile_sizes)) tile_workgroup_key = "_".join([ - str(a) for a in compilation_info.tile_sizes + str(a) for a in tile_sizes ]) + "_" + "_".join([str(a) for a in compilation_info.workgroup_size]) info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}" @@ -354,10 +388,13 @@ def generate_function( func_definition = "" compilation_info_attr = "" if compilation_info: + dispatch_lowering_pass_pipeline = compilation_info.dispatch_lowering_pass_pipeline + if "SPIRV" in compilation_info.dispatch_lowering_pass_pipeline: + dispatch_lowering_pass_pipeline = "SPIRVVectorize" compilation_info_string = ( f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n" - f" lowering_config = ,\n" - f" translation_info = <{compilation_info.dispatch_lowering_pass_pipeline}\n" + f" lowering_config = ,\n" + f" translation_info = <{dispatch_lowering_pass_pipeline}\n" f" pipeline_depth = {compilation_info.software_pipeline_depth}>,\n" f" workgroup_size = {compilation_info.workgroup_size_str()}>\n") compilation_info_attr = f"{{compilation_info = #compilation{generate_function.compilation_index}}} " @@ -375,7 +412,7 @@ def generate_function( ) -# Counter for producing unique complation info attrs +# Counter for producing unique compilation info attrs generate_function.compilation_index = 0 # Intentionally fixed seed! We want full reproducibility here, both across runs diff --git a/tests/e2e/matmul/large_linalg_matmul.mlir b/tests/e2e/matmul/large_linalg_matmul.mlir index 69ea6c73dd36..ccd8d45d8c31 100644 --- a/tests/e2e/matmul/large_linalg_matmul.mlir +++ b/tests/e2e/matmul/large_linalg_matmul.mlir @@ -1,6 +1,10 @@ // Test large aligned linalg matmul to make sure we go through the optimized // path for GPUs. -func.func @large_aligned() { + +// Problem size : 2048x512x1024 +// Input type : F32 +// Accumulation type : F32 +func.func @matmul_2048x512x1024_f32_f32() { %lhs = util.unfoldable_constant dense<1.0> : tensor<2048x1024xf32> %rhs = util.unfoldable_constant dense<0.4> : tensor<1024x512xf32> %c0 = arith.constant 0.0 : f32 @@ -10,4 +14,19 @@ func.func @large_aligned() { outs(%CC: tensor<2048x512xf32>) -> tensor<2048x512xf32> check.expect_almost_eq_const(%D, dense<409.596> : tensor<2048x512xf32>) : tensor<2048x512xf32> return +} + +// Problem size : 3456x1024x2048 +// Input type : F16 +// Accumulation type : F16 +func.func @matmul_3456x1024x2048_f16_f16() { + %lhs = util.unfoldable_constant dense<1.00> : tensor<3456x2048xf16> + %rhs = util.unfoldable_constant dense<0.01> : tensor<2048x1024xf16> + %c0 = arith.constant 0.0 : f16 + %init = linalg.init_tensor[3456, 1024] : tensor<3456x1024xf16> + %CC = linalg.fill ins(%c0 : f16) outs(%init : tensor<3456x1024xf16>) -> tensor<3456x1024xf16> + %D = linalg.matmul ins(%lhs, %rhs: tensor<3456x2048xf16>, tensor<2048x1024xf16>) + outs(%CC: tensor<3456x1024xf16>) -> tensor<3456x1024xf16> + check.expect_almost_eq_const(%D, dense<20.2812> : tensor<3456x1024xf16>) : tensor<3456x1024xf16> + return } \ No newline at end of file diff --git a/tests/e2e/models/CMakeLists.txt b/tests/e2e/models/CMakeLists.txt index 526df29a0033..24e0a2c06ded 100644 --- a/tests/e2e/models/CMakeLists.txt +++ b/tests/e2e/models/CMakeLists.txt @@ -114,6 +114,7 @@ iree_static_linker_test( "1x28x28x1xf32" COMPILER_FLAGS "--iree-input-type=mhlo" + "--iree-vm-target-index-bits=32" EMITC ) @@ -141,5 +142,7 @@ iree_static_linker_test( "simple_mul" FUNCTION_INPUTS "4xf32,4xf32" + COMPILER_FLAGS + "--iree-vm-target-index-bits=32" EMITC ) diff --git a/tests/e2e/models/fullyconnected.mlir b/tests/e2e/models/fullyconnected.mlir index d589c10f4235..89773551759a 100644 --- a/tests/e2e/models/fullyconnected.mlir +++ b/tests/e2e/models/fullyconnected.mlir @@ -1,4 +1,5 @@ // RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --function_input=1x5xf32=1,-2,-3,4,-5 --function_input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s +// RUN: iree-run-mlir --iree-flow-dispatch-via-region-ops --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --function_input=1x5xf32=1,-2,-3,4,-5 --function_input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s // RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s --function_input=1x5xf32=1,-2,-3,4,-5 --function_input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s) // CHECK-LABEL: EXEC @main diff --git a/tests/transform_dialect/cpu/BUILD b/tests/transform_dialect/cpu/BUILD index a6387adb65a7..88ac647e48c9 100644 --- a/tests/transform_dialect/cpu/BUILD +++ b/tests/transform_dialect/cpu/BUILD @@ -20,9 +20,8 @@ iree_lit_test_suite( # transform dialect spec files are MLIR files that specify a transformation, # they need to be included as data. data = [ - "matmul_codegen_spec.mlir", - "matmul_dispatch_spec.mlir", - "matmul_tiled_dispatch_spec.mlir", + "matmul_codegen_custom_dispatch_formation_spec.mlir", + "matmul_codegen_default_spec.mlir", ], tags = [ "noasan", diff --git a/tests/transform_dialect/cpu/CMakeLists.txt b/tests/transform_dialect/cpu/CMakeLists.txt index 130bd8e2cef6..ea778f500afb 100644 --- a/tests/transform_dialect/cpu/CMakeLists.txt +++ b/tests/transform_dialect/cpu/CMakeLists.txt @@ -22,9 +22,8 @@ iree_lit_test_suite( iree-opt iree-run-module DATA - matmul_codegen_spec.mlir - matmul_dispatch_spec.mlir - matmul_tiled_dispatch_spec.mlir + matmul_codegen_custom_dispatch_formation_spec.mlir + matmul_codegen_default_spec.mlir LABELS "noasan" "nomsan" diff --git a/tests/transform_dialect/cpu/matmul.mlir b/tests/transform_dialect/cpu/matmul.mlir index 641c7f73f40e..e7a26e15c043 100644 --- a/tests/transform_dialect/cpu/matmul.mlir +++ b/tests/transform_dialect/cpu/matmul.mlir @@ -10,85 +10,54 @@ func.func @matmul_static( return %0 : !C_size } +// Run with C++ dispatch region formation but transform dialect codegen // RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \ -// RUN: --iree-abi-transformation-pipeline \ -// RUN: --iree-flow-transformation-pipeline \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir | \ -// RUN: FileCheck %s --check-prefixes=DISPATCH - -// TODO: make this test drop transform dialect usage at the flow level and use: -// --iree-flow-transformation-pipeline --iree-flow-convert-region-to-workgroups -// Atm the 3rd flow.dispatch.tensor.load shows as readonly instead of readwrite. - -// DISPATCH: flow.executable private @matmul_static_dispatch_0 { -// DISPATCH: flow.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 -// DISPATCH: builtin.module { -// DISPATCH: func.func @matmul_static_dispatch_0_matmul_3x3x5 -// DISPATCH: flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [3, 5], strides = [1, 1] : !flow.dispatch.tensor -> tensor<3x5xf32> -// DISPATCH: flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [5, 3], strides = [1, 1] : !flow.dispatch.tensor -> tensor<5x3xf32> -// DISPATCH: flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [3, 3], strides = [1, 1] : !flow.dispatch.tensor -> tensor<3x3xf32> -// DISPATCH: linalg.matmul ins({{.*}} : tensor<3x5xf32>, tensor<5x3xf32>) outs({{.*}} : tensor<3x3xf32>) -> tensor<3x3xf32> -// DISPATCH: flow.dispatch.tensor.store {{.*}} offsets = [0, 0], sizes = [3, 3], strides = [1, 1] : tensor<3x3xf32> -> !flow.dispatch.tensor -// DISPATCH: return +// RUN: --iree-abi-transformation-pipeline --iree-flow-transformation-pipeline \ +// RUN: --iree-flow-dispatch-via-region-ops \ +// RUN: --iree-flow-dispatch-via-region-ops-generate-workload-region=false \ +// RUN: --iree-stream-transformation-pipeline \ +// RUN: --iree-hal-configuration-pipeline | \ +// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \ +// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_custom_dispatch_formation_spec.mlir | \ +// RUN: FileCheck %s --check-prefix=CODEGEN-CUSTOM-DISPATCH-FORMATION + +// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.executable private @matmul_static_dispatch_0 { +// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { +// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 ordinal(0) layout(#{{.*}}) attributes {translation_info = #translation} { +// CODEGEN-CUSTOM-DISPATCH-FORMATION: ^bb0(%{{.*}}: !hal.device): +// CODEGEN-CUSTOM-DISPATCH-FORMATION: %[[C2:.*]] = arith.constant 2 : index +// CODEGEN-CUSTOM-DISPATCH-FORMATION: %[[C1:.*]] = arith.constant 1 : index +// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.return %[[C2]], %[[C1]], %[[C1]] : index, index, index +// CODEGEN-CUSTOM-DISPATCH-FORMATION: } +// CODEGEN-CUSTOM-DISPATCH-FORMATION: builtin.module { +// CODEGEN-CUSTOM-DISPATCH-FORMATION: func.func @matmul_static_dispatch_0_matmul_3x3x5() { +// CODEGEN-CUSTOM-DISPATCH-FORMATION: arith.constant 0 : index +// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x5xf32> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.assume_alignment %{{.*}}, 64 : memref<3x5xf32> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset({{.*}}) alignment(64) : memref<5x3xf32> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.assume_alignment %{{.*}}, 64 : memref<5x3xf32> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x3xf32> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.assume_alignment %{{.*}}, 64 : memref<3x3xf32> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index +// CODEGEN-CUSTOM-DISPATCH-FORMATION: affine.apply {{.*}}()[%workgroup_id_x] +// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 5] [1, 1] : memref<3x5xf32> to memref> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 3] [1, 1] : memref<3x3xf32> to memref> +// CODEGEN-CUSTOM-DISPATCH-FORMATION: linalg.matmul ins(%{{.*}}, %{{.*}} : memref>, memref<5x3xf32>) outs(%{{.*}} : memref>) // RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \ // RUN: --iree-abi-transformation-pipeline \ // RUN: --iree-flow-transformation-pipeline \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \ // RUN: --iree-stream-transformation-pipeline \ -// RUN: --iree-hal-configuration-pipeline | \ +// RUN: --iree-hal-configuration-pipeline | \ // RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \ -// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \ -// RUN: FileCheck %s --check-prefixes=CODEGEN - -// CODEGEN: hal.executable private @matmul_static_dispatch_0 { -// CODEGEN: hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ { -// -// The signature of the hal.executable.export region is subject to conventions -// at the flow level. These conventions are materialized in IR e.g. into -// stream.cmd.dispatch before codegen gets invoked. -// As a consequence, the tile_size/num_threads/workgroup_count passed to -// transform.tile_to_foreach_thread needs to be aware of this convention. -// For now we use our own convention that sizes are static and no other bbArg -// than !hal.device is present. -// -// CODEGEN: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 ordinal(0) layout(#{{.*}}) attributes {translation_info = #translation} { -// CODEGEN: ^bb0(%{{.*}}: !hal.device): -// CODEGEN: arith.constant 2 : index -// CODEGEN: arith.constant 1 : index -// CODEGEN: hal.return %{{.*}}, %{{.*}}, %{{.*}} : index, index, index -// CODEGEN: } -// CODEGEN: builtin.module { -// CODEGEN: func.func @matmul_static_dispatch_0_matmul_3x3x5() { -// CODEGEN: arith.constant 0 : index -// CODEGEN: hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x5xf32> -// CODEGEN: memref.assume_alignment %{{.*}}, 64 : memref<3x5xf32> -// CODEGEN: hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset({{.*}}) alignment(64) : memref<5x3xf32> -// CODEGEN: memref.assume_alignment %{{.*}}, 64 : memref<5x3xf32> -// CODEGEN: hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x3xf32> -// CODEGEN: memref.assume_alignment %{{.*}}, 64 : memref<3x3xf32> -// CODEGEN: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index -// CODEGEN: affine.apply {{.*}}()[%workgroup_id_x] -// CODEGEN: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 5] [1, 1] : memref<3x5xf32> to memref -// CODEGEN: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 3] [1, 1] : memref<3x3xf32> to memref -// CODEGEN: linalg.matmul ins(%{{.*}}, %{{.*}} : memref, memref<5x3xf32>) outs(%{{.*}} : memref) +// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_default_spec.mlir | \ +// RUN: FileCheck %s --check-prefixes=CODEGEN-DEFAULT -// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \ -// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \ -// RUN: iree-run-module --entry_function=matmul_static \ -// RUN: --function_input="3x5xf32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" \ -// RUN: --function_input="5x3xf32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" \ -// RUN: --function_input="3x3xf32=0 0 0 0 0 0 0 0 0"| \ -// RUN: FileCheck %s --check-prefixes=EXEC +// CODEGEN-DEFAULT: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CODEGEN-DEFAULT: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 +// CODEGEN-DEFAULT: ^bb0(%[[DEVICE:[a-zA-Z0-9]+]]: !hal.device, %[[ARG0:[a-zA-Z0-9]+]]: index, +// CODEGEN-DEFAULT: %[[C1:.+]] = arith.constant 1 : index +// CODEGEN-DEFAULT: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CODEGEN-DEFAULT: hal.return %[[D0]], %[[C1]], %[[C1]] // EXEC: 3x3xf32=[5 5 5][5 5 5][5 5 5] - -// RUN: iree-compile --iree-hal-target-backends=llvm-cpu \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_tiled_dispatch_spec.mlir \ -// RUN: --iree-flow-export-benchmark-funcs %s | \ -// RUN: iree-benchmark-module --device=local-task | \ -// RUN: FileCheck %s --check-prefixes=BENCHMARK-MODULE - -// When running iree-benchmark-module, we only check the existence of the func. -// BENCHMARK-MODULE: matmul_static diff --git a/tests/transform_dialect/cpu/matmul_codegen_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir similarity index 100% rename from tests/transform_dialect/cpu/matmul_codegen_spec.mlir rename to tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir diff --git a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir new file mode 100644 index 000000000000..9d7b392ba3cf --- /dev/null +++ b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir @@ -0,0 +1,13 @@ +// RUN: iree-opt %s + +transform.structured.canonicalized_sequence failures(propagate) { +^bb1(%variant_op: !pdl.operation): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op + + %foreach_thread, %tiled_generic = + transform.iree.tile_to_foreach_thread_and_workgroup_count_region %matmul tile_sizes [2] + + %variant_op_2 = transform.iree.bufferize %variant_op + %func = transform.structured.match ops{["func.func"]} in %variant_op_2 + transform.iree.foreach_thread_to_workgroup %func +} diff --git a/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir b/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir deleted file mode 100644 index eba247b46391..000000000000 --- a/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir +++ /dev/null @@ -1,9 +0,0 @@ -transform.with_pdl_patterns { -^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %region_op = transform.iree.wrap_in_dispatch_region %0 - transform.iree.region_to_workgroups %region_op - } -} diff --git a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir b/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir deleted file mode 100644 index 41dc2dd46faf..000000000000 --- a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir +++ /dev/null @@ -1,6 +0,0 @@ -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] - %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op -} diff --git a/tests/transform_dialect/cuda/BUILD b/tests/transform_dialect/cuda/BUILD index d2a71acbc447..4aeb99b6a001 100644 --- a/tests/transform_dialect/cuda/BUILD +++ b/tests/transform_dialect/cuda/BUILD @@ -26,7 +26,7 @@ endif() iree_lit_test_suite( name = "lit", srcs = [ - # "reduction.mlir", // see #10398 + "reduction.mlir", "softmax.mlir", ], cfg = "//tests:lit.cfg.py", @@ -34,8 +34,9 @@ iree_lit_test_suite( # they need to be included as data. data = [ "reduction_codegen_spec.mlir", - "reduction_dispatch_spec.mlir", "softmax_codegen_spec.mlir", + # FIXME: This cannot be retired yet as there is some writeonly vs readwrite + # issue and we even end up emitting out of bounds accesses. "softmax_dispatch_spec.mlir", "softmax_fused_codegen_spec.mlir", ], diff --git a/tests/transform_dialect/cuda/CMakeLists.txt b/tests/transform_dialect/cuda/CMakeLists.txt index 59050262740f..139817fe1b6c 100644 --- a/tests/transform_dialect/cuda/CMakeLists.txt +++ b/tests/transform_dialect/cuda/CMakeLists.txt @@ -18,6 +18,7 @@ iree_lit_test_suite( NAME lit SRCS + "reduction.mlir" "softmax.mlir" TOOLS FileCheck @@ -26,7 +27,6 @@ iree_lit_test_suite( iree-run-module DATA reduction_codegen_spec.mlir - reduction_dispatch_spec.mlir softmax_codegen_spec.mlir softmax_dispatch_spec.mlir softmax_fused_codegen_spec.mlir diff --git a/tests/transform_dialect/cuda/reduction.mlir b/tests/transform_dialect/cuda/reduction.mlir index c4a81ba4ca00..4ea93002eaf9 100644 --- a/tests/transform_dialect/cuda/reduction.mlir +++ b/tests/transform_dialect/cuda/reduction.mlir @@ -1,117 +1,73 @@ -func.func @reduce() -> (tensor<8xf32>) { +!in_tensor_t = tensor<8x64xf32> +!out_tensor_t = tensor<8xf32> + +func.func @reduce() -> (!out_tensor_t) { %cst = arith.constant -0.000000e+00 : f32 // Note: arith.constant is good for our purposes here but it may be useful to use // util.unfoldable_constant. - %arg = arith.constant dense<1.0> : tensor<8x64xf32> - %0 = linalg.init_tensor [8] : tensor<8xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8xf32>) -> tensor<8xf32> + %arg = arith.constant dense<1.0> : !in_tensor_t + %0 = linalg.init_tensor [8] : !out_tensor_t + %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t %2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} - ins(%arg : tensor<8x64xf32>) outs(%1 : tensor<8xf32>) { + ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) { ^bb0(%arg3: f32, %arg4: f32): %3 = arith.addf %arg3, %arg4 : f32 linalg.yield %3 : f32 - } -> tensor<8xf32> - return %2 : tensor<8xf32> + } -> !out_tensor_t + return %2 : !out_tensor_t } -// RUN: iree-opt %s --iree-hal-target-backends=cuda \ -// RUN: --iree-abi-transformation-pipeline \ -// RUN: --iree-flow-transformation-pipeline \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/reduction_dispatch_spec.mlir \ -// RUN: --iree-stream-transformation-pipeline \ -// RUN: --iree-hal-configuration-pipeline | \ -// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \ -// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \ -// RUN: FileCheck %s --check-prefix=FLOW-AND-CG --check-prefix=CHECK - // RUN: iree-opt %s --iree-hal-target-backends=cuda \ // RUN: --iree-abi-transformation-pipeline \ // RUN: --iree-flow-transformation-pipeline \ // RUN: --iree-stream-transformation-pipeline \ // RUN: --iree-hal-configuration-pipeline | \ // RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \ -// RUN: --iree-codegen-llvmgpu-workgroup-tile-sizes=4 \ -// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \ -// RUN: FileCheck %s --check-prefix=CG-ONLY --check-prefix=CHECK - -// RUN: iree-compile %s --iree-hal-target-backends=cuda \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/reduction_dispatch_spec.mlir \ // RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \ -// RUN: iree-run-module --entry_function=reduce --device=cuda |\ -// RUN: FileCheck %s --check-prefix=EXEC +// RUN: FileCheck %s --check-prefix=CHECK // RUN: iree-compile %s --iree-hal-target-backends=cuda \ -// RUN: --iree-codegen-llvmgpu-workgroup-tile-sizes=4 \ // RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \ // RUN: iree-run-module --entry_function=reduce --device=cuda |\ // RUN: FileCheck %s --check-prefix=EXEC // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> - // CG-ONLY-DAG: %[[FMINUS0:.*]] = arith.constant dense<-0.000000e+00> : vector<1xf32> + // CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector + // CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index + // CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 128 : i64} : memref<1x2xf32, 3> // CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x // CHECK-DAG: %[[TIDY:.]] = gpu.thread_id y // CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id z - // When using IREE default flow path, fill op is being fused with the generic - // op, the rest of the IR generated is the same. - // CG-ONLY: %[[CMP0:.*]] = arith.cmpi ult, %[[TIDX]], %[[C1]] : index - // CG-ONLY: %[[CMP1:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index - // CG-ONLY: %[[CONXANDYARE0:.*]] = arith.andi %[[CMP0]], %[[CMP1]] : i1 - // CG-ONLY: scf.if %[[CONXANDYARE0]] { - // CG-ONLY: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%[[TIDZ]]] [1] [1] : memref<4xf32, {{.*}}> to memref - // CG-ONLY: %[[EXPAND:.*]] = memref.expand_shape %[[SUBVIEW]] [] : memref into memref<1xf32, {{.*}}> - // CG-ONLY: vector.transfer_write %[[FMINUS0]], %[[EXPAND]][%[[C0]]] {in_bounds = [true]} : vector<1xf32>, memref<1xf32, {{.*}}> - // CG-ONLY: } - // CG-ONLY: gpu.barrier - // CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 128 : i64} : memref<4x2xf32, 3> - // CHECK: %[[SHMEM_VIEW:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]] - // CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.expand_shape %[[SHMEM_VIEW]] [] : memref into memref<1x1xf32, {{.*}}> - // CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index - - // Distributed fill to shared memory, only threadIdx.x == 0 writes. - // CHECK: scf.if %[[CONDXIS0]] - // CHECK: vector.transfer_write %[[F0]], %[[SHMEM_VIEW_EXPANDED]][%[[C0]], %[[C0]]] - // CHECK: gpu.barrier - - // Note some inefficiencies here: all threads read + addf but only threadIdx.x==0 commits. - // So only threadIdx.x == 0 could do that. - // Additionally, the value read is exactly the "Distributed fill to shared memory" from above - // and there is no interleaved read/write so we could fold this read into only - // %[[F0]] and only write back to shared memory. - // - // Note: This will probably happen once the fill is fused into the split op at the linalg level. - // CHECK: %[[NEUTRAL_VEC:.*]] = vector.transfer_read %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}vector<1xf32> + // CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref // Distributed reduction: everyone loads then 5 xor + addf expected - // CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[TIDX]]] - // TODO: Some foldings are missing, no need to be in vector<1xf32>. - // CHECK: %[[NEUTRAL:.*]] = vector.extract %[[NEUTRAL_VEC]][0] : vector<1xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[TIDX]]] // CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf - // CHECK: %[[RES:.*]] = arith.addf %{{.*}}, %[[NEUTRAL]] + // CHECK: %[[RES:.*]] = arith.addf %{{.*}} - // TODO: Some foldings are missing, no need to be in vector<1xf32>. - // CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<1xf32> + // CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector + // CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index // CHECK: scf.if %[[CONDXIS0]] - // CHECK: vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][%[[C0]], %[[C0]]] + // CHECK: vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][] // CHECK: gpu.barrier // Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0. - // FLOW-AND-CG: %[[CONDYIS0:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index + // CHECK: %[[CONDYIS0:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index // TODO: cond eq 0 and cond ult 1 do not CSE atm. - // FLOW-AND-CG: %[[CONXANDYARE0:.*]] = arith.andi %{{.*}}, %[[CONDYIS0]] : i1 + // CHECK: %[[CONXANDYARE0:.*]] = arith.andi %{{.*}}, %[[CONDYIS0]] : i1 // CHECK: scf.if %[[CONXANDYARE0]] { - // CHECK-COUNT-2: vector.transfer_read + // CHECK: vector.transfer_read // CHECK: vector.reduction // CHECK: vector.transfer_write // CHECK: gpu.barrier - // CHECK: memref.dealloc %[[SHMEM_ALLOC]] : memref<4x2xf32, 3> + // CHECK: memref.dealloc %[[SHMEM_ALLOC]] : memref<1x2xf32, 3> // EXEC: result[0]: hal.buffer_view diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir index ca2ee11d261f..18db174cb889 100644 --- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir @@ -2,37 +2,53 @@ transform.structured.canonicalized_sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): + %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op + + // Split the reduction by 2 to obtain a more meaty parallel op with + // parallelism across size(reduction) / 2 threads. %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op - %fused_fill = transform.structured.match ops{["linalg.fill"]} in %variant_op - // Note: split by 32 to vector-distribute the tail combiner_op, but - // split by 2 to vector-distribute the meaty %more_parallel_op - %init_or_alloc_op, %fill_op, %more_parallel_op, %combiner_op = + %init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op = transform.structured.split_reduction %0 - { split_factor = 2, insert_split_dimension = 1, use_alloc } - - %1 = transform.structured.match ops{["linalg.generic"]} in %variant_op - %foreach_thread_1, %tiled_fill = - transform.structured.tile_to_foreach_thread_op %fill_op num_threads [4, 2] (mapped to dims [2, 1, 0]) - %foreach_thread_2, %tiled_more_parallel_op = - transform.structured.tile_to_foreach_thread_op %more_parallel_op num_threads [4, 2] (mapped to dims [2, 1, 0]) - %foreach_thread_3, %tiled_combiner_op = - transform.structured.tile_to_foreach_thread_op %combiner_op num_threads [4] (mapped to dims [2, 1, 0]) - %foreach_thread_4, %tiled_fused_fill_op = - transform.structured.tile_to_foreach_thread_op %fused_fill num_threads [4] (mapped to dims [2, 1, 0]) - - %isolated_handle_1 = transform.get_closest_isolated_parent %foreach_thread_2 - %isolated_handle_2 = transform.structured.vectorize %isolated_handle_1 - %isolated_handle_3 = transform.iree.apply_patterns %isolated_handle_2 { rank_reducing } + { split_factor = 2, insert_split_dimension = 1 } + + // First level of tiling + fusion parallelizes to blocks. + // The mapping to block ids can only happen after bufferization atm. + %foreach_thread_grid, %grid_combiner_op = + transform.iree.tile_to_foreach_thread_and_workgroup_count_region %combiner_op tile_sizes [1] + %not_combiner = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op + transform.structured.fuse_into_containing_op %not_combiner into %foreach_thread_grid + + // Second level of tiling + fusion parallelizes to threads. + // The mapping to thread ids can only happen after bufferization atm. + %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op + + %grid_more_parallel_op = transform.structured.match interface{LinalgOp} + attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op + %foreach_thread_block_more_parallel_op, %block_more_parallel_op = + transform.structured.tile_to_foreach_thread_op %grid_more_parallel_op tile_sizes [1, 1, 0] (mapped to dims [2, 1, 0]) + transform.structured.fuse_into_containing_op %fill_2d into %foreach_thread_block_more_parallel_op + + // Second level of tiling + fusion parallelizes to threads. + // The mapping to thread ids can only happen after bufferization atm. + %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op + %foreach_thread_block_combiner_op, %block_combiner_op = + transform.structured.tile_to_foreach_thread_op %grid_combiner_op tile_sizes [1, 0, 0] (mapped to dims [2, 1, 0]) + transform.structured.fuse_into_containing_op %fill_1d into %foreach_thread_block_combiner_op + + %func = transform.structured.match ops{["func.func"]} in %variant_op + %func_2 = transform.iree.apply_patterns %func { rank_reducing } + %func_3 = transform.structured.vectorize %func_2 %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op + %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2 - %funcop = transform.structured.match ops{["func.func"]} in %variant_op_2 - %isolated_handle_4 = - transform.iree.foreach_thread_to_gpu_and_translation_info %funcop - { workgroup_size = [32, 2, 4] } + %func_5 = transform.iree.foreach_thread_to_workgroup %func_4 + %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5 + { workgroup_size = [32, 2, 1] } // Vector distribution needs to happen on buffers. + %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing } %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } - transform.iree.vector.warp_distribute %isolated_handle_4 + transform.iree.vector.warp_distribute %func_7 } diff --git a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir b/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir deleted file mode 100644 index cfc4600ca08c..000000000000 --- a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: iree-opt %s - -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] - transform.iree.foreach_thread_to_flow %foreach_thread -} diff --git a/tests/transform_dialect/cuda/softmax.mlir b/tests/transform_dialect/cuda/softmax.mlir index b6e4fed8145d..7fa6c2ce6388 100644 --- a/tests/transform_dialect/cuda/softmax.mlir +++ b/tests/transform_dialect/cuda/softmax.mlir @@ -2,7 +2,6 @@ // RUN: iree-opt %s --iree-hal-target-backends=cuda \ // RUN: --iree-abi-transformation-pipeline \ // RUN: --iree-flow-transformation-pipeline \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \ // RUN: --iree-stream-transformation-pipeline \ // RUN: --iree-hal-configuration-pipeline | \ // RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \ @@ -10,7 +9,6 @@ // RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE // RUN: iree-compile %s --iree-hal-target-backends=cuda \ -// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \ // RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_codegen_spec.mlir | \ // RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \ // RUN: FileCheck %s @@ -18,6 +16,10 @@ // RUN: iree-opt %s --iree-hal-target-backends=cuda \ // RUN: --iree-abi-transformation-pipeline \ // RUN: --iree-flow-transformation-pipeline \ +/// +/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite +/// issue and we even end up emitting out of bounds accesses. +/// // RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \ // RUN: --iree-stream-transformation-pipeline \ // RUN: --iree-hal-configuration-pipeline | \ @@ -26,6 +28,10 @@ // RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE // RUN: iree-compile %s --iree-hal-target-backends=cuda \ +/// +/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite +/// issue and we even end up emitting out of bounds accesses. +/// // RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \ // RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_fused_codegen_spec.mlir | \ // RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \ @@ -41,11 +47,13 @@ // CHECK-SHUFFLE: gpu.shuffle xor // Execution only checks that @max_sub_exp runs. -// CHECK: EXEC @max_sub_exp +// CHECK: EXEC @max_sub_exp +// CHECK: 16x128x128xf32=[ +// CHECK-SAME: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 -func.func @max_sub_exp() { +func.func @max_sub_exp() -> !out_tensor_t { %cst = arith.constant -3.40282347E+38 : f32 - %cst_0 = arith.constant dense<1.000000e+00> : !out_tensor_t + %cst_0 = arith.constant dense<1121212.000000e+00> : !out_tensor_t %cst_1 = arith.constant dense<5.000000e+00> : !out_tensor_t %0 = util.do_not_optimize(%cst_1) : !out_tensor_t @@ -73,6 +81,5 @@ func.func @max_sub_exp() { linalg.yield %7 : f32 } -> !out_tensor_t - check.expect_almost_eq(%5, %cst_0) : !out_tensor_t - return + return %5: !out_tensor_t } diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir index 84951141e515..4607570487ce 100644 --- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir @@ -1,82 +1,77 @@ // RUN: iree-opt %s // Codegen -transform.with_pdl_patterns { -^bb0(%arg0: !pdl.operation): - transform.structured.canonicalized_sequence %arg0 failures(propagate) { - ^bb1(%variant_op: !pdl.operation): - // First level of tiling + fusion parallelizes to blocks. - // The mapping to block ids can only happen after bufferization atm - %root = transform.structured.match interface{LinalgOp} - attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op - %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op - %red = transform.structured.match interface{LinalgOp} - attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op - %not_root = merge_handles %fill, %red - %foreach_thread, %tiled_generic = - transform.structured.tile_to_foreach_thread_op %root tile_sizes [1, 4] - transform.structured.fuse_into_containing_op %not_root into %foreach_thread - - // Second level of tiling + fusion parallelizes to threads. - // Leaving the reduction untiled on threadIdx.x makes it sequential on - // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is - // introduced and opportunities for distributing vector ops across warps - // appear. - %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op - %reduction_linalg = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op - %parallel_linalg = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op - %foreach_thread_reduction, %tiled_reduction_generic = - transform.structured.tile_to_foreach_thread_op %reduction_linalg tile_sizes [1, 1] - (mapped to dims [2, 1, 0]) - // TODO: this fusion currently does not happen properly, this is related to the clone - // behavior when fusing into scf.foreach_thread. - // Once fixed we'll be able to fuse. - // Fusion will save us one roundtrip to memory. - // transform.structured.fuse_into_containing_op %fill_linalg into %foreach_thread_reduction - transform.structured.tile_to_foreach_thread_op %parallel_linalg num_threads [1, 4, 32] - (mapped to dims [2, 1, 0]) +transform.structured.canonicalized_sequence failures(propagate) { +^bb1(%variant_op: !pdl.operation): + // First level of tiling + fusion parallelizes to blocks. + // The mapping to block ids can only happen after bufferization atm + %root = transform.structured.match interface{LinalgOp} + attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op + %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op + %red = transform.structured.match interface{LinalgOp} + attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op + %not_root = merge_handles %fill, %red + %foreach_thread, %tiled_generic = + transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 4] + transform.structured.fuse_into_containing_op %not_root into %foreach_thread + + // Second level of tiling + fusion parallelizes to threads. + // Leaving the reduction untiled on threadIdx.x makes it sequential on + // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is + // introduced and opportunities for distributing vector ops across warps + // appear. + %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op + %reduction_linalg = transform.structured.match ops{["linalg.generic"]} + attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op + %parallel_linalg = transform.structured.match ops{["linalg.generic"]} + attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op + %foreach_thread_reduction, %tiled_reduction_generic = + transform.structured.tile_to_foreach_thread_op %reduction_linalg tile_sizes [1, 1] + (mapped to dims [2, 1, 0]) + // TODO: this fusion currently does not happen properly, this is related to the clone + // behavior when fusing into scf.foreach_thread. + // Once fixed we'll be able to fuse. + // Fusion will save us one roundtrip to memory. + // transform.structured.fuse_into_containing_op %fill_linalg into %foreach_thread_reduction + transform.structured.tile_to_foreach_thread_op %parallel_linalg num_threads [1, 4, 32] + (mapped to dims [2, 1, 0]) - // Inability to tile reductions to scf.foreach_thread has 2 implications: - // 1. since no scf.foreach_thread is present, no gpu.barrier is added. - // This should be fixed independently: ops that are not nested in an scf.foreach_thread - // should have a gpu.barrier. Later needs to be complemented by a barrier - // removal pass. - // 2. Similarly, needs to be predicated under an if threadIx == 0 to avoid - // multiple threads updating the buffer inplace once bufferized. - // - // Instead, we can vectorize and go to vector SSA values that sidestep these - // issues. - // Everyone will race to the write while still computing the same value. - // - // That is still not good enough because we need to predicate this in order - // to enable the parallel reduction on warps. - %func = transform.structured.match ops{["func.func"]} in %variant_op - %funcx = transform.iree.apply_patterns %func { rank_reducing } - transform.structured.vectorize %funcx + // Inability to tile reductions to scf.foreach_thread has 2 implications: + // 1. since no scf.foreach_thread is present, no gpu.barrier is added. + // This should be fixed independently: ops that are not nested in an scf.foreach_thread + // should have a gpu.barrier. Later needs to be complemented by a barrier + // removal pass. + // 2. Similarly, needs to be predicated under an if threadIx == 0 to avoid + // multiple threads updating the buffer inplace once bufferized. + // + // Instead, we can vectorize and go to vector SSA values that sidestep these + // issues. + // Everyone will race to the write while still computing the same value. + // + // That is still not good enough because we need to predicate this in order + // to enable the parallel reduction on warps. + %func = transform.structured.match ops{["func.func"]} in %variant_op + %funcx = transform.iree.apply_patterns %func { rank_reducing } + transform.structured.vectorize %funcx - // Bufferization is necessary for: - // 1. lowering scf.foreach_thread to workgroup (block level parallelism) - // 2. lowering scf.foreach_thread to gpu (thread level parallelism) - // 3. introducing predication (due to 1. + 2.) which enables rewriting to - // warp_execute_on_lane_0 and later vector distribution. - %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op + // Bufferization is necessary for: + // 1. lowering scf.foreach_thread to workgroup (block level parallelism) + // 2. lowering scf.foreach_thread to gpu (thread level parallelism) + // 3. introducing predication (due to 1. + 2.) which enables rewriting to + // warp_execute_on_lane_0 and later vector distribution. + %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op - %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2 - %func_3 = transform.iree.foreach_thread_to_workgroup %func_2 - transform.iree.foreach_thread_to_gpu_and_translation_info %func_3 - { workgroup_size = [32, 4, 1] } + %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2 + %func_3 = transform.iree.foreach_thread_to_workgroup %func_2 + transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3 + { workgroup_size = [32, 4, 1] } - %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2 - %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing } + %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2 + %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing } - // Vector distribution needs to happen on buffers. - %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 - %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } - transform.iree.vector.warp_distribute %end_func_2 - } + // Vector distribution needs to happen on buffers. + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 + %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } + transform.iree.vector.warp_distribute %end_func_2 } - - diff --git a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir index 3a922447ebad..67d896ad8eac 100644 --- a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir @@ -46,7 +46,7 @@ transform.structured.canonicalized_sequence failures(propagate) { %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op %func = transform.structured.match ops{["func.func"]} in %variant_op_2 %func_2 = transform.iree.foreach_thread_to_workgroup %func - transform.iree.foreach_thread_to_gpu_and_translation_info %func_2 + transform.iree.map_nested_foreach_thread_to_gpu_threads %func_2 { workgroup_size = [32, 1, 1] } // Vector distribution needs to happen on buffers. diff --git a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken new file mode 100644 index 000000000000..68f891ed38f1 --- /dev/null +++ b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken @@ -0,0 +1,57 @@ +// RUN: iree-opt %s + +// Codegen +transform.structured.canonicalized_sequence failures(propagate) { +// transform.sequence %arg0 failures(propagate) { +^bb1(%variant_op: !pdl.operation): + // First level of tiling + fusion parallelizes to blocks. + // The mapping to block ids can only happen after bufferization atm + %root = transform.structured.match interface{LinalgOp} + attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op + %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op + %red = transform.structured.match interface{LinalgOp} + attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op + %not_root = merge_handles %fill, %red + %foreach_thread, %tiled_generic = + transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 1] + (mapped to dims [0, 1, 2]) + transform.structured.fuse_into_containing_op %not_root into %foreach_thread + + // Second level of tiling + fusion parallelizes to threads. + // Leaving the reduction untiled on threadIdx.x makes it sequential on + // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is + // introduced and opportunities for distributing vector ops across warps + // appear. + %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op + %reduction_linalg = transform.structured.match ops{["linalg.generic"]} + attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op + %not_root_2 = merge_handles %fill_linalg, %reduction_linalg + %parallel_linalg = transform.structured.match ops{["linalg.generic"]} + attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op + %foreach_thread_2, %parallel_linalg_2 = + transform.structured.tile_to_foreach_thread_op %parallel_linalg tile_sizes [1, 1, 0] + (mapped to dims [2, 1, 0]) + transform.structured.fuse_into_containing_op %not_root_2 into %foreach_thread_2 + + // Rank-reduce and vectorize. + %func = transform.structured.match ops{["func.func"]} in %variant_op + %funcx = transform.iree.apply_patterns %func { rank_reducing } + transform.structured.vectorize %funcx + + // Bufferization is necessary for: + // 1. lowering scf.foreach_thread to workgroup (block level parallelism) + // 2. lowering scf.foreach_thread to gpu (thread level parallelism) + // 3. introducing predication (due to 1. + 2.) which enables rewriting to + // warp_execute_on_lane_0 and later vector distribution. + %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op + %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2 + %func_3 = transform.iree.foreach_thread_to_workgroup %func_2 + transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3 + { workgroup_size = [32, 1, 1] } + + // Vector distribution needs to happen on buffers. + %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2 + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 + %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } + transform.iree.vector.warp_distribute %end_func +} diff --git a/third_party/llvm-project b/third_party/llvm-project index 0b77a6734c29..5925ec76d49c 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 0b77a6734c294d60c37617c4010c803e92191b65 +Subproject commit 5925ec76d49c29ea238c88a734dde6dd9feea102 diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo index 4fbf18d6d70a..b9ad408df1e5 160000 --- a/third_party/mlir-hlo +++ b/third_party/mlir-hlo @@ -1 +1 @@ -Subproject commit 4fbf18d6d70ac21941d15572618407a33affc93a +Subproject commit b9ad408df1e514e6b232863499f34fdc4bbc2160 diff --git a/tools/BUILD b/tools/BUILD index fdaa9456220c..725cffc2a556 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -158,6 +158,8 @@ cc_binary( "//runtime/src/iree/base:tracing", "//runtime/src/iree/base/internal:flags", "//runtime/src/iree/hal", + "//runtime/src/iree/modules/hal:types", + "//runtime/src/iree/tooling:comparison", "//runtime/src/iree/tooling:context_util", "//runtime/src/iree/tooling:vm_util", "//runtime/src/iree/vm", diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index ef37edb9a594..0a99dbf4315e 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -126,6 +126,8 @@ iree_cc_binary( iree::base::internal::flags iree::base::tracing iree::hal + iree::modules::hal::types + iree::tooling::comparison iree::tooling::context_util iree::tooling::vm_util iree::vm @@ -266,17 +268,24 @@ if(IREE_BUILD_COMPILER) HOSTONLY ) - # Ensure FileCheck gets built. Tests don't have dependencies in CMake because - # they aren't targets. So until we fix that, we just force this to get built. + # Ensure FileCheck and associated binaries get built. Tests don't have + # dependencies in CMake because they aren't targets. So until we fix that, we + # just force this to get built. # Limiting this to when IREE_BUILD_TESTS is set prevents the installation # below, which we use for cross-platform testing. set_target_properties(FileCheck PROPERTIES EXCLUDE_FROM_ALL OFF) + set_target_properties(not PROPERTIES EXCLUDE_FROM_ALL OFF) - # Bundle the FileCheck binary from LLVM into our tests/bin directory so - # installed FileCheck tests are hermetic. + # Bundle the FileCheck and associated binaries from LLVM into our tests/bin + # directory so installed FileCheck tests are hermetic. install( TARGETS FileCheck DESTINATION "tests/bin" COMPONENT Tests ) + install( + TARGETS not + DESTINATION "tests/bin" + COMPONENT Tests + ) endif(IREE_BUILD_COMPILER) diff --git a/tools/android/run_module_app/src/main.cc b/tools/android/run_module_app/src/main.cc index f4047daf711a..f8feeee830c3 100644 --- a/tools/android/run_module_app/src/main.cc +++ b/tools/android/run_module_app/src/main.cc @@ -153,11 +153,11 @@ Status RunModule(const IreeModuleInvocation& invocation) { iree_allocator_system()), "invoking function '%s'", function_name.c_str()); - std::ostringstream oss; - IREE_RETURN_IF_ERROR(PrintVariantList(outputs.get(), &oss), + std::string result; + IREE_RETURN_IF_ERROR(PrintVariantList(outputs.get(), &result), "printing results"); LOGI("Execution Result:"); - LOGI("%s", oss.str().c_str()); + LOGI("%s", result.c_str()); inputs.reset(); outputs.reset(); diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc index f028523bd052..73cc4c3b6f5a 100644 --- a/tools/iree-benchmark-module-main.cc +++ b/tools/iree-benchmark-module-main.cc @@ -54,7 +54,6 @@ #include #include -#include #include #include #include @@ -432,7 +431,7 @@ int main(int argc, char** argv) { iree_status_t status = iree_benchmark.Register(); if (!iree_status_is_ok(status)) { int ret = static_cast(iree_status_code(status)); - std::cout << iree::Status(std::move(status)) << std::endl; + printf("%s\n", iree::Status(std::move(status)).ToString().c_str()); return ret; } ::benchmark::RunSpecifiedBenchmarks(); diff --git a/tools/iree-check-module-main.cc b/tools/iree-check-module-main.cc index 8c4dd88befb0..f7bf455a427f 100644 --- a/tools/iree-check-module-main.cc +++ b/tools/iree-check-module-main.cc @@ -6,7 +6,6 @@ #include #include -#include #include #include #include @@ -90,7 +89,7 @@ iree_status_t Run(std::string module_file_path, iree_allocator_t host_allocator, // iree_tooling_load_module_from_flags. iree_file_contents_t* flatbuffer_contents = NULL; if (module_file_path == "-") { - std::cout << "Reading module contents from stdin...\n"; + printf("Reading module contents from stdin...\n"); IREE_RETURN_IF_ERROR( iree_stdin_read_contents(host_allocator, &flatbuffer_contents)); } else { @@ -173,8 +172,8 @@ extern "C" int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); if (argc < 2) { - std::cerr - << "A binary module file path to run (or - for stdin) must be passed"; + fprintf(stderr, + "A binary module file path to run (or - for stdin) must be passed"); return -1; } auto module_file_path = std::string(argv[1]); @@ -185,16 +184,15 @@ extern "C" int main(int argc, char** argv) { int ret = iree_status_is_ok(status) ? exit_code : 1; if (FLAG_expect_failure) { if (ret == 0) { - std::cout << "Test passed but expected failure\n"; + printf("Test passed but expected failure\n"); return 1; } - std::cout << "Test failed as expected\n"; + printf("Test failed as expected\n"); return 0; } if (ret != 0) { - std::cout << "Test failed\n"; - std::cout << Status(std::move(status)); + printf("Test failed\n%s\n", Status(std::move(status)).ToString().c_str()); } return ret; diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc index 62e18857bd99..7c0b465901d8 100644 --- a/tools/iree-run-mlir-main.cc +++ b/tools/iree-run-mlir-main.cc @@ -28,9 +28,9 @@ // used to separate the compiler flags from the runtime flags, such as: // iree-run-mlir --iree-hal-target-backends=vulkan-spirv -- --logtostderr +#include #include #include -#include #include #include #include @@ -238,7 +238,7 @@ Status PrepareModule(std::string target_backend, } // Translate from MLIR to IREE bytecode. - std::cout << "Compiling for target backend '" << target_backend << "'...\n"; + printf("Compiling for target backend '%s'...\n", target_backend.c_str()); mlir::PassManager pass_manager(mlir_module->getContext()); pass_manager.enableVerifier(verify_passes_flag); mlir::applyPassManagerCLOptions(pass_manager); @@ -281,7 +281,7 @@ Status PrepareModule(std::string target_backend, "serialization to annotated MLIR (text) failed"); } text_output.flush(); - std::cerr << text_contents << std::endl; + fprintf(stderr, "%s\n", text_contents.c_str()); } if (print_flatbuffer_flag) { bytecode_options.outputFormat = @@ -295,7 +295,7 @@ Status PrepareModule(std::string target_backend, "serialization to flatbuffer bytecode (text) failed"); } text_output.flush(); - std::cerr << text_contents << std::endl; + fprintf(stderr, "%s\n", text_contents.c_str()); } if (!output_file_flag.empty()) { if (llvm::writeToOutput( @@ -320,8 +320,7 @@ Status EvaluateFunction(iree_vm_context_t* context, iree_string_view_t function_name) { IREE_TRACE_SCOPE(); - std::cout << "EXEC @" << std::string(function_name.data, function_name.size) - << std::endl; + printf("EXEC @%.*s\n", (int)function_name.size, function_name.data); // Parse input values from the flags. vm::ref inputs; @@ -499,8 +498,8 @@ Status RunFile(const std::string& mlir_filename, llvm::Twine(split_line)); auto sub_failure = EvaluateFile(std::move(sub_buffer), registry); if (!sub_failure.ok()) { - std::cerr << "Failure for split at line #" << split_line << ": " - << sub_failure << "\n"; + fprintf(stderr, "Failure for split at line #%u: %s\n", split_line, + sub_failure.ToString().c_str()); if (any_failure.ok()) { any_failure = std::move(sub_failure); } @@ -558,8 +557,8 @@ extern "C" int main(int argc_llvm, char** argv_llvm) { auto status = RunFile(input_file_flag, registry); if (!status.ok()) { - std::cerr << "ERROR running file (" << input_file_flag << "):\n" - << status << "\n"; + fprintf(stderr, "ERROR running file (%s):\n%s\n", input_file_flag.c_str(), + status.ToString().c_str()); return 1; } return 0; diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc index b4eda1db5c6e..884c43e4c86d 100644 --- a/tools/iree-run-module-main.cc +++ b/tools/iree-run-module-main.cc @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -18,6 +17,8 @@ #include "iree/base/status_cc.h" #include "iree/base/tracing.h" #include "iree/hal/api.h" +#include "iree/modules/hal/types.h" +#include "iree/tooling/comparison.h" #include "iree/tooling/context_util.h" #include "iree/tooling/vm_util.h" #include "iree/vm/api.h" @@ -35,15 +36,15 @@ IREE_FLAG(bool, print_statistics, false, "Prints runtime statistics to stderr on exit."); // TODO(benvanik): move --function_input= flag into a util. -static iree_status_t parse_function_input(iree_string_view_t flag_name, - void* storage, - iree_string_view_t value) { +static iree_status_t parse_function_io(iree_string_view_t flag_name, + void* storage, + iree_string_view_t value) { auto* list = (std::vector*)storage; list->push_back(std::string(value.data, value.size)); return iree_ok_status(); } -static void print_function_input(iree_string_view_t flag_name, void* storage, - FILE* file) { +static void print_function_io(iree_string_view_t flag_name, void* storage, + FILE* file) { auto* list = (std::vector*)storage; if (list->empty()) { fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data); @@ -56,8 +57,7 @@ static void print_function_input(iree_string_view_t flag_name, void* storage, } static std::vector FLAG_function_inputs; IREE_FLAG_CALLBACK( - parse_function_input, print_function_input, &FLAG_function_inputs, - function_input, + parse_function_io, print_function_io, &FLAG_function_inputs, function_input, "An input (a) value or (b) buffer of the format:\n" " (a) scalar value\n" " value\n" @@ -74,10 +74,19 @@ IREE_FLAG_CALLBACK( "Each occurrence of the flag indicates an input in the order they were\n" "specified on the command line."); +static std::vector FLAG_expected_outputs; +IREE_FLAG_CALLBACK(parse_function_io, print_function_io, &FLAG_expected_outputs, + expected_output, + "An expected function output following the same format as " + "--function_input. When present the results of the " + "invocation will be compared against these values and the " + "tool will return non-zero if any differ. If the value of a " + "particular output is not of interest provide `(ignored)`."); + namespace iree { namespace { -iree_status_t Run() { +iree_status_t Run(int* out_exit_code) { IREE_TRACE_SCOPE0("iree-run-module"); iree_allocator_t host_allocator = iree_allocator_system(); @@ -122,16 +131,36 @@ iree_status_t Run() { IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr, 16, host_allocator, &outputs)); - std::cout << "EXEC @" << function_name << "\n"; + printf("EXEC @%s\n", function_name.c_str()); IREE_RETURN_IF_ERROR( iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/nullptr, inputs.get(), outputs.get(), host_allocator), "invoking function '%s'", function_name.c_str()); - IREE_RETURN_IF_ERROR( - PrintVariantList(outputs.get(), (size_t)FLAG_print_max_element_count), - "printing results"); + if (FLAG_expected_outputs.empty()) { + IREE_RETURN_IF_ERROR( + PrintVariantList(outputs.get(), (size_t)FLAG_print_max_element_count), + "printing results"); + } else { + // Parse expected list into host-local memory that we can easily access. + // Note that we return a status here as this can fail on user inputs. + vm::ref heap_allocator; + IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap( + IREE_SV("heap"), host_allocator, host_allocator, &heap_allocator)); + vm::ref expected_list; + IREE_RETURN_IF_ERROR(ParseToVariantList(heap_allocator.get(), + FLAG_expected_outputs, + host_allocator, &expected_list)); + + // Compare expected vs actual lists and output diffs. + bool did_match = iree_tooling_compare_variant_lists( + expected_list.get(), outputs.get(), host_allocator, stdout); + if (did_match) { + printf("[SUCCESS] all function outputs matched their expected values.\n"); + } + *out_exit_code = did_match ? EXIT_SUCCESS : EXIT_FAILURE; + } // Release resources before gathering statistics. inputs.reset(); @@ -157,20 +186,22 @@ extern "C" int main(int argc, char** argv) { if (argc > 1) { // Avoid iree-run-module spinning endlessly on stdin if the user uses single // dashes for flags. - std::cout << "Error: unexpected positional argument (expected none)." - " Did you use pass a flag with a single dash ('-')?" - " Use '--' instead.\n"; + printf( + "[ERROR] unexpected positional argument (expected none)." + " Did you use pass a flag with a single dash ('-')?" + " Use '--' instead.\n"); return 1; } - iree_status_t status = Run(); + int exit_code = EXIT_SUCCESS; + iree_status_t status = Run(&exit_code); if (!iree_status_is_ok(status)) { iree_status_fprint(stderr, status); iree_status_free(status); return EXIT_FAILURE; } - return EXIT_SUCCESS; + return exit_code; } } // namespace iree diff --git a/tools/test/BUILD b/tools/test/BUILD index f6d88a82b792..c54e2b125253 100644 --- a/tools/test/BUILD +++ b/tools/test/BUILD @@ -22,6 +22,7 @@ iree_lit_test_suite( "iree-benchmark-module.mlir", "iree-run-mlir.mlir", "iree-run-module.mlir", + "iree-run-module-expected.mlir", "multiple_args.mlir", "multiple_exported_functions.mlir", "null_values.mlir", @@ -41,6 +42,7 @@ iree_lit_test_suite( "//tools:iree-run-module", "@llvm-project//lld", "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tools/test/CMakeLists.txt b/tools/test/CMakeLists.txt index e4075f289ed6..3a1e4184aae1 100644 --- a/tools/test/CMakeLists.txt +++ b/tools/test/CMakeLists.txt @@ -17,6 +17,7 @@ iree_lit_test_suite( "executable_benchmarks.mlir" "iree-benchmark-module.mlir" "iree-run-mlir.mlir" + "iree-run-module-expected.mlir" "iree-run-module.mlir" "multiple_args.mlir" "multiple_exported_functions.mlir" @@ -30,6 +31,7 @@ iree_lit_test_suite( iree-compile iree-run-mlir iree-run-module + not LABELS "hostonly" ) diff --git a/tools/test/iree-run-module-expected.mlir b/tools/test/iree-run-module-expected.mlir new file mode 100644 index 000000000000..f7647c8bb64f --- /dev/null +++ b/tools/test/iree-run-module-expected.mlir @@ -0,0 +1,20 @@ +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | iree-run-module --device=local-task --entry_function=abs --function_input=f32=-2 --expected_output=f32=-2 --expected_output=f32=2.0) | FileCheck %s --check-prefix=SUCCESS-MATCHES +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | iree-run-module --device=local-task --entry_function=abs --function_input=f32=-2 --expected_output=f32=-2 --expected_output="(ignored)") | FileCheck %s --check-prefix=SUCCESS-IGNORED +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | iree-run-module --device=local-task --entry_function=abs --function_input=f32=-2 --expected_output=f32=-2 --expected_output=f32=2.1 --expected_f32_threshold=0.1) | FileCheck %s --check-prefix=SUCCESS-THRESHOLD +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --entry_function=abs --function_input=f32=-2 --expected_output=f32=123 --expected_output=f32=2.0) | FileCheck %s --check-prefix=FAILED-FIRST +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --entry_function=abs --function_input=f32=-2 --expected_output=f32=-2 --expected_output=f32=4.5) | FileCheck %s --check-prefix=FAILED-SECOND +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --entry_function=abs --function_input=f32=-2 --expected_output=f32=-2 --expected_output=4xf32=2.0) | FileCheck %s --check-prefix=FAILED-SHAPE +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --entry_function=abs --function_input=f32=-2 --expected_output=f32=-2 --expected_output=8) | FileCheck %s --check-prefix=FAILED-TYPE + +// SUCCESS-MATCHES: [SUCCESS] +// SUCCESS-THRESHOLD: [SUCCESS] +// SUCCESS-IGNORED: [SUCCESS] +// FAILED-FIRST: [FAILED] result[0]: element at index 0 (-2) does not match the expected (123) +// FAILED-SECOND: [FAILED] result[1]: element at index 0 (2) does not match the expected (4.5) +// FAILED-SHAPE: [FAILED] result[1]: metadata is f32; expected that the view matches 4xf32 +// FAILED-TYPE: [FAILED] result[1]: variant types mismatch + +func.func @abs(%input: tensor) -> (tensor, tensor) { + %result = math.absf %input : tensor + return %input, %result : tensor, tensor +}