Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SVE] Add codegen support for vscale_range() function attribute #16962

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/target/llvm/codegen_aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <llvm/Target/TargetMachine.h>
#include <tvm/runtime/registry.h>

#include "../../arith/scalable_expression.h"
#include "codegen_cpu.h"
#include "llvm_instance.h"

Expand All @@ -40,6 +41,7 @@ class CodeGenAArch64 final : public CodeGenCPU {

void VisitStmt_(const AttrStmtNode* op);
void AddFunction(const GlobalVar& gvar, const PrimFunc& f);
void SetTargetAttributes(llvm::Function* func);

bool func_has_pstate_sm = false;
bool func_has_pstate_za = false;
Expand All @@ -51,6 +53,17 @@ void CodeGenAArch64::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
CodeGenCPU::AddFunction(gvar, f);
}

void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
#if TVM_LLVM_VERSION >= 130
// Add vscale_range() function attribute when appropriate.
if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) {
func->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
*llvm_target_->GetContext(), 1, tvm::arith::kAArch64VScaleValues.size()));
}
#endif
CodeGenCPU::SetTargetAttributes(func);
}

/*!
* \brief Visit and handle AArch64 specific pragmas. To be AArch64 specific,
* the expectation is that they are prepended with "pragma_aarch64".
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*
* \param func The function to set attributes on.
*/
void SetTargetAttributes(llvm::Function* func);
virtual void SetTargetAttributes(llvm::Function* func);
/*!
* \brief Emit LLVM IR for conversion functions __extendhfsf2 and __truncsfhf2
* into the current llvm::Module.
Expand Down
38 changes: 38 additions & 0 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,44 @@ def my_func(a: T.handle):
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."


@pytest.mark.skipif(
llvm_version_major() < 13,
reason="Function attribute vscale_range() is not supported in earlier versions of LLVM",
)
@pytest.mark.parametrize(
"mattr,expect_attr",
[
("+neon", False),
("+sve", True),
("+v9a", True),
("+sme", True),
],
)
def test_vscale_range_function_attribute(mattr, expect_attr):
target = f"llvm -mtriple=aarch64-linux-gnu -mattr={mattr}"

m = te.var("m")
A = te.placeholder(m, dtype="float32", name="A")
C = te.compute((m), lambda i: A[i] + 1, name="C")
s = te.create_schedule([C.op])

with tvm.target.Target(target) as target:
f = tvm.build(s, [A, C], target)

# Check if the vscale_range() attribute exists
ll = f.get_source("ll")
attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll)

if expect_attr:
assert (
len(attr) > 0
), f"Function attribute vscale_range() was not found in generated LLVM IR"
else:
assert (
len(attr) == 0
), f"Unexpected function attribute vscale_range() was found in generated LLVM IR"


@pytest.mark.skipif(
llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME"
)
Expand Down
Loading