Skip to content

Commit

Permalink
[ESIMD] Implement the new non-experimental low-level API for DPAS (#6834
Browse files Browse the repository at this point in the history
)

* The new DPAS API are added to the new esimd::xmx (Xe Matrix eXtension)
namespace.
* The old/experimental DPAS API is marked as deprecated and now it
simply calls the new DPAS API.
* The DPAS emulation sequences has got the automatic detection of
the execution size instead of being defined through the macro
ESIMD_XE_HPC.

Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
  • Loading branch information
v-klochkov authored Sep 22, 2022
1 parent f9d8059 commit 55bf1a0
Show file tree
Hide file tree
Showing 7 changed files with 454 additions and 369 deletions.
1 change: 1 addition & 0 deletions sycl/include/sycl/ext/intel/esimd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
#include <sycl/ext/intel/esimd/detail/half_type_traits.hpp>
#include <sycl/ext/intel/esimd/simd.hpp>
#include <sycl/ext/intel/esimd/simd_view.hpp>
#include <sycl/ext/intel/esimd/xmx/dpas.hpp>
#include <sycl/ext/intel/experimental/esimd/kernel_properties.hpp>
#include <sycl/ext/intel/experimental/esimd/math.hpp>
#include <sycl/ext/intel/experimental/esimd/memory.hpp>
Expand Down
47 changes: 47 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/xmx/common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//==-------------- xmx/common.hpp - DPC++ Explicit SIMD API ----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Explicit SIMD API types used in ESIMD Intel Xe Matrix eXtension.
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/detail/defines_elementary.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::intel::esimd::xmx {

enum class dpas_argument_type {
Invalid = 0,
u1 = 1, // unsigned 1 bit
U1 __SYCL_DEPRECATED("use u1") = u1,
s1 = 2, // signed 1 bit
S1 __SYCL_DEPRECATED("use s1") = s1,
u2 = 3, // unsigned 2 bits
U2 __SYCL_DEPRECATED("use u2") = u2,
s2 = 4, // signed 2 bits
S2 __SYCL_DEPRECATED("use s2") = s2,
u4 = 5, // unsigned 4 bits
U4 __SYCL_DEPRECATED("use u4") = u4,
s4 = 6, // signed 4 bits
S4 __SYCL_DEPRECATED("use s4") = s4,
u8 = 7, // unsigned 8 bits
U8 __SYCL_DEPRECATED("use u8") = u8,
s8 = 8, // signed 8 bits
S8 __SYCL_DEPRECATED("use s8") = s8,
bf16 = 9, // bfloat 16
BF16 __SYCL_DEPRECATED("use bf16") = bf16,
fp16 = 10, // half float
FP16 __SYCL_DEPRECATED("use fp16") = fp16,
tf32 = 12, // tensorfloat 32
TF32 __SYCL_DEPRECATED("use tf32") = tf32
};

} // namespace ext::intel::esimd::xmx
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
340 changes: 340 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
//==----------------- xmx/dpas.hpp - DPC++ Explicit SIMD API ---------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Explicit SIMD API for DPAS Intel Xe Matrix eXtension.
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/detail/defines_elementary.hpp>
#include <sycl/ext/intel/esimd/detail/types.hpp>
#include <sycl/ext/intel/esimd/xmx/common.hpp>
#include <sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp>
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {

namespace ext::intel::esimd::xmx {

namespace detail {

template <typename T> constexpr dpas_argument_type dpas_precision_from_type() {
// TODO: add support for tfloat32 here.
if constexpr (std::is_same_v<T, sycl::half>)
return dpas_argument_type::FP16;
else if constexpr (std::is_same_v<T,
sycl::ext::oneapi::experimental::bfloat16>)
return dpas_argument_type::BF16;
else if constexpr (std::is_same_v<T, unsigned char>)
return dpas_argument_type::U8;
else if constexpr (__ESIMD_DNS::is_type<T, char, signed char>())
return dpas_argument_type::S8;
else
return dpas_argument_type::Invalid;
}

template <dpas_argument_type T> constexpr int dpas_bitsize_from_precision() {
if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2)
return 2;
else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4)
return 4;
else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8)
return 8;
else if constexpr (T == dpas_argument_type::BF16 ||
T == dpas_argument_type::FP16)
return 16;
else if constexpr (T == dpas_argument_type::TF32)
return 32;
else
return -1;
}

template <int RepeatCount, int AElemBitSize, int BElemBitSize, bool IsDPASW>
constexpr void verify_repeat_count() {
static_assert(RepeatCount >= 1 && RepeatCount <= 8,
"Repeat count must be within 1 to 8 range");

if constexpr (IsDPASW && RepeatCount != 8) {
static_assert(!(AElemBitSize == 2 && BElemBitSize > 4),
"Unsupported repeat count for DPASW operation");

static_assert(
RepeatCount == 4 ||
(AElemBitSize != 2 && (AElemBitSize != 4 || BElemBitSize <= 4)),
"Unsupported repeat count for DPASW operation");
}
}

template <int SystolicDepth, int RepeatCount, typename T, typename CT,
typename BT, typename AT, dpas_argument_type BPrecision,
dpas_argument_type APrecision, int BN, int AN, bool IsDPASW = false>
constexpr int verify_parameters_and_deduce_exec_size() {

static_assert(SystolicDepth == 8, "Systolic depth must be equal to 8");
static_assert(
APrecision != dpas_argument_type::Invalid &&
BPrecision != dpas_argument_type::Invalid,
"The types of dpas arguments are either incorrect or cannot be deduced."
"Fix the types and/or explicitly specify them.");

constexpr int AElemBitSize = dpas_bitsize_from_precision<APrecision>();
constexpr int BElemBitSize = dpas_bitsize_from_precision<BPrecision>();
static_assert(AElemBitSize != -1 && BElemBitSize != -1,
"Cannot deduce element size of input arguments");
verify_repeat_count<RepeatCount, AElemBitSize, BElemBitSize, IsDPASW>();

constexpr int OpsPerChannel =
std::min(32 / std::max(AElemBitSize, BElemBitSize), 8);

// A(_Mx_K) * B(_Kx_N) + C(_Mx_N)
// where:
// _M = RepeatCount;
// _K = SystolicDepth * OpsPerChannel;
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
constexpr int _M = RepeatCount;
constexpr int _K = SystolicDepth * OpsPerChannel;

// Compute _N (aka ExecutionSize) from the matrix B.
// It has _K*_N elements of BPrecision type, and BN elements of BT type
// hold those _K*_N*BPrecision bits, which let's us compute _N.
constexpr int BMatrixBitSize = sizeof(BT) * BN * 8;
constexpr int BNumElems = BMatrixBitSize / BElemBitSize;
constexpr int _N = BNumElems / _K;
static_assert(_K * _N == BNumElems, "Cannot deduce the execution size.");

// Now verify that AN elements of AT type hold exactly _M*_K elements
// of APrecision type/size. Similarly for B: BN elements of BT type must
// hold _K*_N elements of BPrecision type/size.
// DPASW accepts 2x less expected AN elements than regular DPAS.
constexpr int AFactorForDPASW = IsDPASW ? 2 : 1;
static_assert(_M * _K * AElemBitSize == AN * sizeof(AT) * 8 * AFactorForDPASW,
"The first matrix multiplier has wrong size.");
static_assert(_K * _N * BElemBitSize == BN * sizeof(BT) * 8,
"The second matrix multiplier has wrong size.");

// Execution size may be 8 or 16 depending on the target device.
// User must check if used execution size is supported before calling DPAS.
constexpr int ExecutionSize = _N;

static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16),
"Execution size must be 8 or 16 for DPAS and 8 for DPASW.");

if constexpr (APrecision == dpas_argument_type::FP16 ||
BPrecision == dpas_argument_type::FP16) {
if constexpr (ExecutionSize == 8) {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float>() &&
__ESIMD_DNS::is_type<CT, float>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f | f | hf | hf \n");
} else {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float, sycl::half>() &&
__ESIMD_DNS::is_type<CT, float, sycl::half>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f, hf | f, hf | hf | hf \n");
}
} else if constexpr (APrecision == dpas_argument_type::BF16 ||
BPrecision == dpas_argument_type::BF16) {
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
if constexpr (ExecutionSize == 8) {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float, bfloat16>() &&
__ESIMD_DNS::is_type<CT, float, bfloat16>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f | f | bf | bf \n");
} else {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float, bfloat16>() &&
__ESIMD_DNS::is_type<CT, float, bfloat16>(),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f, bf | f, bf | bf | bf \n");
}
} else if constexpr (APrecision == dpas_argument_type::TF32 ||
BPrecision == dpas_argument_type::TF32) {
static_assert(ExecutionSize == 16,
"tf32 type can be used only with ExecutionSize=16");
static_assert(APrecision == BPrecision && std::is_same_v<T, float> &&
std::is_same_v<CT, float>,
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" f | f | tf32 | tf32 \n");
} else {
static_assert((APrecision == dpas_argument_type::U2 ||
APrecision == dpas_argument_type::S2 ||
APrecision == dpas_argument_type::U4 ||
APrecision == dpas_argument_type::S4 ||
APrecision == dpas_argument_type::U8 ||
APrecision == dpas_argument_type::S8) &&
(BPrecision == dpas_argument_type::U2 ||
BPrecision == dpas_argument_type::S2 ||
BPrecision == dpas_argument_type::U4 ||
BPrecision == dpas_argument_type::S4 ||
BPrecision == dpas_argument_type::U8 ||
BPrecision == dpas_argument_type::S8),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n");
}
return ExecutionSize;
}

} // namespace detail

/// @defgroup sycl_esimd_xmx_systolic_array_api Systolic Array APIs.
/// APIs below are used to implement dot product accumulate systolic functions
/// @ingroup sycl_esimd

/// @addtogroup sycl_esimd_xmx_systolic_array_api
/// @{
/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = C + A x B;
/// @param C represents DPAS accumulator operand.
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename CT, typename BT,
typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int N, int BN, int AN>
__ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<CT, N> C,
__ESIMD_NS::simd<BT, BN> B,
__ESIMD_NS::simd<AT, AN> A) {
(void)detail::verify_parameters_and_deduce_exec_size<
SystolicDepth, RepeatCount, T, CT, BT, AT, BPrecision, APrecision, BN,
AN>();

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
using CRawT = typename __ESIMD_NS::simd<CT, N>::raw_element_type;
return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, T,
CRawT, int, int, N, BNCasted, ANCasted>(
C.data(), BCasted.data(), ACasted.data());
}

/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = A x B;
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int BN, int AN>
auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {

constexpr int ExecutionSize =
detail::verify_parameters_and_deduce_exec_size<SystolicDepth, RepeatCount,
T, T, BT, AT, BPrecision,
APrecision, BN, AN>();
// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
// where:
// _M = RepeatCount;
// _K = SystolicDepth * OpsPerChannel;
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
constexpr int ResultN = RepeatCount * ExecutionSize;

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
__ESIMD_NS::simd<T, ResultN> Result =
__esimd_dpas_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
BCasted.data(), ACasted.data());
return Result;
}

/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = C + A x B;
/// @param C represents DPAS accumulator operand.
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int N, int BN, int AN>
__ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<T, N> C,
__ESIMD_NS::simd<BT, BN> B,
__ESIMD_NS::simd<AT, AN> A) {

constexpr bool IsDPASW = true;
(void)detail::verify_parameters_and_deduce_exec_size<
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
IsDPASW>();

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
return __esimd_dpasw<Info, T, int, int, N, BNCasted, ANCasted>(
C.data(), BCasted.data(), ACasted.data());
}

/// DPAS (Dot Product Accumulate Systolic)
/// Computes the result of matrix operations: Result = A x B;
/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded
/// layout.
/// @param A represents the 1st matrix multiplier.
/// @return the vector value of DPAS computation result.
template <
int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
int BN, int AN>
auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {

constexpr bool IsDPASW = true;
constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size<
SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
IsDPASW>();

// Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
// where:
// _M = RepeatCount;
// _K = SystolicDepth * OpsPerChannel;
// _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
constexpr int ResultN = RepeatCount * ExecutionSize;

constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT));
constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT));
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
__ESIMD_NS::simd<T, ResultN> Result =
__esimd_dpasw_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
BCasted.data(), ACasted.data());
return Result;
}

/// @} sycl_esimd_xmx_systolic_array_api

} // namespace ext::intel::esimd::xmx
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
Loading

0 comments on commit 55bf1a0

Please sign in to comment.