Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SME] Introduce scalable fp32 dense schedule #16921

Merged
merged 7 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/tvm/micro/testing/aot_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@
},
)

AOT_APROFILE_AEM_RUNNER = AOTTestRunner(
makefile="aprofile_aem",
includes=[],
pass_config={
"tir.usmp.enable": False,
# AOT test infra generates 'fake' tensor inputs which fails asserts
"tir.disable_assert": True,
},
)


def parametrize_aot_options(test):
"""Parametrize over valid option combinations"""
Expand Down
69 changes: 68 additions & 1 deletion python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re

import tvm
from tvm import relay, topi, tir
from tvm.tir.schedule.analysis import has_block

from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
Expand Down Expand Up @@ -639,7 +641,7 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
"""dense arm cpu strategy"""
strategy = _op.OpStrategy()
data, _ = inputs
data, weight = inputs

if target.features.has_dsp and data.dtype in ["int8", "int16"]:
strategy.add_implementation(
Expand Down Expand Up @@ -680,6 +682,23 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
plevel=11,
)

if (
target.features.has_sme
and data.dtype in ["float32"]
and weight.dtype in ["float32"]
and out_type.dtype in ["float32"]
# The schedule uses tensorization which does not work when the
# reduction axis has unit iters. See
# https://github.com/apache/tvm/issues/16566
and data.shape[1] > 1
):
strategy.add_implementation(
wrap_compute_dense(topi.arm_cpu.compute_matmul_sme),
lambda: None,
name="matmul.arm_cpu.sme",
plevel=12,
)

# Fallback to x86 schedules as there is currently no arm_cpu schedule for dense
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_nopack),
Expand All @@ -697,6 +716,40 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
return strategy


@matmul_strategy.register("arm_cpu")
def matmul_strategy_arm_cpu(attrs, inputs, out_type, target):
"""matmul arm cpu strategy"""
strategy = _op.OpStrategy()
data, weight = inputs

if (
target.features.has_sme
and data.dtype in ["float32"]
and weight.dtype in ["float32"]
and out_type.dtype in ["float32"]
and not (attrs.transpose_a or attrs.transpose_b)
and len(data.shape) == 2
# The schedule uses tensorization which does not work when the
# reduction axis has unit iters. See
# https://github.com/apache/tvm/issues/16566
and data.shape[1] > 1
):
# Ideally we should check that weight is a Relay constant, but strategy functions
# don't have access to the data needed to check this.
strategy.add_implementation(
wrap_compute_matmul(topi.arm_cpu.compute_matmul_sme),
lambda: None,
name="matmul.arm_cpu.sme",
)
return strategy

logger.warning("matmul is not optimized for arm cpu.")
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul), naive_schedule, name="matmul.generic"
)
return strategy


@conv1d_strategy.register("arm_cpu")
def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv1d strategy"""
Expand Down Expand Up @@ -737,3 +790,17 @@ def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
f"Unsupported kernel layout {kernel_layout} for conv1d {layout} for arm cpu."
)
return strategy


def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
"""
Strategy for arm_cpu STIR schedules.
"""
current_target = tvm.target.Target.current()

if current_target.features.has_sme and has_block(sch, "matmul_sme_gemm"):
topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
return True

# Fallback to TE schedule for operators we have not written a special TIR schedule for
return False
17 changes: 17 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,19 @@ def _corstone300_compile_time_check():
parent_features="cmsisnn",
)


def _aprofile_aem_fvp_compile_time_check():
if shutil.which("FVP_Base_RevC-2xAEMvA") is None:
return "AProfile AEM is not available"
return True


requires_aprofile_aem_fvp = Feature(
"aprofile-aem-fvp",
"AProfile AEM FVP",
compile_time_check=_aprofile_aem_fvp_compile_time_check,
)

# Mark a test as requiring Vitis AI to run
requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI")

Expand Down Expand Up @@ -1205,6 +1218,10 @@ def decorator(*args):
return decorator


def skip_if_no_reference_system(func):
return skip_if_32bit(reason="Reference system unavailable in i386 container")(func)


def requires_package(*packages):
"""Mark a test as requiring python packages to run.

Expand Down
1 change: 0 additions & 1 deletion python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@
# under the License.
# pylint: disable=unused-import
"""Intrinsics for tensorization."""
from . import arm_cpu, cuda, rocm, x86, hexagon
Loading
Loading