-
Notifications
You must be signed in to change notification settings - Fork 745
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ESIMD] Implement the new non-experimental low-level API for DPAS (#6834
) * The new DPAS API are added to the new esimd::xmx (Xe Matrix eXtension) namespace. * The old/experimental DPAS API is marked as deprecated and now it simply calls the new DPAS API. * The DPAS emulation sequences has got the automatic detection of the execution size instead of being defined through the macro ESIMD_XE_HPC. Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
- Loading branch information
1 parent
f9d8059
commit 55bf1a0
Showing
7 changed files
with
454 additions
and
369 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
//==-------------- xmx/common.hpp - DPC++ Explicit SIMD API ----------------==// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// Explicit SIMD API types used in ESIMD Intel Xe Matrix eXtension. | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include <sycl/detail/defines_elementary.hpp> | ||
|
||
namespace sycl { | ||
__SYCL_INLINE_VER_NAMESPACE(_V1) { | ||
namespace ext::intel::esimd::xmx { | ||
|
||
enum class dpas_argument_type { | ||
Invalid = 0, | ||
u1 = 1, // unsigned 1 bit | ||
U1 __SYCL_DEPRECATED("use u1") = u1, | ||
s1 = 2, // signed 1 bit | ||
S1 __SYCL_DEPRECATED("use s1") = s1, | ||
u2 = 3, // unsigned 2 bits | ||
U2 __SYCL_DEPRECATED("use u2") = u2, | ||
s2 = 4, // signed 2 bits | ||
S2 __SYCL_DEPRECATED("use s2") = s2, | ||
u4 = 5, // unsigned 4 bits | ||
U4 __SYCL_DEPRECATED("use u4") = u4, | ||
s4 = 6, // signed 4 bits | ||
S4 __SYCL_DEPRECATED("use s4") = s4, | ||
u8 = 7, // unsigned 8 bits | ||
U8 __SYCL_DEPRECATED("use u8") = u8, | ||
s8 = 8, // signed 8 bits | ||
S8 __SYCL_DEPRECATED("use s8") = s8, | ||
bf16 = 9, // bfloat 16 | ||
BF16 __SYCL_DEPRECATED("use bf16") = bf16, | ||
fp16 = 10, // half float | ||
FP16 __SYCL_DEPRECATED("use fp16") = fp16, | ||
tf32 = 12, // tensorfloat 32 | ||
TF32 __SYCL_DEPRECATED("use tf32") = tf32 | ||
}; | ||
|
||
} // namespace ext::intel::esimd::xmx | ||
} // __SYCL_INLINE_VER_NAMESPACE(_V1) | ||
} // namespace sycl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,340 @@ | ||
//==----------------- xmx/dpas.hpp - DPC++ Explicit SIMD API ---------------==// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// Explicit SIMD API for DPAS Intel Xe Matrix eXtension. | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include <sycl/detail/defines_elementary.hpp> | ||
#include <sycl/ext/intel/esimd/detail/types.hpp> | ||
#include <sycl/ext/intel/esimd/xmx/common.hpp> | ||
#include <sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp> | ||
#include <sycl/ext/oneapi/experimental/bfloat16.hpp> | ||
|
||
namespace sycl { | ||
__SYCL_INLINE_VER_NAMESPACE(_V1) { | ||
|
||
namespace ext::intel::esimd::xmx { | ||
|
||
namespace detail { | ||
|
||
template <typename T> constexpr dpas_argument_type dpas_precision_from_type() { | ||
// TODO: add support for tfloat32 here. | ||
if constexpr (std::is_same_v<T, sycl::half>) | ||
return dpas_argument_type::FP16; | ||
else if constexpr (std::is_same_v<T, | ||
sycl::ext::oneapi::experimental::bfloat16>) | ||
return dpas_argument_type::BF16; | ||
else if constexpr (std::is_same_v<T, unsigned char>) | ||
return dpas_argument_type::U8; | ||
else if constexpr (__ESIMD_DNS::is_type<T, char, signed char>()) | ||
return dpas_argument_type::S8; | ||
else | ||
return dpas_argument_type::Invalid; | ||
} | ||
|
||
template <dpas_argument_type T> constexpr int dpas_bitsize_from_precision() { | ||
if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2) | ||
return 2; | ||
else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4) | ||
return 4; | ||
else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8) | ||
return 8; | ||
else if constexpr (T == dpas_argument_type::BF16 || | ||
T == dpas_argument_type::FP16) | ||
return 16; | ||
else if constexpr (T == dpas_argument_type::TF32) | ||
return 32; | ||
else | ||
return -1; | ||
} | ||
|
||
template <int RepeatCount, int AElemBitSize, int BElemBitSize, bool IsDPASW> | ||
constexpr void verify_repeat_count() { | ||
static_assert(RepeatCount >= 1 && RepeatCount <= 8, | ||
"Repeat count must be within 1 to 8 range"); | ||
|
||
if constexpr (IsDPASW && RepeatCount != 8) { | ||
static_assert(!(AElemBitSize == 2 && BElemBitSize > 4), | ||
"Unsupported repeat count for DPASW operation"); | ||
|
||
static_assert( | ||
RepeatCount == 4 || | ||
(AElemBitSize != 2 && (AElemBitSize != 4 || BElemBitSize <= 4)), | ||
"Unsupported repeat count for DPASW operation"); | ||
} | ||
} | ||
|
||
template <int SystolicDepth, int RepeatCount, typename T, typename CT, | ||
typename BT, typename AT, dpas_argument_type BPrecision, | ||
dpas_argument_type APrecision, int BN, int AN, bool IsDPASW = false> | ||
constexpr int verify_parameters_and_deduce_exec_size() { | ||
|
||
static_assert(SystolicDepth == 8, "Systolic depth must be equal to 8"); | ||
static_assert( | ||
APrecision != dpas_argument_type::Invalid && | ||
BPrecision != dpas_argument_type::Invalid, | ||
"The types of dpas arguments are either incorrect or cannot be deduced." | ||
"Fix the types and/or explicitly specify them."); | ||
|
||
constexpr int AElemBitSize = dpas_bitsize_from_precision<APrecision>(); | ||
constexpr int BElemBitSize = dpas_bitsize_from_precision<BPrecision>(); | ||
static_assert(AElemBitSize != -1 && BElemBitSize != -1, | ||
"Cannot deduce element size of input arguments"); | ||
verify_repeat_count<RepeatCount, AElemBitSize, BElemBitSize, IsDPASW>(); | ||
|
||
constexpr int OpsPerChannel = | ||
std::min(32 / std::max(AElemBitSize, BElemBitSize), 8); | ||
|
||
// A(_Mx_K) * B(_Kx_N) + C(_Mx_N) | ||
// where: | ||
// _M = RepeatCount; | ||
// _K = SystolicDepth * OpsPerChannel; | ||
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16. | ||
constexpr int _M = RepeatCount; | ||
constexpr int _K = SystolicDepth * OpsPerChannel; | ||
|
||
// Compute _N (aka ExecutionSize) from the matrix B. | ||
// It has _K*_N elements of BPrecision type, and BN elements of BT type | ||
// hold those _K*_N*BPrecision bits, which let's us compute _N. | ||
constexpr int BMatrixBitSize = sizeof(BT) * BN * 8; | ||
constexpr int BNumElems = BMatrixBitSize / BElemBitSize; | ||
constexpr int _N = BNumElems / _K; | ||
static_assert(_K * _N == BNumElems, "Cannot deduce the execution size."); | ||
|
||
// Now verify that AN elements of AT type hold exactly _M*_K elements | ||
// of APrecision type/size. Similarly for B: BN elements of BT type must | ||
// hold _K*_N elements of BPrecision type/size. | ||
// DPASW accepts 2x less expected AN elements than regular DPAS. | ||
constexpr int AFactorForDPASW = IsDPASW ? 2 : 1; | ||
static_assert(_M * _K * AElemBitSize == AN * sizeof(AT) * 8 * AFactorForDPASW, | ||
"The first matrix multiplier has wrong size."); | ||
static_assert(_K * _N * BElemBitSize == BN * sizeof(BT) * 8, | ||
"The second matrix multiplier has wrong size."); | ||
|
||
// Execution size may be 8 or 16 depending on the target device. | ||
// User must check if used execution size is supported before calling DPAS. | ||
constexpr int ExecutionSize = _N; | ||
|
||
static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16), | ||
"Execution size must be 8 or 16 for DPAS and 8 for DPASW."); | ||
|
||
if constexpr (APrecision == dpas_argument_type::FP16 || | ||
BPrecision == dpas_argument_type::FP16) { | ||
if constexpr (ExecutionSize == 8) { | ||
static_assert(APrecision == BPrecision && | ||
__ESIMD_DNS::is_type<T, float>() && | ||
__ESIMD_DNS::is_type<CT, float>(), | ||
"Unsupported DPAS types! The supported types are:\n" | ||
" Result | C | B | A \n" | ||
" f | f | hf | hf \n"); | ||
} else { | ||
static_assert(APrecision == BPrecision && | ||
__ESIMD_DNS::is_type<T, float, sycl::half>() && | ||
__ESIMD_DNS::is_type<CT, float, sycl::half>(), | ||
"Unsupported DPAS types! The supported types are:\n" | ||
" Result | C | B | A \n" | ||
" f, hf | f, hf | hf | hf \n"); | ||
} | ||
} else if constexpr (APrecision == dpas_argument_type::BF16 || | ||
BPrecision == dpas_argument_type::BF16) { | ||
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; | ||
if constexpr (ExecutionSize == 8) { | ||
static_assert(APrecision == BPrecision && | ||
__ESIMD_DNS::is_type<T, float, bfloat16>() && | ||
__ESIMD_DNS::is_type<CT, float, bfloat16>(), | ||
"Unsupported DPAS types! The supported types are:\n" | ||
" Result | C | B | A \n" | ||
" f | f | bf | bf \n"); | ||
} else { | ||
static_assert(APrecision == BPrecision && | ||
__ESIMD_DNS::is_type<T, float, bfloat16>() && | ||
__ESIMD_DNS::is_type<CT, float, bfloat16>(), | ||
"Unsupported DPAS types! The supported types are:\n" | ||
" Result | C | B | A \n" | ||
" f, bf | f, bf | bf | bf \n"); | ||
} | ||
} else if constexpr (APrecision == dpas_argument_type::TF32 || | ||
BPrecision == dpas_argument_type::TF32) { | ||
static_assert(ExecutionSize == 16, | ||
"tf32 type can be used only with ExecutionSize=16"); | ||
static_assert(APrecision == BPrecision && std::is_same_v<T, float> && | ||
std::is_same_v<CT, float>, | ||
"Unsupported DPAS types! The supported types are:\n" | ||
" Result | C | B | A \n" | ||
" f | f | tf32 | tf32 \n"); | ||
} else { | ||
static_assert((APrecision == dpas_argument_type::U2 || | ||
APrecision == dpas_argument_type::S2 || | ||
APrecision == dpas_argument_type::U4 || | ||
APrecision == dpas_argument_type::S4 || | ||
APrecision == dpas_argument_type::U8 || | ||
APrecision == dpas_argument_type::S8) && | ||
(BPrecision == dpas_argument_type::U2 || | ||
BPrecision == dpas_argument_type::S2 || | ||
BPrecision == dpas_argument_type::U4 || | ||
BPrecision == dpas_argument_type::S4 || | ||
BPrecision == dpas_argument_type::U8 || | ||
BPrecision == dpas_argument_type::S8), | ||
"Unsupported DPAS types! The supported types are:\n" | ||
" Result | C | B | A \n" | ||
" ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n"); | ||
} | ||
return ExecutionSize; | ||
} | ||
|
||
} // namespace detail | ||
|
||
/// @defgroup sycl_esimd_xmx_systolic_array_api Systolic Array APIs. | ||
/// APIs below are used to implement dot product accumulate systolic functions | ||
/// @ingroup sycl_esimd | ||
|
||
/// @addtogroup sycl_esimd_xmx_systolic_array_api | ||
/// @{ | ||
/// DPAS (Dot Product Accumulate Systolic) | ||
/// Computes the result of matrix operations: Result = C + A x B; | ||
/// @param C represents DPAS accumulator operand. | ||
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded | ||
/// layout. | ||
/// @param A represents the 1st matrix multiplier. | ||
/// @return the vector value of DPAS computation result. | ||
template < | ||
int SystolicDepth, int RepeatCount, typename T, typename CT, typename BT, | ||
typename AT, | ||
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(), | ||
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(), | ||
int N, int BN, int AN> | ||
__ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<CT, N> C, | ||
__ESIMD_NS::simd<BT, BN> B, | ||
__ESIMD_NS::simd<AT, AN> A) { | ||
(void)detail::verify_parameters_and_deduce_exec_size< | ||
SystolicDepth, RepeatCount, T, CT, BT, AT, BPrecision, APrecision, BN, | ||
AN>(); | ||
|
||
constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); | ||
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); | ||
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>(); | ||
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>(); | ||
using CRawT = typename __ESIMD_NS::simd<CT, N>::raw_element_type; | ||
return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, T, | ||
CRawT, int, int, N, BNCasted, ANCasted>( | ||
C.data(), BCasted.data(), ACasted.data()); | ||
} | ||
|
||
/// DPAS (Dot Product Accumulate Systolic) | ||
/// Computes the result of matrix operations: Result = A x B; | ||
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded | ||
/// layout. | ||
/// @param A represents the 1st matrix multiplier. | ||
/// @return the vector value of DPAS computation result. | ||
template < | ||
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, | ||
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(), | ||
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(), | ||
int BN, int AN> | ||
auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) { | ||
|
||
constexpr int ExecutionSize = | ||
detail::verify_parameters_and_deduce_exec_size<SystolicDepth, RepeatCount, | ||
T, T, BT, AT, BPrecision, | ||
APrecision, BN, AN>(); | ||
// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) | ||
// where: | ||
// _M = RepeatCount; | ||
// _K = SystolicDepth * OpsPerChannel; | ||
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16. | ||
constexpr int ResultN = RepeatCount * ExecutionSize; | ||
|
||
constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); | ||
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); | ||
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>(); | ||
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>(); | ||
|
||
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + | ||
((int)APrecision << 8) + (int)BPrecision; | ||
__ESIMD_NS::simd<T, ResultN> Result = | ||
__esimd_dpas_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>( | ||
BCasted.data(), ACasted.data()); | ||
return Result; | ||
} | ||
|
||
/// DPAS (Dot Product Accumulate Systolic) | ||
/// Computes the result of matrix operations: Result = C + A x B; | ||
/// @param C represents DPAS accumulator operand. | ||
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded | ||
/// layout. | ||
/// @param A represents the 1st matrix multiplier. | ||
/// @return the vector value of DPAS computation result. | ||
template < | ||
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, | ||
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(), | ||
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(), | ||
int N, int BN, int AN> | ||
__ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<T, N> C, | ||
__ESIMD_NS::simd<BT, BN> B, | ||
__ESIMD_NS::simd<AT, AN> A) { | ||
|
||
constexpr bool IsDPASW = true; | ||
(void)detail::verify_parameters_and_deduce_exec_size< | ||
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN, | ||
IsDPASW>(); | ||
|
||
constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); | ||
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); | ||
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>(); | ||
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>(); | ||
|
||
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + | ||
((int)APrecision << 8) + (int)BPrecision; | ||
return __esimd_dpasw<Info, T, int, int, N, BNCasted, ANCasted>( | ||
C.data(), BCasted.data(), ACasted.data()); | ||
} | ||
|
||
/// DPAS (Dot Product Accumulate Systolic) | ||
/// Computes the result of matrix operations: Result = A x B; | ||
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded | ||
/// layout. | ||
/// @param A represents the 1st matrix multiplier. | ||
/// @return the vector value of DPAS computation result. | ||
template < | ||
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, | ||
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(), | ||
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(), | ||
int BN, int AN> | ||
auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) { | ||
|
||
constexpr bool IsDPASW = true; | ||
constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size< | ||
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN, | ||
IsDPASW>(); | ||
|
||
// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) | ||
// where: | ||
// _M = RepeatCount; | ||
// _K = SystolicDepth * OpsPerChannel; | ||
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16. | ||
constexpr int ResultN = RepeatCount * ExecutionSize; | ||
|
||
constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); | ||
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); | ||
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>(); | ||
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>(); | ||
|
||
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + | ||
((int)APrecision << 8) + (int)BPrecision; | ||
__ESIMD_NS::simd<T, ResultN> Result = | ||
__esimd_dpasw_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>( | ||
BCasted.data(), ACasted.data()); | ||
return Result; | ||
} | ||
|
||
/// @} sycl_esimd_xmx_systolic_array_api | ||
|
||
} // namespace ext::intel::esimd::xmx | ||
} // __SYCL_INLINE_VER_NAMESPACE(_V1) | ||
} // namespace sycl |
Oops, something went wrong.