Skip to content

Commit

Permalink
[SYCL] Add some trivial util functions for half type. (#6691)
Browse files Browse the repository at this point in the history
The usage of sycl::half and CUDA half is different, sycl::half overrides
arithmetic and comparison operators(+, -, *, /, >, <,==.!=...) while
CUDA math provides those arithmetic and comparison functionalities via a
bunch of util functions. This PR provides some sycl::half util functions
aligning with CUDA math, users who are porting half related CUDA code to
SYCL will spend less effort with those util functions' help.

Signed-off-by: jinge90 <ge.jin@intel.com>
  • Loading branch information
jinge90 authored Nov 1, 2022
1 parent a776f3c commit b4ce7c0
Show file tree
Hide file tree
Showing 2 changed files with 320 additions and 0 deletions.
1 change: 1 addition & 0 deletions sycl/include/sycl/ext/intel/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#pragma once
#include <sycl/ext/intel/math/imf_half_trivial.hpp>
#include <sycl/half_type.hpp>
#include <type_traits>

Expand Down
319 changes: 319 additions & 0 deletions sycl/include/sycl/ext/intel/math/imf_half_trivial.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
//==------------- imf_half_trivial.hpp - trivial half utils ----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Trival half util functions.
//===----------------------------------------------------------------------===//

#pragma once
#include <sycl/half_type.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext {
namespace intel {
namespace math {
sycl::half hadd(sycl::half x, sycl::half y) { return x + y; }

sycl::half hadd_sat(sycl::half x, sycl::half y) {
return sycl::clamp((x + y), sycl::half(0.f), sycl::half(1.0f));
}

sycl::half hfma(sycl::half x, sycl::half y, sycl::half z) {
return sycl::fma(x, y, z);
}

sycl::half hfma_sat(sycl::half x, sycl::half y, sycl::half z) {
return sycl::clamp(sycl::fma(x, y, z), sycl::half(0.f), sycl::half(1.0f));
}

sycl::half hmul(sycl::half x, sycl::half y) { return x * y; }

sycl::half hmul_sat(sycl::half x, sycl::half y) {
return sycl::clamp((x * y), sycl::half(0.f), sycl::half(1.0f));
}

sycl::half hneg(sycl::half x) { return -x; }

sycl::half hsub(sycl::half x, sycl::half y) { return x - y; }

sycl::half hsub_sat(sycl::half x, sycl::half y) {
return sycl::clamp((x - y), sycl::half(0.f), sycl::half(1.0f));
}

sycl::half hdiv(sycl::half x, sycl::half y) { return x / y; }

bool heq(sycl::half x, sycl::half y) { return x == y; }

bool hequ(sycl::half x, sycl::half y) {
if (sycl::isnan(x) || sycl::isnan(y))
return true;
else
return x == y;
}

bool hge(sycl::half x, sycl::half y) { return x >= y; }

bool hgeu(sycl::half x, sycl::half y) {
if (sycl::isnan(x) || sycl::isnan(y))
return true;
else
return x >= y;
}

bool hgt(sycl::half x, sycl::half y) { return x > y; }

bool hgtu(sycl::half x, sycl::half y) {
if (sycl::isnan(x) || sycl::isnan(y))
return true;
else
return x > y;
}

bool hle(sycl::half x, sycl::half y) { return x <= y; }

bool hleu(sycl::half x, sycl::half y) {
if (sycl::isnan(x) || sycl::isnan(y))
return true;
else
return x <= y;
}

bool hlt(sycl::half x, sycl::half y) { return x < y; }

bool hltu(sycl::half x, sycl::half y) {
if (sycl::isnan(x) || sycl::isnan(y))
return true;
return x < y;
}

bool hne(sycl::half x, sycl::half y) {
if (sycl::isnan(x) || sycl::isnan(y))
return false;
return x != y;
}

bool hneu(sycl::half x, sycl::half y) {
if (sycl::isnan(x) || sycl::isnan(y))
return true;
else
return x != y;
}

bool hisinf(sycl::half x) { return sycl::isinf(x); }
bool hisnan(sycl::half y) { return sycl::isnan(y); }

sycl::half2 hadd2(sycl::half2 x, sycl::half2 y) { return x + y; }

sycl::half2 hadd2_sat(sycl::half2 x, sycl::half2 y) {
return sycl::clamp((x + y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
}

sycl::half2 hfma2(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
return sycl::fma(x, y, z);
}

sycl::half2 hfma2_sat(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
return sycl::clamp(sycl::fma(x, y, z), sycl::half2{0.f, 0.f},
sycl::half2{1.f, 1.f});
}

sycl::half2 hmul2(sycl::half2 x, sycl::half2 y) { return x * y; }

sycl::half2 hmul2_sat(sycl::half2 x, sycl::half2 y) {
return sycl::clamp((x * y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
}

sycl::half2 h2div(sycl::half2 x, sycl::half2 y) { return x / y; }

sycl::half2 hneg2(sycl::half2 x) { return -x; }

sycl::half2 hsub2(sycl::half2 x, sycl::half2 y) { return x - y; }

sycl::half2 hsub2_sat(sycl::half2 x, sycl::half2 y) {
return sycl::clamp((x - y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
}

bool hbeq2(sycl::half2 x, sycl::half2 y) {
return heq(x.s0(), y.s0()) && heq(x.s1(), y.s1());
}

bool hbequ2(sycl::half2 x, sycl::half2 y) {
return hequ(x.s0(), y.s0()) && hequ(x.s1(), y.s1());
}

bool hbge2(sycl::half2 x, sycl::half2 y) {
return hge(x.s0(), y.s0()) && hge(x.s1(), y.s1());
}

bool hbgeu2(sycl::half2 x, sycl::half2 y) {
return hgeu(x.s0(), y.s0()) && hgeu(x.s1(), y.s1());
}

bool hbgt2(sycl::half2 x, sycl::half2 y) {
return hgt(x.s0(), y.s0()) && hgt(x.s1(), y.s1());
}

bool hbgtu2(sycl::half2 x, sycl::half2 y) {
return hgtu(x.s0(), y.s0()) && hgtu(x.s1(), y.s1());
}

bool hble2(sycl::half2 x, sycl::half2 y) {
return hle(x.s0(), y.s0()) && hle(x.s1(), y.s1());
}

bool hbleu2(sycl::half2 x, sycl::half2 y) {
return hleu(x.s0(), y.s0()) && hleu(x.s1(), y.s1());
}

bool hblt2(sycl::half2 x, sycl::half2 y) {
return hlt(x.s0(), y.s0()) && hlt(x.s1(), y.s1());
}

bool hbltu2(sycl::half2 x, sycl::half2 y) {
return hltu(x.s0(), y.s0()) && hltu(x.s1(), y.s1());
}

bool hbne2(sycl::half2 x, sycl::half2 y) {
return hne(x.s0(), y.s0()) && hne(x.s1(), y.s1());
}

bool hbneu2(sycl::half2 x, sycl::half2 y) {
return hneu(x.s0(), y.s0()) && hneu(x.s1(), y.s1());
}

sycl::half2 heq2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(heq(x.s0(), y.s0()) ? 1.0f : 0.f),
(heq(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hequ2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hequ(x.s0(), y.s0()) ? 1.0f : 0.f),
(hequ(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hge2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hge(x.s0(), y.s0()) ? 1.0f : 0.f),
(hge(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hgeu2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hgeu(x.s0(), y.s0()) ? 1.0f : 0.f),
(hgeu(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hgt2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hgt(x.s0(), y.s0()) ? 1.0f : 0.f),
(hgt(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hgtu2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hgtu(x.s0(), y.s0()) ? 1.0f : 0.f),
(hgtu(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hle2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hle(x.s0(), y.s0()) ? 1.0f : 0.f),
(hle(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hleu2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hleu(x.s0(), y.s0()) ? 1.0f : 0.f),
(hleu(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hlt2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hlt(x.s0(), y.s0()) ? 1.0f : 0.f),
(hlt(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hltu2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hltu(x.s0(), y.s0()) ? 1.0f : 0.f),
(hltu(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hisnan2(sycl::half2 x) {
return sycl::half2{(hisnan(x.s0()) ? 1.0f : 0.f),
(hisnan(x.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hne2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hne(x.s0(), y.s0()) ? 1.0f : 0.f),
(hne(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half2 hneu2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{(hneu(x.s0(), y.s0()) ? 1.0f : 0.f),
(hneu(x.s1(), y.s1()) ? 1.0f : 0.f)};
}

sycl::half hmax(sycl::half x, sycl::half y) { return sycl::fmax(x, y); }

sycl::half hmax_nan(sycl::half x, sycl::half y) {
if (hisnan(x) || hisnan(y))
return sycl::half(NAN);
else
return sycl::fmax(x, y);
}

sycl::half2 hmax2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{hmax(x.s0(), y.s0()), hmax(x.s1(), y.s1())};
}

sycl::half2 hmax2_nan(sycl::half2 x, sycl::half2 y) {
return sycl::half2{hmax_nan(x.s0(), y.s0()), hmax_nan(x.s1(), y.s1())};
}

sycl::half hmin(sycl::half x, sycl::half y) { return sycl::fmin(x, y); }

sycl::half hmin_nan(sycl::half x, sycl::half y) {
if (hisnan(x) || hisnan(y))
return sycl::half(NAN);
else
return sycl::fmin(x, y);
}

sycl::half2 hmin2(sycl::half2 x, sycl::half2 y) {
return sycl::half2{hmin(x.s0(), y.s0()), hmin(x.s1(), y.s1())};
}

sycl::half2 hmin2_nan(sycl::half2 x, sycl::half2 y) {
return sycl::half2{hmin_nan(x.s0(), y.s0()), hmin_nan(x.s1(), y.s1())};
}

sycl::half2 hcmadd(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
return sycl::half2{x.s0() * y.s0() - x.s1() * y.s1() + z.s0(),
x.s0() * y.s1() + x.s1() * y.s0() + z.s1()};
}

sycl::half hfma_relu(sycl::half x, sycl::half y, sycl::half z) {
sycl::half r = sycl::fma(x, y, z);
if (!hisnan(r)) {
if (r < 0.f)
return sycl::half{0.f};
else
return r;
}
return r;
}

sycl::half2 hfma2_relu(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
sycl::half2 r = sycl::fma(x, y, z);
if (!hisnan(r.s0()) && r.s0() < 0.f)
r.s0() = 0.f;
if (!hisnan(r.s1()) && r.s1() < 0.f)
r.s1() = 0.f;
return r;
}

sycl::half habs(sycl::half x) { return sycl::fabs(x); }

sycl::half2 habs2(sycl::half2 x) { return sycl::fabs(x); }
} // namespace math
} // namespace intel
} // namespace ext
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl

0 comments on commit b4ce7c0

Please sign in to comment.