diff --git a/python/tvm/micro/testing/aot_test_utils.py b/python/tvm/micro/testing/aot_test_utils.py index 06cd0f1c9ea4..991a3f0ddb8e 100644 --- a/python/tvm/micro/testing/aot_test_utils.py +++ b/python/tvm/micro/testing/aot_test_utils.py @@ -65,6 +65,16 @@ }, ) +AOT_APROFILE_AEM_RUNNER = AOTTestRunner( + makefile="aprofile_aem", + includes=[], + pass_config={ + "tir.usmp.enable": False, + # AOT test infra generates 'fake' tensor inputs which fails asserts + "tir.disable_assert": True, + }, +) + def parametrize_aot_options(test): """Parametrize over valid option combinations""" diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 2fc148c3effd..9974d2691d4b 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -21,7 +21,9 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import re +import tvm from tvm import relay, topi, tir +from tvm.tir.schedule.analysis import has_block from ....auto_scheduler import is_auto_scheduler_enabled from ....meta_schedule import is_meta_schedule_enabled @@ -639,7 +641,7 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target): def schedule_dense_arm_cpu(attrs, inputs, out_type, target): """dense arm cpu strategy""" strategy = _op.OpStrategy() - data, _ = inputs + data, weight = inputs if target.features.has_dsp and data.dtype in ["int8", "int16"]: strategy.add_implementation( @@ -680,6 +682,23 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target): plevel=11, ) + if ( + target.features.has_sme + and data.dtype in ["float32"] + and weight.dtype in ["float32"] + and out_type.dtype in ["float32"] + # The schedule uses tensorization which does not work when the + # reduction axis has unit iters. See + # https://github.com/apache/tvm/issues/16566 + and data.shape[1] > 1 + ): + strategy.add_implementation( + wrap_compute_dense(topi.arm_cpu.compute_matmul_sme), + lambda: None, + name="matmul.arm_cpu.sme", + plevel=12, + ) + # Fallback to x86 schedules as there is currently no arm_cpu schedule for dense strategy.add_implementation( wrap_compute_dense(topi.x86.dense_nopack), @@ -697,6 +716,40 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target): return strategy +@matmul_strategy.register("arm_cpu") +def matmul_strategy_arm_cpu(attrs, inputs, out_type, target): + """matmul arm cpu strategy""" + strategy = _op.OpStrategy() + data, weight = inputs + + if ( + target.features.has_sme + and data.dtype in ["float32"] + and weight.dtype in ["float32"] + and out_type.dtype in ["float32"] + and not (attrs.transpose_a or attrs.transpose_b) + and len(data.shape) == 2 + # The schedule uses tensorization which does not work when the + # reduction axis has unit iters. See + # https://github.com/apache/tvm/issues/16566 + and data.shape[1] > 1 + ): + # Ideally we should check that weight is a Relay constant, but strategy functions + # don't have access to the data needed to check this. + strategy.add_implementation( + wrap_compute_matmul(topi.arm_cpu.compute_matmul_sme), + lambda: None, + name="matmul.arm_cpu.sme", + ) + return strategy + + logger.warning("matmul is not optimized for arm cpu.") + strategy.add_implementation( + wrap_compute_matmul(topi.nn.matmul), naive_schedule, name="matmul.generic" + ) + return strategy + + @conv1d_strategy.register("arm_cpu") def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv1d strategy""" @@ -737,3 +790,17 @@ def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target): f"Unsupported kernel layout {kernel_layout} for conv1d {layout} for arm cpu." ) return strategy + + +def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool: + """ + Strategy for arm_cpu STIR schedules. + """ + current_target = tvm.target.Target.current() + + if current_target.features.has_sme and has_block(sch, "matmul_sme_gemm"): + topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) + return True + + # Fallback to TE schedule for operators we have not written a special TIR schedule for + return False diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index d0ceee4aa2a0..1190616737ce 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1023,6 +1023,19 @@ def _corstone300_compile_time_check(): parent_features="cmsisnn", ) + +def _aprofile_aem_fvp_compile_time_check(): + if shutil.which("FVP_Base_RevC-2xAEMvA") is None: + return "AProfile AEM is not available" + return True + + +requires_aprofile_aem_fvp = Feature( + "aprofile-aem-fvp", + "AProfile AEM FVP", + compile_time_check=_aprofile_aem_fvp_compile_time_check, +) + # Mark a test as requiring Vitis AI to run requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI") @@ -1205,6 +1218,10 @@ def decorator(*args): return decorator +def skip_if_no_reference_system(func): + return skip_if_32bit(reason="Reference system unavailable in i386 container")(func) + + def requires_package(*packages): """Mark a test as requiring python packages to run. diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 7e5a26bdeb43..d127335e82a6 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -16,4 +16,3 @@ # under the License. # pylint: disable=unused-import """Intrinsics for tensorization.""" -from . import arm_cpu, cuda, rocm, x86, hexagon diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index a5003d41a8d1..90af1e05b172 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -17,6 +17,10 @@ # pylint: disable=invalid-name,missing-function-docstring,unused-import """Intrinsics for ARM tensorization.""" from tvm.script import tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder.tir import prim_func as build_prim_func +from tvm.target.codegen import llvm_version_major + from .. import TensorIntrin from .dot_product_common import ( DP4A_S8S8S32_INTRIN, @@ -163,15 +167,367 @@ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None: return dot_prod_desc, dot_prod_impl +def get_sme_transpose_interleave_2svlx2svl_intrin(): + """ + Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length) using + the Scalable Matrix Extension (SME). + + This is completed by loading rows of the input matrix into the accumulator tile, + then storing the columns. The SME accumulator tile is divided into a series of sub-tiles + which must be loaded to / stored from independently. + + Note: currently only supports the fp32 datatype. + + Example + ------- + An example case for float32. In this instance the accumulator tile is divided into 4 + sub-tiles of size SVLxSVL numbered 0-3. We start by loading rows of A, each SVL in length, + into each of the sub-tiles. In the diagram below, each load for a sub-tile is sequenced by + a, b, ... till the tile is full. + + The columns of each sub-tile are then stored into A_t. Note that to perform a transpose, + the contents of sub-tile 1 and 2 are stored in opposite locations - see the diagram + below. + + A: Accumulator tile: A_t: + 2SVL 2SVL 2SVL + +----------------+ +-----------------+ +-------------------+ + | --0a-- --1a-- | | | | | | | | | + | --0b-- --1b-- | | 0 1 | | 0a 0b .. 2a 2b .. | + | ... ... | ld1w.horiz | | st1w.vert | | | | | | + 2SVL | --2a-- --3a-- | ====> 2SVL | | ====> 2SVL | | | | | | + | --2a-- --3b-- | | 2 3 | | 1a 1b .. 3a 3b .. | + | ... ... | | | | | | | | | + +----------------+ +-----------------+ +-------------------+ + + Returns + ------- + intrin : TensorIntrin + The SME TensorIntrin that can be used in tensorizing a schedule. + + """ + SVF = 4 * T.vscale() + SVF2 = 2 * SVF + + @T.prim_func + def desc(a: T.handle, a_t: T.handle) -> None: + A = T.match_buffer(a, (SVF2, SVF2), dtype="float32", offset_factor=1) + A_t = T.match_buffer(a_t, (SVF2, SVF2), dtype="float32", offset_factor=1) + with T.block("root"): + T.reads(A[0:SVF2, 0:SVF2]) + T.writes(A_t[0:SVF2, 0:SVF2]) + for k, m in T.grid(SVF2, SVF2): + with T.block("transpose"): + v_m, v_k = T.axis.remap("SS", [m, k]) + A_t[v_k, v_m] = A[v_m, v_k] + + def impl(): + # Accumulation sub-tile count. For fp32 it is 4 + sub_tile_count = 4 + + with IRBuilder() as ib: + with build_prim_func(): + a = T.arg("a", T.handle()) + a_t = T.arg("a_t", T.handle()) + + A = T.match_buffer( + a, (SVF2, SVF2), "float32", offset_factor=1, strides=[T.int32(), 1] + ) + A_t = T.match_buffer( + a_t, + (SVF2, SVF2), + "float32", + offset_factor=1, + strides=[T.int32(), 1], + ) + + # Disable predication + ptrue = T.broadcast(T.IntImm("int1", 1), T.vscale() * 4) + + with T.block("root"): + T.reads(A[0:SVF2, 0:SVF2]) + T.writes(A_t[0:SVF2, 0:SVF2]) + + # Load rows of the input matrix + with T.serial(0, SVF) as slice_idx: + for sub_tile_idx in range(0, sub_tile_count): + row_offset = SVF if sub_tile_idx >= (sub_tile_count // 2) else 0 + col_offset = SVF if sub_tile_idx % 2 else 0 + offset = (slice_idx + row_offset) * A.strides[0] + col_offset + + input_ptr = A.access_ptr("r", offset=offset) + sub_tile = T.int32(sub_tile_idx) + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.ld1w.horiz", + T.uint32(4), + ptrue, + input_ptr, + sub_tile, + slice_idx, + ) + ) + + # Store columns to the ouptut matrix + with T.serial(0, SVF) as slice_idx: + for sub_tile_idx in range(0, sub_tile_count): + col_offset = SVF if sub_tile_idx >= (sub_tile_count // 2) else 0 + row_offset = SVF if sub_tile_idx % 2 else 0 + offset = (slice_idx + row_offset) * A_t.strides[0] + col_offset + + output_ptr = A_t.access_ptr("w", offset=offset) + sub_tile = T.int32(sub_tile_idx) + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.st1w.vert", + T.uint32(4), + ptrue, + output_ptr, + sub_tile, + slice_idx, + ) + ) + + return ib.get() + + return desc, impl() + + +def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K): + """ + Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length using + outer product operations from the Scalable Matrix Extension (SME). + + The inputs A and B are expected to be of size K x 2SVL and produce a result C of + size 2SVL x 2SVL. + + The SME accumulator tile is divided into sub-tiles, each of which is utilized to + calculate the outer-product using columns / rows of A and B respectively. For each + sub-tile, elements in the first column of input matrix A (accessed sequentially due + to being transpose-interleaved) and first row of input matrix B are used to calculate + an outer-product. This is then accumulated with the result of performing an + outer-product on the second column and row of A and B respectively. This process is + repeated K times. Finally, the results of the accumulation are stored. + + Note: The input tensor 'A' must be transpose-interleaved. + Note: Currently only supports the fp32 datatype. + + Example + ------- + + Diagram showing outer-product performed on each of the accumulator sub-tiles + for the fp32 datatype: + + SVL SVL + +----------------------------+ + | l | h | K + K +----------------------------+ + +---+ +----------------------------+ + | | | 0: 1: |-+ + | | | mopa(l, l) mopa(l, h) | |-+ + l | | | | | | + | | | | | | + |---| | | | | + | | | 2: 3: | | | + h | | | mopa(h, l) mopa(h, h) | | | + | | | | | | + | | | | | | + +---+ +----------------------------+ | | + +----------------------------+ | + +---------------------------+ + (accumulate K times) + + Pseudo code computing 2SVL x 2SVL GEMM for fp32 inputs: + + .. code-block:: c + + // Number of fp32 elements in a scalable vector + int SVF = SVL / 32; + + // Reset the accumulator tile + sme.zero(); + + // Calculate outer products and accumulate + for (k = 0; k < K; k++) { + float32xSVF A_row_0 = A[k][0]; + float32xSVF A_row_1 = A[k][SVF]; + float32xSVF B_row_0 = B[k][0]; + float32xSVF B_row_1 = B[k][SVF]; + + float32xSVFxSVF sub_tile_0 += sme.mopa(A_row_0, B_row_0); + float32xSVFxSVF sub_tile_1 += sme.mopa(A_row_0, B_row_1); + float32xSVFxSVF sub_tile_2 += sme.mopa(A_row_1, B_row_0); + float32xSVFxSVF sub_tile_3 += sme.mopa(A_row_1, B_row_1); + } + + // Store the results of accumulation + for (i = 0; i < SVF; i++) { + C[i][0] = sme.horiz(sub_tile_0[i]); + C[i][0] = sme.horiz(sub_tile_0[i + SVF]); + C[i + SVF][0] = sme.horiz(sub_tile_0[i]); + C[i + SVF][0] = sme.horiz(sub_tile_0[i + SVF]); + } + + Notes: + - Recall that A has been transposed beforehand such that each column is now accessed + by row. + - 'sme.zero' resets the accumulator tile to contain all zero's. + - 'sme.mopa' is the outer product and accumulate intrinsic. + - 'sme.horiz' stores rows of an accumulator sub-tile to memory. + + Returns + ------- + intrin : TensorIntrin + The SME TensorIntrin that can be used in tensorizing a schedule. + + """ + SVF = 4 * T.vscale() + SVF2 = 2 * SVF + + @T.prim_func + def desc(a: T.handle, b: T.handle, c: T.handle): + A = T.match_buffer(a, (K, SVF2), dtype="float32", offset_factor=1) + B = T.match_buffer(b, (K, SVF2), dtype="float32", offset_factor=1) + C = T.match_buffer(c, (SVF2, SVF2), dtype="float32", offset_factor=1) + + with T.block("root"): + T.reads(C[0:SVF2, 0:SVF2], A[0:K, 0:SVF2], B[0:K, 0:SVF2]) + T.writes(C[0:SVF2, 0:SVF2]) + for m, n, k in T.grid(SVF2, SVF2, K): + with T.block("gemm"): + v_m, v_n, v_k = T.axis.remap("SSR", [m, n, k]) + C[v_m, v_n] += A[v_k, v_m] * B[v_k, v_n] + + def impl(): + # Accumulation sub-tile count. For fp32 it is 4 + sub_tile_count = 4 + + with IRBuilder() as ib: + with build_prim_func(): + a = T.arg("a", T.handle()) + b = T.arg("b", T.handle()) + c = T.arg("c", T.handle()) + + A = T.match_buffer(a, (K, SVF2), "float32", offset_factor=1, strides=[T.int32(), 1]) + B = T.match_buffer(b, (K, SVF2), "float32", offset_factor=1, strides=[T.int32(), 1]) + C = T.match_buffer( + c, (SVF2, SVF2), "float32", offset_factor=1, strides=[T.int32(), 1] + ) + + ptrue = T.broadcast(T.IntImm("int1", 1), T.vscale() * 4) + + with T.block("root"): + T.reads(C[0:SVF2, 0:SVF2], A[0:K, 0:SVF2], B[0:K, 0:SVF2]) + T.writes(C[0:SVF2, 0:SVF2]) + + # Iterate over the reduction axis applying outer product and accumulate + with T.serial(K) as k: + a_low = T.BufferLoad(A, [k, T.Ramp(0, 1, T.vscale() * 4)]) + a_high = T.BufferLoad(A, [k, T.Ramp(SVF, 1, T.vscale() * 4)]) + b_low = T.BufferLoad(B, [k, T.Ramp(0, 1, T.vscale() * 4)]) + b_high = T.BufferLoad(B, [k, T.Ramp(SVF, 1, T.vscale() * 4)]) + + input_combinations = [ + (a_low, b_low), + (a_low, b_high), + (a_high, b_low), + (a_high, b_high), + ] + for sub_tile_idx in range(0, sub_tile_count): + sub_tile = T.int32(sub_tile_idx) + input_1 = input_combinations[sub_tile_idx][0] + input_2 = input_combinations[sub_tile_idx][1] + + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.mopa.nxv4f32", + T.uint32(5), + sub_tile, + ptrue, + ptrue, + input_1, + input_2, + ) + ) + + # Store the accumulated tile results + with T.serial(SVF) as slice_idx: + for sub_tile_idx in range(sub_tile_count): + vert_offset = SVF if sub_tile_idx >= (sub_tile_count // 2) else 0 + horiz_offset = SVF if sub_tile_idx % 2 else 0 + local_offset = (slice_idx + vert_offset) * C.strides[0] + horiz_offset + output_ptr = C.access_ptr("w", offset=local_offset, extent=SVF) + + T.evaluate( + T.call_llvm_intrin( + "void", + "llvm.aarch64.sme.st1w.horiz", + T.uint32(4), + ptrue, + output_ptr, + T.int32(sub_tile_idx), + T.int32(slice_idx), + ) + ) + + return ib.get() + + return desc, impl() + + +def get_sme_init_intrin(): + """ + Reset the entire matrix tile storage to 0. + """ + SVF2 = 2 * 4 * T.vscale() + + @T.prim_func + def desc(c: T.handle) -> None: + C = T.match_buffer(c, (SVF2, SVF2), "float32", offset_factor=1) + with T.block("root"): + T.reads() + T.writes(C[0:SVF2, 0:SVF2]) + for m, n in T.grid(SVF2, SVF2): + with T.block("init"): + v_m, v_n = T.axis.remap("SS", [m, n]) + C[v_m, v_n] = T.float32(0) + + @T.prim_func + def impl(c: T.handle) -> None: + C = T.match_buffer(c, (SVF2, SVF2), "float32", offset_factor=1) + with T.block("root"): + T.reads() + T.writes(C[0:SVF2, 0:SVF2]) + clear_all_tiles = T.int32(255) + T.evaluate( + T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", T.uint32(1), clear_all_tiles) + ) + + return desc, impl + + ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon" ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot" ARM_DOT_4x4_u8_UDOT_INTRIN = "dot_4x4_u8u8u32_udot" ARM_DOT_4x4_u8_HDOT_INTRIN = "dot_4x4_u8u8i32_hdot" TensorIntrin.register(ARM_DOT_4x4_i8_NEON_INTRIN, neon_4x4_i8i8i32_desc, neon_4x4_i8i8i32_impl) - TensorIntrin.register(ARM_DOT_4x4_i8_SDOT_INTRIN, *get_dotprod_intrin("int8", "int32")) - TensorIntrin.register(ARM_DOT_4x4_u8_UDOT_INTRIN, *get_dotprod_intrin("uint8", "uint32")) - TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8", "int32")) + +ARM_SME_INIT = "sme_init" +ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE = "sme_2svlx2svl_transpose_interleave" +ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA = "sme_2svlx2svl_gemm_interleaved_mopa" + +# The following tensor intrinsics use LLVM intrinsics that are only available +# in versions of LLVM >= 15. Installations with older versions of LLVM will +# not be able to use them. +if llvm_version_major() >= 15: + TensorIntrin.register( + ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, *get_sme_transpose_interleave_2svlx2svl_intrin() + ) + TensorIntrin.register(ARM_SME_INIT, *get_sme_init_intrin()) diff --git a/python/tvm/topi/arm_cpu/__init__.py b/python/tvm/topi/arm_cpu/__init__.py index 054103f43bef..5484adaa6409 100644 --- a/python/tvm/topi/arm_cpu/__init__.py +++ b/python/tvm/topi/arm_cpu/__init__.py @@ -22,13 +22,16 @@ from .depthwise_conv2d import * from .conv2d_transpose import * from .conv2d_int8 import * -from . import conv2d_alter_op from .bitserial_conv2d import * from .bitserial_dense import * from .injective import * from .group_conv2d import * from .pooling import * from .dense import * +from .matmul import * from .qnn import * + +from . import conv2d_alter_op +from . import dense_alter_op from . import qnn_alter_op from . import qnn_legalize diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index c350b87167b2..f2e01c5aefd6 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -19,6 +19,7 @@ import tvm from tvm.target import Target +from tvm.tir.expr import PrimExpr def get_tiling_A(interleave_A, in_dtype): @@ -186,6 +187,31 @@ def get_conv2d_im2col_padding(M, K, tile_M, tile_K): return pad_M, pad_K +def pad_dim_to_multiple(dim: PrimExpr, multiple: PrimExpr): + """ + Compute the padding required to reach specified multiple. + + Parameters + ---------- + dim : PrimExpr + Current size of the dim. + multiple : PrimExpr + Multiple to pad up to. + + Returns + ------- + padded_dim : PrimExpr + The new dim size. + pad_value : PrimExpr + The padding required. + """ + pad_value = 0 + if dim % multiple != 0: + pad_value = multiple - (dim % multiple) + padded_dim = dim + pad_value + return padded_dim, pad_value + + def get_conv2d_weights_padding(N, K, tile_N, tile_K): """Compute the necessary padding for matrix B', where B' is the transformed version of matrix B in C=A*B. diff --git a/python/tvm/topi/arm_cpu/dense.py b/python/tvm/topi/arm_cpu/dense.py index dd66b0d531bc..6a44cc89b0a6 100644 --- a/python/tvm/topi/arm_cpu/dense.py +++ b/python/tvm/topi/arm_cpu/dense.py @@ -14,16 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel """Dense schedule for ARM CPU""" - from tvm import autotvm -from .mprofile.dsp.dense import dense_dsp_schedule, dense_dsp_compute + +from .mprofile.dsp.dense import ( + dense_dsp_schedule, + dense_dsp_compute, +) @autotvm.register_topi_compute("dense_dsp.arm_cpu") def dense_dsp(cfg, data, weight, bias, out_dtype): - """Compute conv2d_nhwc with v7e-m DSP instructions.""" + """Compute dense_dsp with v7e-m DSP instructions.""" return dense_dsp_compute(cfg, data, weight, bias=bias, out_dtype=out_dtype) diff --git a/python/tvm/topi/arm_cpu/dense_alter_op.py b/python/tvm/topi/arm_cpu/dense_alter_op.py new file mode 100644 index 000000000000..208b923e68e4 --- /dev/null +++ b/python/tvm/topi/arm_cpu/dense_alter_op.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dense alter op definitions for the `arm_cpu` device key.""" + +import tvm +from tvm import relay +from tvm import autotvm +from tvm import te + +from ..nn import dense_alter_layout + + +@dense_alter_layout.register("arm_cpu") +def _alter_dense(attrs, inputs, tinfos, out_type): + target = tvm.target.Target.current(allow_none=False) + dispatch_ctx = autotvm.task.DispatchContext.current + + _, outs = relay.backend.te_compiler.select_implementation( + relay.op.get("nn.dense"), + attrs, + tinfos, + out_type, + target, + ) + workload = autotvm.task.get_workload(outs) + if workload is None: + # The best implementation is not an AutoTVM template, + # we then assume it's not necessary to alter this op. + return None + + cfg = dispatch_ctx.query(target, workload) + topi_impl = workload[0] + if topi_impl == "matmul.arm_cpu.sme": + # Pre-compute transposed weights and convert to a matmul + assert isinstance( + inputs[1], relay.Constant + ), "matmul_sme.arm_cpu requires weights be a Relay Constant" + + weight_dtype = tinfos[1].dtype + weight_data = inputs[1].data.numpy() + interleaved = weight_data.transpose() + encoded_weight = relay.const(interleaved, weight_dtype) + + new_weight = te.placeholder((weight_data.shape), dtype=weight_dtype) + new_workload = autotvm.task.args_to_workload( + [tinfos[0], new_weight, None, out_type.dtype], topi_impl + ) + dispatch_ctx.update(target, new_workload, cfg) + + return relay.nn.matmul( + inputs[0], + encoded_weight, + units=attrs.units, + out_dtype=attrs.out_dtype, + transpose_a=False, + transpose_b=False, + ) + + # x86 schedules are used as a fallback + return tvm.topi.x86.dense_alter_op._alter_dense_layout(attrs, inputs, tinfos, out_type) diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py new file mode 100644 index 000000000000..ea8b27cabcf6 --- /dev/null +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument + +"""Matmul schedules for the `arm_cpu` device key.""" + +import tvm +from tvm import te +from tvm import autotvm +from tvm.script import tir as T +from tvm.topi import nn +from tvm.topi.utils import get_const_tuple +from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes +from tvm.topi.arm_cpu.arm_utils import pad_dim_to_multiple + + +@autotvm.register_topi_compute("matmul.arm_cpu.sme") +def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, transpose_b=False): + """ + SME Matmul compute definition. + """ + assert ( + transpose_a == transpose_b == False + ), "Compute definition currently does not support transposed inputs." + + M, K = get_const_tuple(data_a.shape) + N = get_const_tuple(data_b.shape)[1] + + if not out_dtype: + out_dtype = data_a.dtype + + tile_m = 2 * 4 * tvm.tir.vscale() + tile_n = 2 * 4 * tvm.tir.vscale() + + M_padded, pad_M = pad_dim_to_multiple(M, tile_m) + N_padded, pad_N = pad_dim_to_multiple(N, tile_n) + if pad_M != 0: + data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=(pad_M, 0)) + if pad_N != 0: + data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=(0, pad_N)) + + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M_padded, N_padded), + lambda m, n: te.sum( + data_a[m, k].astype(data_a.dtype) * data_b[k, n].astype(data_b.dtype), + axis=k, + ).astype(out_dtype), + name="matmul_sme_gemm", + ) + C = te.compute((M, N), lambda m, n: C[m, n]) + return C + + +def tir_schedule_matmul_sme(sch): + """ + SME STIR Matmul schedule. + """ + # pylint: disable=import-outside-toplevel + from tvm.tir.tensor_intrin.arm_cpu import ( + ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, + ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, + ARM_SME_INIT, + get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, + ) + + gemm_block = sch.get_block("matmul_sme_gemm") + m, n, k = sch.get_loops(gemm_block) + + extent_m = sch.get(m).extent + extent_k = sch.get(k).extent + + tile_m = T.cast(2 * 4 * T.vscale(), extent_m.dtype) + tile_k = T.cast(2 * 4 * T.vscale(), extent_k.dtype) + tile_n = T.cast(2 * 4 * T.vscale(), sch.get(n).extent.dtype) + + # Interleave the input utilizing the matrix tile + interleave_a_block = sch.cache_read(gemm_block, 0, "global") + sch.transform_layout(interleave_a_block, ("write", 0), lambda m, k: (k, m)) + m, k = sch.get_loops(interleave_a_block) + outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True) + outer_k, inner_k = sch.split(k, factors=(None, tile_k), disable_predication=True) + sch.reorder(outer_k, outer_m, inner_k, inner_m) + sch.tensorize(inner_k, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE) + + # Split and reorder the loops of the GeMM for tensorization + m, n, k = sch.get_loops(gemm_block) + outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True) + outer_n, inner_n = sch.split(n, factors=(None, tile_n), disable_predication=True) + sch.reorder(outer_m, outer_n, inner_m, inner_n, k) + + # Tensorize the GeMM initialization + init_block = sch.decompose_reduction(gemm_block, inner_m) + sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT) + + # Tensorize the GeMM update + sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{extent_k}" + tvm.tir.TensorIntrin.register( + sme_gemm_interleaved_intrin_name, + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_k), + override=True, + ) + sch.tensorize(inner_m, sme_gemm_interleaved_intrin_name) + + # Add pstate annotations + root_block = sch.get_block("root") + sch.annotate( + root_block, SMEAttributes.STREAMING_MODE, SMEAttributes.StreamingModeValues.ENABLED + ) + sch.annotate(root_block, SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 0e9b1f7b65f0..10b1248c6a3a 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -39,7 +39,7 @@ def check_int8_applicable(x, y, allow_padding=False): ) -@dense_alter_layout.register(["cpu", "arm_cpu"]) +@dense_alter_layout.register(["cpu"]) def _alter_dense_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index b82fff218f68..bf3028b7b3e8 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -371,7 +371,7 @@ class ConstIntBoundAnalyzer::Impl } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) { - return MakeBound(1, 16); + return MakeBound(1, kAArch64VScaleValues.size()); } else { return Everything(op->dtype); } diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index b747855bff59..2655cf66719c 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -476,12 +476,10 @@ class ScheduleBuilder : public ExprVisitor { mod_eq_structural_(meta_schedule::ModuleEquality::Create("ignore-ndarray")) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + database_ = meta_schedule::Database::Current(); if (backend::IsMetaScheduleEnabled()) { - database_ = meta_schedule::Database::Current(); CHECK(database_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay " "build, but no `meta_schedule.Database` context is provided. "; - } else { - database_ = NullOpt; } } diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 9e2fe63b006a..ccc973485529 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -193,6 +193,7 @@ RELAY_REGISTER_OP("nn.matmul") .add_argument("tensor_a", "nD Tensor", "The first input Tensor.") .add_argument("tensor_b", "2D Tensor", "The second input Tensor.") .set_support_level(1) + .set_attr("FInferCorrectLayout", DenseInferCorrectLayout) .add_type_rel("Matmul", MatmulRel) .set_attr("TOpPattern", kOutEWiseFusable); diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 00e573eaf6e4..a97cda266f53 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -18,6 +18,8 @@ */ #include "./ir_comparator.h" +#include "../../arith/scalable_expression.h" + namespace tvm { namespace tir { @@ -74,7 +76,9 @@ bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { bool equal = n.same_as(other) || ((n->type_index() == other->type_index()) && - n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other)); + n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other)) || + (tvm::arith::ContainsVscaleCall(n) && analyzer_.CanProveEqual(n, other)); + if (!equal && assert_mode_) { std::ostringstream os; os << "Expression mismatch: " << n << " vs " << other; diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 8f22ba5b73ed..06bbcac57872 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -15,15 +15,17 @@ # specific language governing permissions and limitations # under the License. -import re +""" +Codegen tests for AArch64 +""" +import re import pytest import tvm from tvm import te from tvm.script import tir as T from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes - from tvm.target.codegen import llvm_version_major @@ -496,6 +498,46 @@ def main(A: T.Buffer((5,), "int32")): assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." +@pytest.mark.skipif( + llvm_version_major() < 16, reason="SME is not supported in earlier versions of LLVM" +) +@pytest.mark.parametrize("dtype", ["float32"]) +def test_matmul_sme(dtype): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+v9a,+sme" + + def check_correct_assembly(dtype): + A = te.placeholder((32, 32), dtype=dtype, name="A") + B = te.placeholder((32, 32), dtype=dtype, name="B") + + with tvm.target.Target(target): + C = tvm.topi.arm_cpu.matmul.compute_matmul_sme(A, B, None, dtype, False, False) + prim_func = te.create_prim_func([A, B, C]) + + sch = tvm.tir.Schedule(prim_func) + tvm.topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) + prim_func = sch.mod + + f = tvm.build(prim_func, target=target) + + assembly = f.get_source("asm") + smstart = re.findall(r"smstart\t(sm|za)", assembly) + loads = re.findall(r"ld1[whdb]\t{\s?za", assembly) + mopa = re.findall( + r"fmopa\tza[0-9].[shdb],( p[0-9]/[zm],)?( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", + assembly, + ) + stores = re.findall(r"st1[whdb]\t{\s?za", assembly) + smstop = re.findall(r"smstop\t(sm|za)", assembly) + + assert len(smstart) > 0 + assert len(loads) > 0 + assert len(mopa) > 0 + assert len(stores) > 0 + assert len(smstop) > 0 + + check_correct_assembly(dtype=dtype) + + @pytest.mark.skipif( llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" ) diff --git a/tests/python/integration/test_arm_aprofile.py b/tests/python/integration/test_arm_aprofile.py index af35a1429735..d32fed00afe8 100644 --- a/tests/python/integration/test_arm_aprofile.py +++ b/tests/python/integration/test_arm_aprofile.py @@ -16,7 +16,6 @@ # under the License. """Tests for Arm(R) A-Profile Architecture.""" import os -import subprocess import numpy as np import pytest @@ -26,8 +25,6 @@ from tvm import relay from tvm.relay.transform import ToMixedPrecision, FoldConstant from tvm.relay.build_module import bind_params_by_name -from tvm.testing.aot import AOTTestModel, AOTTestRunner, generate_ref_data, compile_and_run -from tvm.contrib import utils def get_mattr(dtype): @@ -80,96 +77,5 @@ def test_conv2d(dtype): lib.export_library(lib_path, cc="aarch64-linux-gnu-gcc") -# AOT Test Runner using the AArch64 Architecture Envelope Model (AEM) -# Fixed Virtual Platform (FVP) reference system. -# See: https://developer.arm.com/Tools%20and%20Software/Fixed%20Virtual%20Platforms -AOT_APROFILE_AEM_RUNNER = AOTTestRunner( - makefile="aprofile_aem", - pass_config={ - "tir.usmp.enable": False, - "tir.disable_assert": True, # AOT test infra creates 'fake' inputs that fail asserts - }, -) - - -@tvm.testing.requires_x86 -@tvm.testing.skip_if_32bit -def test_aem_simple_addition(): - """Tests a simple addition running on the AArch64 AEM.""" - inp = relay.var("data", shape=(1, 2, 4, 4)) - add = relay.add(inp, relay.const(np.ones((1, 2, 4, 4)))) - func = relay.Function([inp], add) - ir_mod = tvm.IRModule.from_expr(func) - ir_mod = tvm.relay.transform.InferType()(ir_mod) - - main_func = ir_mod["main"] - shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} - type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} - - input_data = np.random.uniform(size=shape_dict["data"]).astype(type_dict["data"]) - params = {} - inputs = {"data": input_data} - ref_outputs = generate_ref_data(ir_mod, inputs, params) - - compile_and_run( - AOTTestModel(module=ir_mod, inputs=inputs, outputs=ref_outputs, params=params), - target=tvm.target.Target("llvm -mtriple=aarch64-none-elf"), - runtime=tvm.relay.backend.Runtime("crt", {"system-lib": True}), - interface_api="packed", - use_unpacked_api=False, - runner=AOT_APROFILE_AEM_RUNNER, - ) - - -@tvm.testing.requires_x86 -@tvm.testing.skip_if_32bit -def test_aem_asm_sme(): - """ - Tests SME assembly runs on the AArch64 AEM. This test is used as a simple - sanity check until the TVM schedules are able to produce SME. - """ - c_code = """ - #include - - int main(void) { - __asm volatile( - "smstart\\n" - "smstop\\n" - ); - printf("EXITTHESIM\\n"); - return 0; - } - """ - runner = AOT_APROFILE_AEM_RUNNER - - tmpdir = utils.tempdir() - build_path = os.path.join(tmpdir.path, "build") - os.makedirs(build_path, exist_ok=True) - - with open(build_path + "/test.c", "w") as f: - f.write(c_code) - - file_dir = os.path.dirname(os.path.abspath(__file__)) - makefile_dir = os.path.join(file_dir, "../../../tests/python/relay/aot") - makefile = os.path.join(makefile_dir, f"{runner.makefile}.mk") - - make_command = ( - f"make -f {makefile} build_dir={build_path}" - + f" TVM_ROOT={file_dir}/../../.." - + f" AOT_TEST_ROOT={makefile_dir}" - + " FVP_DIR=/opt/arm/fvp/Base_RevC_AEMvA_pkg/models/Linux64_GCC-9.3/" - ) - - compile_command = f"{make_command} aot_test_runner" - popen = subprocess.Popen(compile_command, cwd=build_path, shell=True, stdout=subprocess.PIPE) - return_code = popen.wait() - assert not return_code, "Failed to compile" - - run_command = f"{make_command} run" - popen = subprocess.Popen(run_command, cwd=build_path, shell=True, stdout=subprocess.PIPE) - return_code = popen.wait() - assert not return_code, "Failed to run" - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py index 8cc1c7c7aa44..1272b35451f9 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py @@ -18,7 +18,7 @@ import tvm from tvm import meta_schedule as ms from tvm.script import tir as T -from tvm.tir.tensor_intrin import arm_cpu, cuda, rocm, x86 +from tvm.tir.tensor_intrin import cuda, rocm, x86 @tvm.script.ir_module diff --git a/tests/python/relay/strategy/arm_cpu/scalable_utils.py b/tests/python/relay/strategy/arm_cpu/scalable_utils.py new file mode 100644 index 000000000000..ad16a47612d0 --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/scalable_utils.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.tir.stmt_functor import post_order_visit, ir_transform + + +def calculate_extra_workspace_size_from_scalable_extents(func, known_vscale_value): + """ + The AOT executor needs to know the size of the workspace ahead of time, but this + isn't possible when some allocations are scalable (vscale is not known at compile-time). + If we know the target hardware, we can reason about the value of vscale ahead of time. + This function will calculate an upper-bound for the extra workspace bytes required by the + AOT executor given TIR function and a known value for vscale. + """ + extra_workspace_bytes = 0 + is_scalable_extent = False + ana = tvm.arith.Analyzer() + + def replace_vscale_with_known_value(stmt): + nonlocal is_scalable_extent + if isinstance(stmt, tvm.tir.expr.Call) and stmt.op.name == "tir.vscale": + is_scalable_extent = True + return tvm.tir.IntImm(stmt.dtype, known_vscale_value) + + def calculate_workspace_bytes(stmt): + nonlocal extra_workspace_bytes, is_scalable_extent + if isinstance(stmt, tvm.tir.stmt.Allocate): + for extent in stmt.extents: + extent_stmt = tvm.tir.Evaluate(extent) + is_scalable_extent = False + mutated_extent = ir_transform(extent_stmt, replace_vscale_with_known_value, None) + # Non scalable extents are already included in the calculation by AOT + if is_scalable_extent: + alloc_bytes = ana.simplify(mutated_extent.value) * tvm.DataType(stmt.dtype).bits + extra_workspace_bytes += alloc_bytes + + post_order_visit(func.body, calculate_workspace_bytes) + return extra_workspace_bytes diff --git a/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py b/tests/python/relay/strategy/arm_cpu/test_dense.py similarity index 50% rename from tests/python/relay/strategy/arm_cpu/test_dense_dsp.py rename to tests/python/relay/strategy/arm_cpu/test_dense.py index abd3ac4a3f6a..b9384e532e7d 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense.py @@ -14,14 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import numpy as np + import tvm import tvm.testing from tvm import relay -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data -from tvm.micro.testing.aot_test_utils import ( - AOT_CORSTONE300_RUNNER, +from tvm import meta_schedule +from tvm.testing.aot import ( + AOTTestModel, + AOTCompiledTestModel, + compile_and_run, + run_and_check, + generate_ref_data, ) +from tvm.micro.testing.aot_test_utils import AOT_CORSTONE300_RUNNER, AOT_APROFILE_AEM_RUNNER +from tvm.target.codegen import llvm_version_major +from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy +from scalable_utils import calculate_extra_workspace_size_from_scalable_extents class BasicDenseTests: @@ -84,5 +94,80 @@ class TestDense(BasicDenseTests): enable_bias = tvm.testing.parameter(False, True) +@pytest.mark.skipif( + llvm_version_major() < 17, reason="SME is not supported in earlier versions of LLVM" +) +@tvm.testing.requires_aprofile_aem_fvp +@pytest.mark.parametrize( + "data_shape,weight_shape", + [ + ((32, 32), (32, 32)), + ((2, 35), (6, 35)), + ((3, 3), (68, 3)), + ((79, 65), (152, 65)), + ], +) +@pytest.mark.parametrize("dtype", ["float32"]) +def test_sme_dense(data_shape, weight_shape, dtype): + np.random.seed(0) + + input_data = np.random.uniform(size=data_shape).astype(dtype) + inp = relay.var("data", shape=data_shape, dtype=dtype) + weight_data = np.random.uniform(size=weight_shape).astype(dtype) + weight = relay.const(weight_data, dtype=dtype) + + dense = relay.nn.dense(inp, weight) + func = relay.Function(relay.analysis.free_vars(dense), dense) + + ir_mod = tvm.IRModule.from_expr(func) + ir_mod = tvm.relay.transform.InferType()(ir_mod) + + inputs = {"data": input_data} + params = {} + ref_outputs = generate_ref_data(ir_mod, inputs, params) + + target = tvm.target.Target("llvm -mtriple=aarch64-none-elf -mattr=+v9.2a,+sme") + runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True}) + executor = tvm.relay.backend.Executor( + "aot", + { + "interface-api": "packed", + "unpacked-api": False, + }, + ) + + with tvm.transform.PassContext( + opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config + ), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): + executor_factory = tvm.relay.build( + ir_mod, + target=target, + executor=executor, + runtime=runtime, + params=params, + ) + generated_func = executor_factory.lowered_ir_mods.items()[0][1][ + "tvmgen_default_fused_nn_matmul" + ] + extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4) + + test_model = AOTTestModel( + ir_mod, inputs, ref_outputs, params=params, extra_memory_in_bytes=extra_memory_in_bytes + ) + compiled = AOTCompiledTestModel(test_model, executor_factory) + + assembly = ( + compiled.executor_factory.module.imported_modules[0].imported_modules[0].get_source("asm") + ) + assert "fmopa" in assembly + + assert run_and_check( + models=[compiled], + interface_api="packed", + runner=AOT_APROFILE_AEM_RUNNER, + print_output_on_mismatch=True, + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/strategy/arm_cpu/test_matmul.py b/tests/python/relay/strategy/arm_cpu/test_matmul.py new file mode 100644 index 000000000000..3b46c8019a65 --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_matmul.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import numpy as np + +import tvm +from tvm import relay +from tvm import meta_schedule +from tvm.testing.aot import ( + AOTTestModel, + AOTCompiledTestModel, + run_and_check, + generate_ref_data, +) +from tvm.micro.testing.aot_test_utils import AOT_APROFILE_AEM_RUNNER +from tvm.target.codegen import llvm_version_major +from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy +from scalable_utils import calculate_extra_workspace_size_from_scalable_extents + + +@pytest.mark.skipif( + llvm_version_major() < 17, reason="SME is not supported in earlier versions of LLVM" +) +@tvm.testing.requires_aprofile_aem_fvp +@pytest.mark.parametrize( + "data_shape,weight_shape,transpose_a,transpose_b", + [ + ((4, 63), (63, 10), False, False), + ((64, 32), (32, 32), False, True), + ((96, 64), (64, 32), False, False), + ((62, 3), (3, 3), False, False), + ((4, 5), (79, 5), False, True), + ((134, 36), (36, 111), False, False), + ((3, 10), (10, 72), False, False), + # Tensorization does not work when the reduction axis has unit iters. + # See https://github.com/apache/tvm/issues/16566 + # ((5, 1), (1, 5), False, False), + ], +) +@pytest.mark.parametrize("dtype", ["float32"]) +def test_sme_matmul_with_const_b(data_shape, weight_shape, transpose_a, transpose_b, dtype): + """ + Execution tests for matmul Scalable Matrix Extension (SME) schedule. + """ + np.random.seed(0) + + input_data = np.random.uniform(size=data_shape).astype(dtype) + inp = relay.var("data", shape=data_shape, dtype=dtype) + weight_data = np.random.uniform(size=weight_shape).astype(dtype) + weight = relay.const(weight_data, dtype=dtype) + + matmul = relay.nn.matmul(inp, weight, transpose_a=transpose_a, transpose_b=transpose_b) + func = relay.Function(relay.analysis.free_vars(matmul), matmul) + + ir_mod = tvm.IRModule.from_expr(func) + ir_mod = tvm.relay.transform.InferType()(ir_mod) + + inputs = {"data": input_data} + params = {} + ref_outputs = generate_ref_data(ir_mod, inputs, params) + + target = tvm.target.Target("llvm -mtriple=aarch64-none-elf -mattr=+v9.2a,+sme") + runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True}) + executor = tvm.relay.backend.Executor( + "aot", + { + "interface-api": "packed", + "unpacked-api": False, + }, + ) + with tvm.transform.PassContext( + opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config + ), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): + executor_factory = tvm.relay.build( + ir_mod, + target=target, + executor=executor, + runtime=runtime, + params=params, + ) + generated_func = executor_factory.lowered_ir_mods.items()[0][1][ + "tvmgen_default_fused_nn_matmul" + ] + extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4) + + test_model = AOTTestModel( + ir_mod, inputs, ref_outputs, params=params, extra_memory_in_bytes=extra_memory_in_bytes + ) + compiled = AOTCompiledTestModel(test_model, executor_factory) + + assembly = executor_factory.module.imported_modules[0].imported_modules[0].get_source("asm") + assert "fmopa" in assembly + + assert run_and_check( + models=[compiled], + interface_api="packed", + runner=AOT_APROFILE_AEM_RUNNER, + print_output_on_mismatch=True, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index d0767175d3d8..71dd688e2929 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -258,18 +258,23 @@ def test_int8_depthwise_conv2d(target, expected_impl): @pytest.mark.parametrize( "target,expected_valid_impl,expected_impl", - [("llvm -device=arm_cpu", ["dense_pack.x86", "dense_nopack.x86"], "dense_pack.x86")], + [ + ( + "llvm -device=arm_cpu", + ["dense_pack.x86", "dense_nopack.x86"], + "dense_pack.x86", + ), + ], ) def test_dense(target, expected_valid_impl, expected_impl): target = tvm.target.Target(target) - data_shape = (30, 40) weight_shape = (30, 40) dtype = "float32" out = relay.nn.dense( relay.var("data", shape=data_shape, dtype=dtype), - relay.var("weight", shape=weight_shape, dtype=dtype), + relay.const(np.zeros((weight_shape)).astype(dtype)), out_dtype=dtype, ) out = run_infer_type(out) @@ -284,7 +289,51 @@ def test_dense(target, expected_valid_impl, expected_impl): ] valid_impl = relay.backend.te_compiler.get_valid_implementations(*args) selected_impl, _ = relay.backend.te_compiler.select_implementation(*args, use_autotvm=False) + assert len(valid_impl) == len(expected_valid_impl) + for impl in valid_impl: + assert impl.name in expected_valid_impl + assert selected_impl.name == expected_impl + +@pytest.mark.skipif(llvm_version_major() < 15, reason="Older versions of LLVM don't support SME.") +@pytest.mark.parametrize( + "shape,expected_valid_impl,expected_impl", + [ + ( + (30, 40), + ["matmul.arm_cpu.sme", "dense_pack.x86", "dense_nopack.x86"], + "matmul.arm_cpu.sme", + ), + ( + (5, 1), + ["dense_pack.x86", "dense_nopack.x86"], + "dense_pack.x86", + ), + ], +) +def test_dense_with_sme_target(shape, expected_valid_impl, expected_impl): + target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme") + data_shape = shape + weight_shape = shape + dtype = "float32" + + out = relay.nn.dense( + relay.var("data", shape=data_shape, dtype=dtype), + relay.const(np.zeros((weight_shape)).astype(dtype)), + out_dtype=dtype, + ) + out = run_infer_type(out) + + with target: + args = [ + out.op, + out.attrs, + [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], + out.checked_type, + target, + ] + valid_impl = relay.backend.te_compiler.get_valid_implementations(*args) + selected_impl, _ = relay.backend.te_compiler.select_implementation(*args, use_autotvm=False) assert len(valid_impl) == len(expected_valid_impl) for impl in valid_impl: assert impl.name in expected_valid_impl diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 831070299f56..f74b31157ae2 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -23,6 +23,7 @@ from tvm.relay import transform, analysis from tvm.relay.testing.temp_op_attr import TempOpAttr from tvm.relay.testing import run_infer_type +from tvm.target.codegen import llvm_version_major import numpy as np import tvm.testing from tvm.relay import testing @@ -1451,6 +1452,61 @@ def expected(): assert tvm.ir.structural_equal(a, b) +@pytest.mark.skipif( + llvm_version_major() < 17, reason="SME is not supported in earlier versions of LLVM" +) +def test_alter_op_dense_arm_cpu_sme(): + np.random.seed(0) + y_data = np.random.uniform(size=(64, 32)).astype("float32") + + def before(): + x = relay.var("x", shape=(32, 32), dtype="float32") + y = relay.const(y_data, dtype="float32") + dense = relay.nn.dense(x, y) + return relay.Function(analysis.free_vars(dense), dense) + + def expected(): + x = relay.var("x", shape=(32, 32), dtype="float32") + y = relay.const(y_data.transpose(), dtype="float32") + matmul = relay.nn.matmul(x, y) + return relay.Function(analysis.free_vars(matmul), matmul) + + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme"): + with TempOpAttr("nn.dense", "FTVMAlterOpLayout", topi.arm_cpu.dense_alter_op._alter_dense): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + +@pytest.mark.skipif( + llvm_version_major() < 17, reason="SME is not supported in earlier versions of LLVM" +) +@pytest.mark.parametrize( + "transpose_b,transform_b", [(False, lambda x: x), (True, lambda x: x.transpose())] +) +def test_alter_op_matmul_arm_cpu_sme(transpose_b, transform_b): + np.random.seed(0) + y_data = np.random.uniform(size=(64, 32)).astype("float32") + + def before(): + x = relay.var("x", shape=(96, 32), dtype="float32") + y = relay.const(y_data, dtype="float32") + dense = relay.nn.matmul(x, y, transpose_a=False, transpose_b=transpose_b) + return relay.Function(analysis.free_vars(dense), dense) + + def expected(): + x = relay.var("x", shape=(96, 32), dtype="float32") + y = relay.const(transform_b(y_data), dtype="float32") + matmul = relay.nn.matmul(x, y) + return relay.Function(analysis.free_vars(matmul), matmul) + + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme"): + with TempOpAttr("nn.dense", "FTVMAlterOpLayout", topi.arm_cpu.dense_alter_op._alter_dense): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + def test_conv2d_strided_slice_packed_to_unpacked(): """We do not support propagating through packed to unpacked layout""" x_shape = (1, 1, 1, 1, 4) diff --git a/tests/python/topi/test_topi_matmul.py b/tests/python/topi/test_topi_matmul.py index 4b05dd3813e2..a7b3965aeed3 100644 --- a/tests/python/topi/test_topi_matmul.py +++ b/tests/python/topi/test_topi_matmul.py @@ -14,12 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import pytest import numpy as np + import tvm import tvm.testing from tvm import te from tvm import topi from tvm.topi.utils import get_const_tuple +from tvm.topi.arm_cpu.matmul import compute_matmul_sme def with_tvm(lam, *args): @@ -148,7 +152,17 @@ def test_tensordot(): verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1))) +@pytest.mark.parametrize("transpose_a,transpose_b", [(True, False), (False, True)]) +def test_unsupported_sme_matmul_compute_transpose(transpose_a, transpose_b): + """ + SME matmul compute does not support transposed inputs for now. + """ + err_msg = "Compute definition currently does not support transposed inputs." + with pytest.raises(AssertionError, match=err_msg) as e: + compute_matmul_sme( + te.placeholder((32, 32)), te.placeholder((32, 32)), None, None, transpose_a, transpose_b + ) + + if __name__ == "__main__": - test_nn_matmul() - test_matmul() - test_tensordot() + tvm.testing.main()