Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU][Kernel] Single socket spmm #3024

Merged
merged 23 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7c944c0
optimizations of spmm for CPU
sanchit-misra Jun 15, 2021
e68fa07
Added names of contributors
sanchit-misra Jun 15, 2021
30e3951
Minor code cleanup
sanchit-misra Jun 15, 2021
a797833
Moved the spmm optimization code to a new header file
sanchit-misra Jun 18, 2021
78cece7
Moved to DGL's logging method
sanchit-misra Jun 22, 2021
e47c0c0
removed duplicate code between SpMMSumCsr and SpMMCmpCsr
sanchit-misra Jun 22, 2021
ad78519
Changes made to follow Google coding style
sanchit-misra Jun 23, 2021
d51301f
Fixed lint errors in spmm.h
sanchit-misra Jun 23, 2021
bed06d6
Fixed some lint errors from spmm_blocking_libxsmm.h
sanchit-misra Jun 23, 2021
9c941ba
Fixed lint errors from spmm_blocking_libxsmm.h
sanchit-misra Jun 23, 2021
e11f8e8
Added comments to SpMMCreateLibxsmmKernel
sanchit-misra Jun 23, 2021
3eaa153
to enable building of tests, and other cosmetic changes
sanchit-misra Jun 26, 2021
ffc8520
disabling libxsmm on windows
sanchit-misra Jun 28, 2021
5dec62f
Put a condition to avoid opt impl for FP64 as libxsmm does not have F…
sanchit-misra Jul 3, 2021
7b8781b
Merge branch 'master' into single-socket-spmm
jermainewang Jul 5, 2021
7c0a9c3
cosmetic changes and documentation
sanchit-misra Jul 5, 2021
5f4ce36
Merge branch 'single-socket-spmm' of https://github.com/sanchit-misra…
sanchit-misra Jul 5, 2021
ada9e41
cosmetic changes
sanchit-misra Jul 7, 2021
1e32045
to pass lint tests
sanchit-misra Jul 7, 2021
e1288f2
replaced multiple allocations for buffers of indices and edges with a…
sanchit-misra Jul 7, 2021
a1c3734
Merge branch 'master' into single-socket-spmm
jermainewang Jul 12, 2021
dae56cc
Merge branch 'master' into single-socket-spmm
jermainewang Jul 12, 2021
64fbcad
Merge branch 'master' into single-socket-spmm
jermainewang Jul 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@
[submodule "third_party/nccl"]
path = third_party/nccl
url = https://github.com/nvidia/nccl
[submodule "third_party/libxsmm"]
path = third_party/libxsmm
url = https://github.com/hfp/libxsmm.git
29 changes: 23 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dgl_option(USE_CUDA "Build with CUDA" OFF)
dgl_option(USE_SYSTEM_NCCL "Build using system's NCCL library" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(USE_AVX "Build with AVX optimization" ON)
dgl_option(USE_LIBXSMM "Build with LIBXSMM library optimization" ON)
dgl_option(USE_FP16 "Build with fp16 support to enable mixed precision training" OFF)
dgl_option(USE_TVM "Build with TVM kernels" OFF)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
Expand All @@ -35,7 +36,7 @@ dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if

# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG -O0 -g3 -ggdb")
endif(NOT MSVC)

if(USE_CUDA)
Expand Down Expand Up @@ -88,10 +89,10 @@ if(MSVC)
else(MSVC)
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11)
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -march=native ${CMAKE_C_FLAGS}")
# We still use c++11 flag in CPU build because gcc5.4 (our default compiler) is
# not fully compatible with c++14 feature.
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 -march=native ${CMAKE_CXX_FLAGS}")
if(NOT APPLE)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--warn-common ${CMAKE_SHARED_LINKER_FLAGS}")
endif(NOT APPLE)
Expand All @@ -107,9 +108,15 @@ if(USE_OPENMP)
endif(USE_OPENMP)

if(USE_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
message(STATUS "Build with AVX optimization.")
if(USE_LIBXSMM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should users set a proper DGL_CPU_LLC_SIZE value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the size of the LLC of the CPU the code is run on. This macro here is just failsafe. The code automatically gets the LLC size using sysconf() in the function getLLCSize(). Only if that fails, it uses this number. The number I have used as default here is quite safe for most server class CPUs. Only if sysconf fails and this number is also bigger than the user's LLC, then they will have to set the LLC size here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if sysconf fails and this number is also bigger than the user's LLC, then they will have to set the LLC size here.

In this case, will the program crash or just run with under-optimal config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't crash. Will just run with under-optimal config.

message(STATUS "Build with LIBXSMM optimization.")
else(USE_LIBXSMM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
message(STATUS "Build with AVX optimization.")
endif(USE_LIBXSMM)
endif(USE_AVX)

# Build with fp16 to support mixed precision training.
Expand Down Expand Up @@ -191,6 +198,7 @@ target_include_directories(dgl PRIVATE "third_party/xbyak/")
target_include_directories(dgl PRIVATE "third_party/METIS/include/")
target_include_directories(dgl PRIVATE "tensoradapter/include")
target_include_directories(dgl PRIVATE "third_party/nanoflann/include")
target_include_directories(dgl PRIVATE "third_party/libxsmm/include")

# For serialization
if (USE_HDFS)
Expand All @@ -210,6 +218,15 @@ if(NOT MSVC)
list(APPEND DGL_LINKER_LIBS metis)
endif(NOT MSVC)

# Compile LIBXSMM
if(USE_LIBXSMM)
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j BLAS=0
sanchit-misra marked this conversation as resolved.
Show resolved Hide resolved
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
add_dependencies(dgl libxsmm)
list(APPEND DGL_LINKER_LIBS -L${CMAKE_SOURCE_DIR}/third_party/libxsmm/lib/ xsmm)
endif(USE_LIBXSMM)

# Compile TVM Runtime and Featgraph
# (NOTE) We compile a dynamic library called featgraph_runtime, which the DGL library links to.
# Kernels are packed in a separate dynamic library called featgraph_kernels, which DGL
Expand Down
127 changes: 94 additions & 33 deletions src/array/cpu/spmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#if !defined(_WIN32)
#ifdef USE_AVX
#include "intel/cpu_support.h"
#ifdef USE_LIBXSMM
#include "spmm_blocking_libxsmm.h"
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
namespace dgl {
Expand Down Expand Up @@ -42,52 +45,75 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
const DType* W = efeat.Ptr<DType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
DType* O = out.Ptr<DType>();
assert(indptr != nullptr);
sanchit-misra marked this conversation as resolved.
Show resolved Hide resolved
assert(O != nullptr);
if(Op::use_lhs)
sanchit-misra marked this conversation as resolved.
Show resolved Hide resolved
{
assert(indices != nullptr);
assert(X != nullptr);
sanchit-misra marked this conversation as resolved.
Show resolved Hide resolved
}
if(Op::use_rhs)
{
if(has_idx)
assert(edges != nullptr);
assert(W != nullptr);
}
#if !defined(_WIN32)
#ifdef USE_AVX
typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
/* Prepare an assembler kernel */
static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
(dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
/* Distribute the kernel among OMP threads */
ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
? asm_kernel_ptr.get()
: nullptr;
if (cpu_spec && dim > 16 && !bcast.use_bcast) {
#ifdef USE_LIBXSMM
bool special_condition = bcast.use_bcast || (Op::use_lhs && (dim != lhs_dim)) || (Op::use_rhs && (dim != rhs_dim));
if(!special_condition)
{
SpMMSumCsrOpt<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
} else {
#endif
sanchit-misra marked this conversation as resolved.
Show resolved Hide resolved
typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
/* Prepare an assembler kernel */
static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
(dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
/* Distribute the kernel among OMP threads */
ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
? asm_kernel_ptr.get()
: nullptr;
if (cpu_spec && dim > 16 && !bcast.use_bcast) {
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
}
}
}
} else {
} else {
sanchit-misra marked this conversation as resolved.
Show resolved Hide resolved
#endif // USE_AVX
#endif // _WIN32

#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
}
}
}
}
#if !defined(_WIN32)
#ifdef USE_AVX
}
#ifdef USE_LIBXSMM
}
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
}
Expand Down Expand Up @@ -172,6 +198,34 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
assert(indptr != nullptr);
assert(O != nullptr);
if(Op::use_lhs)
{
assert(indices != nullptr);
assert(X != nullptr);
assert(argX != nullptr);
}
if(Op::use_rhs)
{
if(has_idx)
assert(edges != nullptr);
assert(W != nullptr);
assert(argW != nullptr);
}
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM

bool special_condition = bcast.use_bcast || (Op::use_lhs && (dim != lhs_dim)) || (Op::use_rhs && (dim != rhs_dim));
if(!special_condition)
{
SpMMCmpCsrOpt<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge);
} else {
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32

#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
Expand All @@ -197,6 +251,13 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
}
}
}
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
}
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
}

/*!
Expand Down
Loading