diff --git a/src/gpu/nvidia/cudnn_matmul_executor.hpp b/src/gpu/nvidia/cudnn_matmul_executor.hpp index 454c7e2fff4..a8e790d7193 100644 --- a/src/gpu/nvidia/cudnn_matmul_executor.hpp +++ b/src/gpu/nvidia/cudnn_matmul_executor.hpp @@ -18,6 +18,7 @@ #ifndef GPU_NVIDIA_CUDNN_MATMUL_EXECUTOR_HPP #define GPU_NVIDIA_CUDNN_MATMUL_EXECUTOR_HPP +#include "common/compiler_workarounds.hpp" #include "common/primitive_exec_types.hpp" #include "gpu/nvidia/cudnn_matmul.hpp" #include "gpu/nvidia/cudnn_matmul_impl.hpp" @@ -60,35 +61,39 @@ struct cudnn_matmul_base_exec_t { arg_dst_scale, uint8_t *bias_scratch_ptr) { - compat::host_task(cgh, [=](const compat::interop_handle &ih) { - auto &sycl_engine = *utils::downcast( - cuda_stream->engine()); - auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine); - // SYCL out-of-order queue encapsulates multiple CUstream objects. - // Every query of the CUstream object can return any of those - // therefore we need to make sure that we activate both cuDNN and - // cuBLAS handles for the same CUstream object. - auto native_stream = cuda_stream->get_underlying_stream(); - auto cublas_handle = cuda_stream->get_cublas_handle(native_stream); - auto cudnn_handle = cuda_stream->get_cudnn_handle(native_stream); - - void *reorder_scratch = arg_bias_scratch.get_native_pointer(ih); - void *bias = arg_bias.get_native_pointer(ih); - void *weights = arg_weights.get_native_pointer(ih); - void *src = arg_src.get_native_pointer(ih); - void *dst = arg_dst.get_native_pointer(ih); - - void *src_scale = arg_src_scale.get_native_pointer(ih); - void *wei_scale = arg_wei_scale.get_native_pointer(ih); - void *dst_scale = arg_dst_scale.get_native_pointer(ih); - - matmul_impl_->execute(cublas_handle, cudnn_handle, weights, src, - dst, bias, reorder_scratch, src_scale, wei_scale, - dst_scale); - - free_runtime_scratch(matmul_impl_->has_runtime_params(), - cublas_handle, cuda_stream, bias_scratch_ptr); - }); + compat::host_task(cgh, + [= WA_THIS_COPY_CAPTURE](const compat::interop_handle &ih) { + auto &sycl_engine = *utils::downcast( + cuda_stream->engine()); + auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine); + // SYCL out-of-order queue encapsulates multiple CUstream objects. + // Every query of the CUstream object can return any of those + // therefore we need to make sure that we activate both cuDNN and + // cuBLAS handles for the same CUstream object. + auto native_stream = cuda_stream->get_underlying_stream(); + auto cublas_handle + = cuda_stream->get_cublas_handle(native_stream); + auto cudnn_handle + = cuda_stream->get_cudnn_handle(native_stream); + + void *reorder_scratch + = arg_bias_scratch.get_native_pointer(ih); + void *bias = arg_bias.get_native_pointer(ih); + void *weights = arg_weights.get_native_pointer(ih); + void *src = arg_src.get_native_pointer(ih); + void *dst = arg_dst.get_native_pointer(ih); + + void *src_scale = arg_src_scale.get_native_pointer(ih); + void *wei_scale = arg_wei_scale.get_native_pointer(ih); + void *dst_scale = arg_dst_scale.get_native_pointer(ih); + + matmul_impl_->execute(cublas_handle, cudnn_handle, weights, + src, dst, bias, reorder_scratch, src_scale, + wei_scale, dst_scale); + + free_runtime_scratch(matmul_impl_->has_runtime_params(), + cublas_handle, cuda_stream, bias_scratch_ptr); + }); } void free_runtime_scratch(bool has_runtime_params, @@ -233,47 +238,56 @@ struct cudnn_matmul_lt_base_exec_t { uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr, uint8_t *wei_scale_scratch_ptr) { - compat::host_task(cgh, [=](const compat::interop_handle &ih) { - auto &sycl_engine = *utils::downcast( - cuda_stream->engine()); - auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine); - // SYCL out-of-order queue encapsulates multiple CUstream objects. - // Every query of the CUstream object can return any of those - // therefore we need to make sure that we activate both cuDNN and - // cuBLAS handles for the same CUstream object. - auto native_stream = cuda_stream->get_underlying_stream(); - auto cublas_handle = cuda_stream->get_cublas_handle(native_stream); - auto cudnn_handle = cuda_stream->get_cudnn_handle(native_stream); - - void *reorder_scratch = arg_bias_scratch.get_native_pointer(ih); - void *algo_scratch = arg_algo_scratch.get_native_pointer(ih); - void *block_a_scratch = arg_block_a_scratch.get_native_pointer(ih); - void *block_b_scratch = arg_block_b_scratch.get_native_pointer(ih); - void *block_c_scratch = arg_block_c_scratch.get_native_pointer(ih); - - void *scaled_src = scaled_arg_src.get_native_pointer(ih); - void *scaled_wt = scaled_arg_wt.get_native_pointer(ih); - - void *bias = arg_bias.get_native_pointer(ih); - void *weights = arg_weights.get_native_pointer(ih); - void *src = arg_src.get_native_pointer(ih); - void *dst = arg_dst.get_native_pointer(ih); - - void *src_scale = arg_src_scale.get_native_pointer(ih); - void *wei_scale = arg_wei_scale.get_native_pointer(ih); - void *dst_scale = arg_dst_scale.get_native_pointer(ih); - - matmul_impl_->execute(cublas_handle, cudnn_handle, weights, src, - dst, bias, algo_scratch, reorder_scratch, block_a_scratch, - block_b_scratch, block_c_scratch, scaled_src, scaled_wt, - src_scale, wei_scale, dst_scale); - - free_runtime_scratch(matmul_impl_->has_runtime_params(), - cublas_handle, cuda_stream, algo_scratch_ptr, - bias_scratch_ptr, block_a_scratch_ptr, block_b_scratch_ptr, - block_c_scratch_ptr, src_scale_scratch_ptr, - wei_scale_scratch_ptr); - }); + compat::host_task(cgh, + [= WA_THIS_COPY_CAPTURE](const compat::interop_handle &ih) { + auto &sycl_engine = *utils::downcast( + cuda_stream->engine()); + auto sc = cuda_sycl_scoped_context_handler_t(sycl_engine); + // SYCL out-of-order queue encapsulates multiple CUstream objects. + // Every query of the CUstream object can return any of those + // therefore we need to make sure that we activate both cuDNN and + // cuBLAS handles for the same CUstream object. + auto native_stream = cuda_stream->get_underlying_stream(); + auto cublas_handle + = cuda_stream->get_cublas_handle(native_stream); + auto cudnn_handle + = cuda_stream->get_cudnn_handle(native_stream); + + void *reorder_scratch + = arg_bias_scratch.get_native_pointer(ih); + void *algo_scratch + = arg_algo_scratch.get_native_pointer(ih); + void *block_a_scratch + = arg_block_a_scratch.get_native_pointer(ih); + void *block_b_scratch + = arg_block_b_scratch.get_native_pointer(ih); + void *block_c_scratch + = arg_block_c_scratch.get_native_pointer(ih); + + void *scaled_src = scaled_arg_src.get_native_pointer(ih); + void *scaled_wt = scaled_arg_wt.get_native_pointer(ih); + + void *bias = arg_bias.get_native_pointer(ih); + void *weights = arg_weights.get_native_pointer(ih); + void *src = arg_src.get_native_pointer(ih); + void *dst = arg_dst.get_native_pointer(ih); + + void *src_scale = arg_src_scale.get_native_pointer(ih); + void *wei_scale = arg_wei_scale.get_native_pointer(ih); + void *dst_scale = arg_dst_scale.get_native_pointer(ih); + + matmul_impl_->execute(cublas_handle, cudnn_handle, weights, + src, dst, bias, algo_scratch, reorder_scratch, + block_a_scratch, block_b_scratch, block_c_scratch, + scaled_src, scaled_wt, src_scale, wei_scale, + dst_scale); + + free_runtime_scratch(matmul_impl_->has_runtime_params(), + cublas_handle, cuda_stream, algo_scratch_ptr, + bias_scratch_ptr, block_a_scratch_ptr, + block_b_scratch_ptr, block_c_scratch_ptr, + src_scale_scratch_ptr, wei_scale_scratch_ptr); + }); } protected: @@ -336,7 +350,8 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t { nvidia::stream_t *cuda_stream = utils::downcast(ctx.stream()); - return cuda_stream->interop_task([=](::sycl::handler &cgh) { + return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE]( + ::sycl::handler &cgh) { auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC); auto arg_wt = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS); auto arg_bias = CTX_IN_SYCL_MEMORY(DNNL_ARG_BIAS); @@ -410,7 +425,8 @@ struct cudnn_matmul_lt_runtime_args_exec_t final uint8_t *wei_scale_scratch_ptr = alloc_ptr(wei_scale_scratchpad_size, cuda_stream->queue()); - return cuda_stream->interop_task([=](::sycl::handler &cgh) { + return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE]( + ::sycl::handler &cgh) { auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC); auto arg_wt = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS); auto arg_dst = CTX_OUT_SYCL_MEMORY(DNNL_ARG_DST);