Skip to content

Commit

Permalink
[Feature] Add bfloat16 support for CPU (dmlc#5497)
Browse files Browse the repository at this point in the history
Co-authored-by: Hongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
  • Loading branch information
2 people authored and DominikaJedynak committed Mar 12, 2024
1 parent 503576f commit df3f316
Show file tree
Hide file tree
Showing 16 changed files with 338 additions and 58 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ endif(NOT MSVC)
# Compile LIBXSMM
if((NOT MSVC) AND USE_LIBXSMM)
if(REBUILD_LIBXSMM)
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0 CC=${CMAKE_C_COMPILER}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
else(REBUILD_LIBXSMM)
add_custom_target(libxsmm COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0
add_custom_target(libxsmm COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0 CC=${CMAKE_C_COMPILER}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
endif(REBUILD_LIBXSMM)
Expand Down
79 changes: 53 additions & 26 deletions include/dgl/aten/macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,42 +152,69 @@
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{ __VA_ARGS__ } \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0)
#else // BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#endif // BF16_ENABLED
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be bfloat16/float32/float64 on CPU"; \
} \
} while (0)
#endif // DGL_USE_CUDA

/**
Expand Down
68 changes: 68 additions & 0 deletions include/dgl/runtime/bfloat16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* Copyright (c) 2023 by Contributors
* @file dgl/runtime/ndarray.h
* @brief BFloat16 CPU header
*/
#ifndef DGL_RUNTIME_BFLOAT16_H_
#define DGL_RUNTIME_BFLOAT16_H_

#include <cmath>

class BFloat16 {
uint16_t val;

public:
constexpr BFloat16() : val(0) {}
// Disable lint "explicit" warning, since implicit usage on constructor is
// expected.
BFloat16(float f) { // NOLINT
if (std::isnan(f)) {
val = 0x7FC0;
} else {
union {
uint16_t iraw16[2];
uint32_t iraw32;
float f32;
};

f32 = f;
const uint32_t rounding_bias = 0x00007FFF + (iraw16[1] & 0x1);
val = static_cast<uint16_t>((iraw32 + rounding_bias) >> 16);
}
}
static constexpr BFloat16 Min() {
BFloat16 min;
min.val = 0xFF80;
return min;
}

static constexpr BFloat16 Max() {
BFloat16 max;
max.val = 0x7F80;
return max;
}

BFloat16& operator-=(const float& rhs) {
float lhs = (*this);
(*this) = lhs - rhs;
return *this;
}

BFloat16& operator+=(const float& rhs) {
float lhs = (*this);
(*this) = lhs + rhs;
return *this;
}

operator float() const {
union {
float f;
uint16_t raw[2];
};
raw[0] = 0;
raw[1] = val;
return f;
}
};

#endif // DGL_RUNTIME_BFLOAT16_H_
1 change: 1 addition & 0 deletions include/dgl/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <utility>
#include <vector>

#include "bfloat16.h"
#include "c_runtime_api.h"
#include "serializer.h"
#include "shared_mem.h"
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def config_cython():
library_dirs=library_dirs,
libraries=libraries,
# Crashes without this flag with GCC 5.3.1
extra_compile_args=["-std=c++11"],
extra_compile_args=["-std=c++14"],
language="c++",
)
)
Expand Down
22 changes: 22 additions & 0 deletions src/array/cpu/gather_mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ void GatherMMScatter(
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}

template void GatherMM<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
Expand All @@ -53,6 +59,12 @@ template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);

template void GatherMMScatter<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
Expand All @@ -66,6 +78,12 @@ template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);

template void SegmentMM<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
Expand All @@ -79,6 +97,10 @@ template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);

template void SegmentMMBackwardB<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
Expand Down
36 changes: 36 additions & 0 deletions src/array/cpu/sddmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ void SDDMMCsrHetero(
});
}

template void SDDMMCsr<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
Expand Down Expand Up @@ -120,6 +126,18 @@ template void SDDMMCsrRedirected<kDGLCPU, int64_t, double>(
NDArray lhs, NDArray rhs, NDArray out, NDArray efeats_redirected,int lhs_target, int rhs_target);


template void SDDMMCsrHetero<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
Expand Down Expand Up @@ -181,6 +199,12 @@ void SDDMMCooHetero(
});
}

template void SDDMMCoo<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
Expand All @@ -194,6 +218,18 @@ template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);

template void SDDMMCooHetero<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
Expand Down
34 changes: 34 additions & 0 deletions src/array/cpu/segment_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}

template void SegmentReduce<kDGLCPU, int32_t, BFloat16>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, BFloat16>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
Expand All @@ -69,6 +75,16 @@ template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);

template <>
void ScatterAdd<kDGLCPU, int32_t, BFloat16>(
NDArray feat, NDArray idx, NDArray out) {
LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
template <>
void ScatterAdd<kDGLCPU, int64_t, BFloat16>(
NDArray feat, NDArray idx, NDArray out) {
LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, float>(
Expand All @@ -78,6 +94,20 @@ template void ScatterAdd<kDGLCPU, int32_t, double>(
template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat, NDArray arg, NDArray out);

template <>
void UpdateGradMinMax_hetero<kDGLCPU, int32_t, BFloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
template <>
void UpdateGradMinMax_hetero<kDGLCPU, int64_t, BFloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
Expand All @@ -95,6 +125,10 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);

template void BackwardSegmentCmp<kDGLCPU, int32_t, BFloat16>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, BFloat16>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
Expand Down
2 changes: 2 additions & 0 deletions src/array/cpu/segment_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace cpu {
*/
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
if (std::is_same<DType, BFloat16>::value)
LOG(FATAL) << "Unsupported CPU kernel for SegmentSum for BF16.";
int n = out->shape[0];
int dim = 1;
for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
Expand Down
Loading

0 comments on commit df3f316

Please sign in to comment.