Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Reintroduce experimental bfloat16 math functions #7567

Merged
merged 3 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,21 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext {
namespace oneapi {

class bfloat16;

namespace detail {
using Bfloat16StorageT = uint16_t;
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value);
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value);
} // namespace detail

class bfloat16 {
using storage_t = uint16_t;
storage_t value;
detail::Bfloat16StorageT value;

friend inline detail::Bfloat16StorageT
detail::bfloat16ToBits(const bfloat16 &Value);
friend inline bfloat16
detail::bitsToBfloat16(const detail::Bfloat16StorageT Value);

public:
bfloat16() = default;
Expand All @@ -36,7 +48,7 @@ class bfloat16 {

private:
// Explicit conversion functions
static storage_t from_float(const float &a) {
static detail::Bfloat16StorageT from_float(const float &a) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
#if (__CUDA_ARCH__ >= 800)
Expand Down Expand Up @@ -72,7 +84,7 @@ class bfloat16 {
#endif
}

static float to_float(const storage_t &a) {
static float to_float(const detail::Bfloat16StorageT &a) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
return __devicelib_ConvertBF16ToFINTEL(a);
#else
Expand All @@ -85,12 +97,6 @@ class bfloat16 {
#endif
}

static bfloat16 from_bits(const storage_t &a) {
bfloat16 res;
res.value = a;
return res;
}

public:
// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }
Expand Down Expand Up @@ -122,7 +128,7 @@ class bfloat16 {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
#if (__CUDA_ARCH__ >= 800)
return from_bits(__nvvm_neg_bf16(lhs.value));
return detail::bitsToBfloat16(__nvvm_neg_bf16(lhs.value));
#else
return -to_float(lhs.value);
#endif
Expand Down Expand Up @@ -203,6 +209,23 @@ class bfloat16 {
// for floating-point types.
};

namespace detail {

// Helper function for getting the internal representation of a bfloat16.
inline Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value) {
return Value.value;
}

// Helper function for creating a float16 from a value with the same type as the
// internal representation.
inline bfloat16 bitsToBfloat16(const Bfloat16StorageT Value) {
bfloat16 res;
res.value = Value;
return res;
}

} // namespace detail

} // namespace oneapi
} // namespace ext

Expand Down
208 changes: 208 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
//==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/detail/defines_elementary.hpp>
#include <sycl/exception.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>
#include <sycl/marray.hpp>

#include <cstring>
#include <tuple>
#include <type_traits>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext {
namespace oneapi {
namespace experimental {

namespace detail {
template <size_t N>
uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
uint32_t res;
std::memcpy(&res, &x[start], sizeof(uint32_t));
return res;
}
} // namespace detail

template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
#else
std::ignore = x;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;

for (size_t i = 0; i < N / 2; i++) {
auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
}
return res;
#else
std::ignore = x;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
#else
std::ignore = x;
std::ignore = y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
sycl::marray<bfloat16, N> y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;

for (size_t i = 0; i < N / 2; i++) {
auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
detail::to_uint32_t(y, i * 2));
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
oneapi::detail::Bfloat16StorageT YBits =
oneapi::detail::bfloat16ToBits(y[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
}

return res;
#else
std::ignore = x;
std::ignore = y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
#else
std::ignore = x;
std::ignore = y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
sycl::marray<bfloat16, N> y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;

for (size_t i = 0; i < N / 2; i++) {
auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
detail::to_uint32_t(y, i * 2));
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
oneapi::detail::Bfloat16StorageT YBits =
oneapi::detail::bfloat16ToBits(y[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
}
return res;
#else
std::ignore = x;
std::ignore = y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits(z);
return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
#else
std::ignore = x;
std::ignore = y;
std::ignore = z;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
sycl::marray<bfloat16, N> y,
sycl::marray<bfloat16, N> z) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;

for (size_t i = 0; i < N / 2; i++) {
auto partial_res =
__clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
detail::to_uint32_t(z, i * 2));
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
}

if (N % 2) {
oneapi::detail::Bfloat16StorageT XBits =
oneapi::detail::bfloat16ToBits(x[N - 1]);
oneapi::detail::Bfloat16StorageT YBits =
oneapi::detail::bfloat16ToBits(y[N - 1]);
oneapi::detail::Bfloat16StorageT ZBits =
oneapi::detail::bfloat16ToBits(z[N - 1]);
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
}
return res;
#else
std::ignore = x;
std::ignore = y;
std::ignore = z;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

} // namespace experimental
} // namespace oneapi
} // namespace ext
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
9 changes: 0 additions & 9 deletions sycl/include/sycl/ext/oneapi/experimental/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ namespace ext {
namespace oneapi {
namespace experimental {

namespace detail {
template <size_t N>
uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
uint32_t res;
std::memcpy(&res, &x[start], sizeof(uint32_t));
return res;
}
} // namespace detail

// Provides functionality to print data from kernels in a C way:
// - On non-host devices this function is directly mapped to printf from
// OpenCL C
Expand Down
Loading