Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for ARM SVE2. #8051

Merged
merged 51 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
77ea0a0
Checkpoint SVE2 restart.
Dec 14, 2023
c203d1e
Remove dead code. Add new test.
Dec 14, 2023
27ee93e
Update cmake for new file.
Dec 14, 2023
bf0e925
Checkpoint progress on SVE2.
Dec 16, 2023
f40eeb5
Merge branch 'main' into arm_sve_redux
Jan 9, 2024
deb5fbc
Checkpoint ARM SVE2 support. Passes correctness_simd_op_check_sve2 te…
Jan 18, 2024
51c4568
Merge branch 'main' into arm_sve_redux
Jan 18, 2024
5f98675
Remove an opportunity for RISC V codegen to change due to SVE2 support.
Jan 18, 2024
1b8a75e
Ensure SVE intrinsics get vscale vectors and non-SVE ones get fixed v…
Jan 19, 2024
5eeef77
Checkpoint SVE2 work. Generally passes test, though using both NEON
Jan 26, 2024
f57f1d3
Remove an unfavored implementation possibility.
Jan 26, 2024
da3c259
Fix opcode recognition in test to handle some cases that show up.
Jan 29, 2024
06fa66c
Merge branch 'main' into arm_sve_redux
Jan 29, 2024
a069e6e
Formatting fixes.
Jan 29, 2024
1e8a540
Formatting fix.
Jan 29, 2024
93fb752
Limit SVE2 test to LLVM 19.
Jan 30, 2024
de11e8f
Fix a degenerate case asking for zero sized vectors via a HAlide type
Jan 31, 2024
9b2897c
Merge branch 'main' into arm_sve_redux
Feb 6, 2024
2bc10e3
Merge branch 'main' into arm_sve_redux
steven-johnson Feb 7, 2024
c598c9d
Merge branch 'main' into arm_sve_redux
Feb 7, 2024
bb73c00
Fix confusion about Neon64/Neon128 and make it clear this is just the
Feb 11, 2024
65fff76
Merge branch 'arm_sve_redux' of https://github.com/halide/Halide into…
Feb 11, 2024
93d7ba9
REmove extraneous commented out line.
Feb 11, 2024
ba934e9
Address some review feedback. Mostly comment fixes.
Feb 11, 2024
00cb4ce
Merge branch 'main' into arm_sve_redux
Feb 11, 2024
229bb60
Fix missed conflict resolution.
Feb 11, 2024
42206a5
Fix some TODOs in SVE code. Move utility function to Util.h and common
Feb 12, 2024
90186ad
Formatting.
Feb 12, 2024
bc149bc
Add missed refactor change.
Feb 12, 2024
79776e0
Add issue to TODO comment.
Feb 12, 2024
c3ca689
Remove TODOs that don't seem necessary.
Feb 13, 2024
b0e4f99
Add issue for TODO.
Feb 13, 2024
417d762
Add issue for TODO.
Feb 13, 2024
6e6e491
Merge branch 'main' into arm_sve_redux
Feb 15, 2024
e25a947
Merge branch 'main' into arm_sve_redux
Feb 21, 2024
fe30990
Remove dubious looking FP to int code that was ifdef'ed out. Doesn't
Feb 21, 2024
dc3be8a
Add issues for TODOs.
Feb 22, 2024
7627e0d
Merge branch 'main' into arm_sve_redux
Feb 22, 2024
4a269bd
Merge branch 'main' into arm_sve_redux
Feb 23, 2024
6afdcff
Update simd_op_check_sve2.cpp
steven-johnson Feb 23, 2024
b03b3c7
Merge branch 'main' into arm_sve_redux
steven-johnson Feb 23, 2024
f8952c2
Make a deep copy of each piece of test IR so that we can parallelize
abadams Feb 23, 2024
eaed2ef
Merge branch 'arm_sve_redux' of https://github.com/halide/Halide into…
Mar 5, 2024
2ac96c8
Merge branch 'main' into arm_sve_redux
steven-johnson Mar 5, 2024
4324bc5
Fix two clang-tidy warnings
steven-johnson Mar 5, 2024
a63439b
Remove try/catch block from simd-op-check-sve2
steven-johnson Mar 5, 2024
f84c764
Merge branch 'arm_sve_redux' of https://github.com/halide/Halide into…
Mar 6, 2024
210e5d7
Don't try to run SVE2 code if vector_bits doesn't match host.
Mar 6, 2024
9d8e2c6
Add support for fcvtm/p, make scalars go through pattern matching too…
abadams Mar 13, 2024
32d1fcb
Merge remote-tracking branch 'origin/main' into arm_sve_redux
abadams Mar 13, 2024
9dbfcd5
Don't do arm neon instruction selection on scalars
abadams Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,368 changes: 1,112 additions & 256 deletions src/CodeGen_ARM.cpp

Large diffs are not rendered by default.

226 changes: 182 additions & 44 deletions src/CodeGen_LLVM.cpp

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/CodeGen_LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,13 @@ class CodeGen_LLVM : public IRVisitor {
llvm::Constant *get_splat(int lanes, llvm::Constant *value,
VectorTypeConstraint type_constraint = VectorTypeConstraint::None) const;

/** Make sure a value type has the same scalable/fixed vector type as a guide. */
// @{
llvm::Value *match_vector_type_scalable(llvm::Value *value, VectorTypeConstraint constraint);
llvm::Value *match_vector_type_scalable(llvm::Value *value, llvm::Type *guide);
llvm::Value *match_vector_type_scalable(llvm::Value *value, llvm::Value *guide);
// @}

/** Support for generating LLVM vector predication intrinsics
* ("@llvm.vp.*" and "@llvm.experimental.vp.*")
*/
Expand Down
1 change: 1 addition & 0 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ const char *const intrinsic_op_names[] = {
"widening_shift_left",
"widening_shift_right",
"widening_sub",
"get_runtime_vscale",
};

static_assert(sizeof(intrinsic_op_names) / sizeof(intrinsic_op_names[0]) == Call::IntrinsicOpCount,
Expand Down
2 changes: 2 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,8 @@ struct Call : public ExprNode<Call> {
widening_shift_right,
widening_sub,

get_runtime_vscale,

IntrinsicOpCount // Sentinel: keep last.
};

Expand Down
6 changes: 6 additions & 0 deletions src/LLVM_Output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ std::unique_ptr<llvm::Module> clone_module(const llvm::Module &module_in) {
// Read it back in.
llvm::MemoryBufferRef buffer_ref(llvm::StringRef(clone_buffer.data(), clone_buffer.size()), "clone_buffer");
auto cloned_module = llvm::parseBitcodeFile(buffer_ref, module_in.getContext());

// TODO(<add issue>): Add support for returning the error.
if (!cloned_module) {
llvm::dbgs() << cloned_module.takeError();
module_in.print(llvm::dbgs(), nullptr, false, true);
}
internal_assert(cloned_module);

return std::move(cloned_module.get());
Expand Down
5 changes: 1 addition & 4 deletions src/StorageFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,14 @@
#include "Monotonic.h"
#include "Simplify.h"
#include "Substitute.h"
#include "Util.h"
#include <utility>

namespace Halide {
namespace Internal {

namespace {

int64_t next_power_of_two(int64_t x) {
return static_cast<int64_t>(1) << static_cast<int64_t>(std::ceil(std::log2(x)));
}

using std::map;
using std::string;
using std::vector;
Expand Down
6 changes: 6 additions & 0 deletions src/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
/** \file
* Various utility functions used internally Halide. */

#include <cmath>
#include <cstdint>
#include <cstring>
#include <functional>
Expand Down Expand Up @@ -532,6 +533,11 @@ int clz64(uint64_t x);
int ctz64(uint64_t x);
// @}

/** Return an integer 2^n, for some n, which is >= x. Argument x must be > 0. */
inline int64_t next_power_of_two(int64_t x) {
return static_cast<int64_t>(1) << static_cast<int64_t>(std::ceil(std::log2(x)));
}

} // namespace Internal
} // namespace Halide

Expand Down
5 changes: 5 additions & 0 deletions src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,10 @@ enum halide_error_code_t {
/** An explicit storage bound provided is too small to store
* all the values produced by the function. */
halide_error_code_storage_bound_too_small = -45,

/** "vscale" value of Scalable Vector detected in runtime does not match
* the vscale value used in compilation. */
halide_error_code_vscale_invalid = -46,
};

/** Halide calls the functions below on various error conditions. The
Expand Down Expand Up @@ -1316,6 +1320,7 @@ extern int halide_error_device_dirty_with_no_device_support(void *user_context,
extern int halide_error_storage_bound_too_small(void *user_context, const char *func_name, const char *var_name,
int provided_size, int required_size);
extern int halide_error_device_crop_failed(void *user_context);
extern int halide_error_vscale_invalid(void *user_context, const char *func_name, int runtime_vscale, int compiletime_vscale);
// @}

/** Optional features a compilation Target can have.
Expand Down
76 changes: 63 additions & 13 deletions src/runtime/aarch64.ll
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,34 @@ define weak_odr <2 x i64> @vabdl_u32x2(<2 x i32> %a, <2 x i32> %b) nounwind alwa

declare <4 x float> @llvm.aarch64.neon.frecpe.v4f32(<4 x float> %x) nounwind readnone;
declare <2 x float> @llvm.aarch64.neon.frecpe.v2f32(<2 x float> %x) nounwind readnone;
declare float @llvm.aarch64.neon.frecpe.f32(float)
declare <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x) nounwind readnone;
declare <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float> %x) nounwind readnone;
declare float @llvm.aarch64.neon.frsqrte.f32(float)
declare <4 x float> @llvm.aarch64.neon.frecps.v4f32(<4 x float> %x, <4 x float> %y) nounwind readnone;
declare <2 x float> @llvm.aarch64.neon.frecps.v2f32(<2 x float> %x, <2 x float> %y) nounwind readnone;
declare float @llvm.aarch64.neon.frecps.f32(float, float)
declare <4 x float> @llvm.aarch64.neon.frsqrts.v4f32(<4 x float> %x, <4 x float> %y) nounwind readnone;
declare <2 x float> @llvm.aarch64.neon.frsqrts.v2f32(<2 x float> %x, <2 x float> %y) nounwind readnone;
declare float @llvm.aarch64.neon.frsqrts.f32(float, float)

declare <8 x half> @llvm.aarch64.neon.frecpe.v8f16(<8 x half> %x) nounwind readnone;
declare <4 x half> @llvm.aarch64.neon.frecpe.v4f16(<4 x half> %x) nounwind readnone;
declare half @llvm.aarch64.neon.frecpe.f16(half)
declare <8 x half> @llvm.aarch64.neon.frsqrte.v8f16(<8 x half> %x) nounwind readnone;
declare <4 x half> @llvm.aarch64.neon.frsqrte.v4f16(<4 x half> %x) nounwind readnone;
declare half @llvm.aarch64.neon.frsqrte.f16(half)
declare <8 x half> @llvm.aarch64.neon.frecps.v8f16(<8 x half> %x, <8 x half> %y) nounwind readnone;
declare <4 x half> @llvm.aarch64.neon.frecps.v4f16(<4 x half> %x, <4 x half> %y) nounwind readnone;
declare half @llvm.aarch64.neon.frecps.f16(half, half)
declare <8 x half> @llvm.aarch64.neon.frsqrts.v8f16(<8 x half> %x, <8 x half> %y) nounwind readnone;
declare <4 x half> @llvm.aarch64.neon.frsqrts.v4f16(<4 x half> %x, <4 x half> %y) nounwind readnone;
declare half @llvm.aarch64.neon.frsqrts.f16(half, half)

define weak_odr float @fast_inverse_f32(float %x) nounwind alwaysinline {
%vec = insertelement <2 x float> poison, float %x, i32 0
%approx = tail call <2 x float> @fast_inverse_f32x2(<2 x float> %vec)
%result = extractelement <2 x float> %approx, i32 0
%approx = tail call float @llvm.aarch64.neon.frecpe.f32(float %x)
%correction = tail call float @llvm.aarch64.neon.frecps.f32(float %approx, float %x)
%result = fmul float %approx, %correction
ret float %result
}

Expand All @@ -85,9 +94,9 @@ define weak_odr <4 x float> @fast_inverse_f32x4(<4 x float> %x) nounwind alwaysi
}

define weak_odr half @fast_inverse_f16(half %x) nounwind alwaysinline {
%vec = insertelement <4 x half> poison, half %x, i32 0
%approx = tail call <4 x half> @fast_inverse_f16x4(<4 x half> %vec)
%result = extractelement <4 x half> %approx, i32 0
%approx = tail call half @llvm.aarch64.neon.frecpe.f16(half %x)
%correction = tail call half @llvm.aarch64.neon.frecps.f16(half %approx, half %x)
%result = fmul half %approx, %correction
ret half %result
}

Expand All @@ -106,9 +115,10 @@ define weak_odr <8 x half> @fast_inverse_f16x8(<8 x half> %x) nounwind alwaysinl
}

define weak_odr float @fast_inverse_sqrt_f32(float %x) nounwind alwaysinline {
%vec = insertelement <2 x float> poison, float %x, i32 0
%approx = tail call <2 x float> @fast_inverse_sqrt_f32x2(<2 x float> %vec)
%result = extractelement <2 x float> %approx, i32 0
%approx = tail call float @llvm.aarch64.neon.frsqrte.f32(float %x)
%approx2 = fmul float %approx, %approx
%correction = tail call float @llvm.aarch64.neon.frsqrts.f32(float %approx2, float %x)
%result = fmul float %approx, %correction
ret float %result
}

Expand All @@ -129,9 +139,10 @@ define weak_odr <4 x float> @fast_inverse_sqrt_f32x4(<4 x float> %x) nounwind al
}

define weak_odr half @fast_inverse_sqrt_f16(half %x) nounwind alwaysinline {
%vec = insertelement <4 x half> poison, half %x, i32 0
%approx = tail call <4 x half> @fast_inverse_sqrt_f16x4(<4 x half> %vec)
%result = extractelement <4 x half> %approx, i32 0
%approx = tail call half @llvm.aarch64.neon.frsqrte.f16(half %x)
%approx2 = fmul half %approx, %approx
%correction = tail call half @llvm.aarch64.neon.frsqrts.f16(half %approx2, half %x)
%result = fmul half %approx, %correction
ret half %result
}

Expand All @@ -149,4 +160,43 @@ define weak_odr <8 x half> @fast_inverse_sqrt_f16x8(<8 x half> %x) nounwind alwa
%correction = tail call <8 x half> @llvm.aarch64.neon.frsqrts.v8f16(<8 x half> %approx2, <8 x half> %x)
%result = fmul <8 x half> %approx, %correction
ret <8 x half> %result
}
}

declare <vscale x 4 x float> @llvm.aarch64.sve.frecpe.x.nxv4f32(<vscale x 4 x float> %x) nounwind readnone;
declare <vscale x 4 x float> @llvm.aarch64.sve.frsqrte.x.nxv4f32(<vscale x 4 x float> %x) nounwind readnone;
declare <vscale x 4 x float> @llvm.aarch64.sve.frecps.x.nxv4f32(<vscale x 4 x float> %x, <vscale x 4 x float> %y) nounwind readnone;
declare <vscale x 4 x float> @llvm.aarch64.sve.frsqrts.x.nxv4f32(<vscale x 4 x float> %x, <vscale x 4 x float> %y) nounwind readnone;
declare <vscale x 8 x half> @llvm.aarch64.sve.frecpe.x.nxv8f16(<vscale x 8 x half> %x) nounwind readnone;
declare <vscale x 8 x half> @llvm.aarch64.sve.frsqrte.x.nxv8f16(<vscale x 8 x half> %x) nounwind readnone;
declare <vscale x 8 x half> @llvm.aarch64.sve.frecps.x.nxv8f16(<vscale x 8 x half> %x, <vscale x 8 x half> %y) nounwind readnone;
declare <vscale x 8 x half> @llvm.aarch64.sve.frsqrts.x.nxv8f16(<vscale x 8 x half> %x, <vscale x 8 x half> %y) nounwind readnone;

define weak_odr <vscale x 4 x float> @fast_inverse_f32nx4(<vscale x 4 x float> %x) nounwind alwaysinline {
%approx = tail call <vscale x 4 x float> @llvm.aarch64.sve.frecpe.x.nxv4f32(<vscale x 4 x float> %x)
%correction = tail call <vscale x 4 x float> @llvm.aarch64.sve.frecps.x.nxv4f32(<vscale x 4 x float> %approx, <vscale x 4 x float> %x)
%result = fmul <vscale x 4 x float> %approx, %correction
ret <vscale x 4 x float> %result
}

define weak_odr <vscale x 8 x half> @fast_inverse_f16nx8(<vscale x 8 x half> %x) nounwind alwaysinline {
%approx = tail call <vscale x 8 x half> @llvm.aarch64.sve.frecpe.x.nxv8f16(<vscale x 8 x half> %x)
%correction = tail call <vscale x 8 x half> @llvm.aarch64.sve.frecps.x.nxv8f16(<vscale x 8 x half> %approx, <vscale x 8 x half> %x)
%result = fmul <vscale x 8 x half> %approx, %correction
ret <vscale x 8 x half> %result
}

define weak_odr <vscale x 4 x float> @fast_inverse_sqrt_f32nx4(<vscale x 4 x float> %x) nounwind alwaysinline {
%approx = tail call <vscale x 4 x float> @llvm.aarch64.sve.frsqrte.x.nxv4f32(<vscale x 4 x float> %x)
%approx2 = fmul <vscale x 4 x float> %approx, %approx
%correction = tail call <vscale x 4 x float> @llvm.aarch64.sve.frsqrts.x.nxv4f32(<vscale x 4 x float> %approx2, <vscale x 4 x float> %x)
%result = fmul <vscale x 4 x float> %approx, %correction
ret <vscale x 4 x float> %result
}

define weak_odr <vscale x 8 x half> @fast_inverse_sqrt_f16nx8(<vscale x 8 x half> %x) nounwind alwaysinline {
%approx = tail call <vscale x 8 x half> @llvm.aarch64.sve.frsqrte.x.nxv8f16(<vscale x 8 x half> %x)
%approx2 = fmul <vscale x 8 x half> %approx, %approx
%correction = tail call <vscale x 8 x half> @llvm.aarch64.sve.frsqrts.x.nxv8f16(<vscale x 8 x half> %approx2, <vscale x 8 x half> %x)
%result = fmul <vscale x 8 x half> %approx, %correction
ret <vscale x 8 x half> %result
}
8 changes: 8 additions & 0 deletions src/runtime/errors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,12 @@ WEAK int halide_error_device_crop_failed(void *user_context) {
return halide_error_code_device_crop_failed;
}

WEAK int halide_error_vscale_invalid(void *user_context, const char *func_name, int runtime_vscale, int compiletime_vscale) {
error(user_context)
<< "The function " << func_name
<< " is compiled with the assumption that vscale of Scalable Vector is " << compiletime_vscale
<< ". However, the detected runtime vscale is " << runtime_vscale << ".";
return halide_error_code_vscale_invalid;
}

} // extern "C"
28 changes: 27 additions & 1 deletion src/runtime/posix_math.ll
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,30 @@ define weak_odr double @neg_inf_f64() nounwind uwtable readnone alwaysinline {

define weak_odr double @nan_f64() nounwind uwtable readnone alwaysinline {
ret double 0x7FF8000000000000
}
}

; In case scalable vector with un-natural vector size, LLVM doesn't auto-vectorize the above scalar version
define weak_odr <vscale x 4 x float> @inf_f32nx4() nounwind uwtable readnone alwaysinline {
ret <vscale x 4 x float> shufflevector (<vscale x 4 x float> insertelement (<vscale x 4 x float> undef, float 0x7FF0000000000000, i32 0), <vscale x 4 x float> undef, <vscale x 4 x i32> zeroinitializer)
}

define weak_odr <vscale x 4 x float> @neg_inf_f32nx4() nounwind uwtable readnone alwaysinline {
ret <vscale x 4 x float> shufflevector (<vscale x 4 x float> insertelement (<vscale x 4 x float> undef, float 0xFFF0000000000000, i32 0), <vscale x 4 x float> undef, <vscale x 4 x i32> zeroinitializer)
}

define weak_odr <vscale x 4 x float> @nan_f32nx4() nounwind uwtable readnone alwaysinline {
ret <vscale x 4 x float> shufflevector (<vscale x 4 x float> insertelement (<vscale x 4 x float> undef, float 0x7FF8000000000000, i32 0), <vscale x 4 x float> undef, <vscale x 4 x i32> zeroinitializer)
}


define weak_odr <vscale x 2 x double> @inf_f64nx2() nounwind uwtable readnone alwaysinline {
ret <vscale x 2 x double> shufflevector (<vscale x 2 x double> insertelement (<vscale x 2 x double> undef, double 0x7FF0000000000000, i32 0), <vscale x 2 x double> undef, <vscale x 2 x i32> zeroinitializer)
}

define weak_odr <vscale x 2 x double> @neg_inf_f64nx2() nounwind uwtable readnone alwaysinline {
ret <vscale x 2 x double> shufflevector (<vscale x 2 x double> insertelement (<vscale x 2 x double> undef, double 0xFFF0000000000000, i32 0), <vscale x 2 x double> undef, <vscale x 2 x i32> zeroinitializer)
}

define weak_odr <vscale x 2 x double> @nan_f64nx2() nounwind uwtable readnone alwaysinline {
ret <vscale x 2 x double> shufflevector (<vscale x 2 x double> insertelement (<vscale x 2 x double> undef, double 0x7FF8000000000000, i32 0), <vscale x 2 x double> undef, <vscale x 2 x i32> zeroinitializer)
}
1 change: 1 addition & 0 deletions src/runtime/runtime_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = {
(void *)&halide_error_unaligned_host_ptr,
(void *)&halide_error_storage_bound_too_small,
(void *)&halide_error_device_crop_failed,
(void *)&halide_error_vscale_invalid,
(void *)&halide_float16_bits_to_double,
(void *)&halide_float16_bits_to_float,
(void *)&halide_free,
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ tests(GROUPS correctness
simd_op_check_hvx.cpp
simd_op_check_powerpc.cpp
simd_op_check_riscv.cpp
simd_op_check_sve2.cpp
simd_op_check_wasm.cpp
simd_op_check_x86.cpp
simplified_away_embedded_image.cpp
Expand Down
Loading
Loading