Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpu: nvidia: Updated matmul labda capture #2110

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 88 additions & 72 deletions src/gpu/nvidia/cudnn_matmul_executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef GPU_NVIDIA_CUDNN_MATMUL_EXECUTOR_HPP
#define GPU_NVIDIA_CUDNN_MATMUL_EXECUTOR_HPP

#include "common/compiler_workarounds.hpp"
#include "common/primitive_exec_types.hpp"
#include "gpu/nvidia/cudnn_matmul.hpp"
#include "gpu/nvidia/cudnn_matmul_impl.hpp"
Expand Down Expand Up @@ -60,35 +61,39 @@ struct cudnn_matmul_base_exec_t {
arg_dst_scale,
uint8_t *bias_scratch_ptr) {

compat::host_task(cgh, [=](const compat::interop_handle &ih) {
auto &sycl_engine = *utils::downcast<nvidia::engine_t *>(
cuda_stream->engine());
auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine);
// SYCL out-of-order queue encapsulates multiple CUstream objects.
// Every query of the CUstream object can return any of those
// therefore we need to make sure that we activate both cuDNN and
// cuBLAS handles for the same CUstream object.
auto native_stream = cuda_stream->get_underlying_stream();
auto cublas_handle = cuda_stream->get_cublas_handle(native_stream);
auto cudnn_handle = cuda_stream->get_cudnn_handle(native_stream);

void *reorder_scratch = arg_bias_scratch.get_native_pointer(ih);
void *bias = arg_bias.get_native_pointer(ih);
void *weights = arg_weights.get_native_pointer(ih);
void *src = arg_src.get_native_pointer(ih);
void *dst = arg_dst.get_native_pointer(ih);

void *src_scale = arg_src_scale.get_native_pointer(ih);
void *wei_scale = arg_wei_scale.get_native_pointer(ih);
void *dst_scale = arg_dst_scale.get_native_pointer(ih);

matmul_impl_->execute(cublas_handle, cudnn_handle, weights, src,
dst, bias, reorder_scratch, src_scale, wei_scale,
dst_scale);

free_runtime_scratch(matmul_impl_->has_runtime_params(),
cublas_handle, cuda_stream, bias_scratch_ptr);
});
compat::host_task(cgh,
[= WA_THIS_COPY_CAPTURE](const compat::interop_handle &ih) {
auto &sycl_engine = *utils::downcast<nvidia::engine_t *>(
cuda_stream->engine());
auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine);
// SYCL out-of-order queue encapsulates multiple CUstream objects.
// Every query of the CUstream object can return any of those
// therefore we need to make sure that we activate both cuDNN and
// cuBLAS handles for the same CUstream object.
auto native_stream = cuda_stream->get_underlying_stream();
auto cublas_handle
= cuda_stream->get_cublas_handle(native_stream);
auto cudnn_handle
= cuda_stream->get_cudnn_handle(native_stream);

void *reorder_scratch
= arg_bias_scratch.get_native_pointer(ih);
void *bias = arg_bias.get_native_pointer(ih);
void *weights = arg_weights.get_native_pointer(ih);
void *src = arg_src.get_native_pointer(ih);
void *dst = arg_dst.get_native_pointer(ih);

void *src_scale = arg_src_scale.get_native_pointer(ih);
void *wei_scale = arg_wei_scale.get_native_pointer(ih);
void *dst_scale = arg_dst_scale.get_native_pointer(ih);

matmul_impl_->execute(cublas_handle, cudnn_handle, weights,
src, dst, bias, reorder_scratch, src_scale,
wei_scale, dst_scale);

free_runtime_scratch(matmul_impl_->has_runtime_params(),
cublas_handle, cuda_stream, bias_scratch_ptr);
});
}

void free_runtime_scratch(bool has_runtime_params,
Expand Down Expand Up @@ -233,47 +238,56 @@ struct cudnn_matmul_lt_base_exec_t {
uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr,
uint8_t *wei_scale_scratch_ptr) {

compat::host_task(cgh, [=](const compat::interop_handle &ih) {
auto &sycl_engine = *utils::downcast<nvidia::engine_t *>(
cuda_stream->engine());
auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine);
// SYCL out-of-order queue encapsulates multiple CUstream objects.
// Every query of the CUstream object can return any of those
// therefore we need to make sure that we activate both cuDNN and
// cuBLAS handles for the same CUstream object.
auto native_stream = cuda_stream->get_underlying_stream();
auto cublas_handle = cuda_stream->get_cublas_handle(native_stream);
auto cudnn_handle = cuda_stream->get_cudnn_handle(native_stream);

void *reorder_scratch = arg_bias_scratch.get_native_pointer(ih);
void *algo_scratch = arg_algo_scratch.get_native_pointer(ih);
void *block_a_scratch = arg_block_a_scratch.get_native_pointer(ih);
void *block_b_scratch = arg_block_b_scratch.get_native_pointer(ih);
void *block_c_scratch = arg_block_c_scratch.get_native_pointer(ih);

void *scaled_src = scaled_arg_src.get_native_pointer(ih);
void *scaled_wt = scaled_arg_wt.get_native_pointer(ih);

void *bias = arg_bias.get_native_pointer(ih);
void *weights = arg_weights.get_native_pointer(ih);
void *src = arg_src.get_native_pointer(ih);
void *dst = arg_dst.get_native_pointer(ih);

void *src_scale = arg_src_scale.get_native_pointer(ih);
void *wei_scale = arg_wei_scale.get_native_pointer(ih);
void *dst_scale = arg_dst_scale.get_native_pointer(ih);

matmul_impl_->execute(cublas_handle, cudnn_handle, weights, src,
dst, bias, algo_scratch, reorder_scratch, block_a_scratch,
block_b_scratch, block_c_scratch, scaled_src, scaled_wt,
src_scale, wei_scale, dst_scale);

free_runtime_scratch(matmul_impl_->has_runtime_params(),
cublas_handle, cuda_stream, algo_scratch_ptr,
bias_scratch_ptr, block_a_scratch_ptr, block_b_scratch_ptr,
block_c_scratch_ptr, src_scale_scratch_ptr,
wei_scale_scratch_ptr);
});
compat::host_task(cgh,
[= WA_THIS_COPY_CAPTURE](const compat::interop_handle &ih) {
auto &sycl_engine = *utils::downcast<nvidia::engine_t *>(
cuda_stream->engine());
auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine);
// SYCL out-of-order queue encapsulates multiple CUstream objects.
// Every query of the CUstream object can return any of those
// therefore we need to make sure that we activate both cuDNN and
// cuBLAS handles for the same CUstream object.
auto native_stream = cuda_stream->get_underlying_stream();
auto cublas_handle
= cuda_stream->get_cublas_handle(native_stream);
auto cudnn_handle
= cuda_stream->get_cudnn_handle(native_stream);

void *reorder_scratch
= arg_bias_scratch.get_native_pointer(ih);
void *algo_scratch
= arg_algo_scratch.get_native_pointer(ih);
void *block_a_scratch
= arg_block_a_scratch.get_native_pointer(ih);
void *block_b_scratch
= arg_block_b_scratch.get_native_pointer(ih);
void *block_c_scratch
= arg_block_c_scratch.get_native_pointer(ih);

void *scaled_src = scaled_arg_src.get_native_pointer(ih);
void *scaled_wt = scaled_arg_wt.get_native_pointer(ih);

void *bias = arg_bias.get_native_pointer(ih);
void *weights = arg_weights.get_native_pointer(ih);
void *src = arg_src.get_native_pointer(ih);
void *dst = arg_dst.get_native_pointer(ih);

void *src_scale = arg_src_scale.get_native_pointer(ih);
void *wei_scale = arg_wei_scale.get_native_pointer(ih);
void *dst_scale = arg_dst_scale.get_native_pointer(ih);

matmul_impl_->execute(cublas_handle, cudnn_handle, weights,
src, dst, bias, algo_scratch, reorder_scratch,
block_a_scratch, block_b_scratch, block_c_scratch,
scaled_src, scaled_wt, src_scale, wei_scale,
dst_scale);

free_runtime_scratch(matmul_impl_->has_runtime_params(),
cublas_handle, cuda_stream, algo_scratch_ptr,
bias_scratch_ptr, block_a_scratch_ptr,
block_b_scratch_ptr, block_c_scratch_ptr,
src_scale_scratch_ptr, wei_scale_scratch_ptr);
});
}

protected:
Expand Down Expand Up @@ -336,7 +350,8 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
nvidia::stream_t *cuda_stream
= utils::downcast<nvidia::stream_t *>(ctx.stream());

return cuda_stream->interop_task([=](::sycl::handler &cgh) {
return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE](
::sycl::handler &cgh) {
auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC);
auto arg_wt = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS);
auto arg_bias = CTX_IN_SYCL_MEMORY(DNNL_ARG_BIAS);
Expand Down Expand Up @@ -410,7 +425,8 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
uint8_t *wei_scale_scratch_ptr
= alloc_ptr(wei_scale_scratchpad_size, cuda_stream->queue());

return cuda_stream->interop_task([=](::sycl::handler &cgh) {
return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE](
::sycl::handler &cgh) {
auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC);
auto arg_wt = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS);
auto arg_dst = CTX_OUT_SYCL_MEMORY(DNNL_ARG_DST);
Expand Down