Skip to content

Commit

Permalink
[ROCm]: Add ROCm command buffer support for triton kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Feb 5, 2024
1 parent 789ef4b commit f01c27f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jaxlib/gpu/triton_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,16 +524,16 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {

gpustreamCaptureStatus_t capture_status;
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
bool is_capturing = capture_status == CU_STREAM_CAPTURE_STATUS_ACTIVE;
bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE;

gpustreamCaptureMode_t capture_mode = CU_STREAM_CAPTURE_MODE_RELAXED;
gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED;
gpuStream_t autotune_stream = stream;

if (is_capturing) {

GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
// Need a side stream so as not to interfere with graph capture.
GPU_RETURN_IF_ERROR(
gpuStreamCreate(&autotune_stream, CU_STREAM_NON_BLOCKING));
GPU_RETURN_IF_ERROR(gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING));
}

// If an input aliases with an output, it will get overwritten during the
Expand Down
14 changes: 14 additions & 0 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS

#define GPU_STREAM_CAPTURE_STATUS_ACTIVE CU_STREAM_CAPTURE_STATUS_ACTIVE
#define GPU_STREAM_CAPTURE_MODE_RELAXED CU_STREAM_CAPTURE_MODE_RELAXED
#define GPU_STREAM_NON_BLOCKING CU_STREAM_NON_BLOCKING

#define gpuCtxGetDevice cuCtxGetDevice
#define gpuCtxPopCurrent cuCtxPopCurrent
#define gpuCtxPushCurrent cuCtxPushCurrent
Expand Down Expand Up @@ -332,6 +336,8 @@ typedef hipsolverFillMode_t gpusolverFillMode_t;
typedef hipblasHandle_t gpublasHandle_t;
typedef hipblasStatus_t gpublasStatus_t;
typedef hipCtx_t gpuContext_t;
typedef hipStreamCaptureMode gpustreamCaptureMode_t;
typedef hipStreamCaptureStatus gpustreamCaptureStatus_t;
typedef hipDataType gpuDataType;
typedef hipDevice_t gpuDevice_t;
typedef hipDeviceptr_t gpuDevicePtr_t;
Expand Down Expand Up @@ -494,6 +500,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS

#define GPU_STREAM_CAPTURE_STATUS_ACTIVE hipStreamCaptureStatusActive
#define GPU_STREAM_CAPTURE_MODE_RELAXED hipStreamCaptureModeRelaxed
#define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking

#define gpuGetLastError hipGetLastError
#define gpuGetErrorString hipGetErrorString
#define gpuMemcpyAsync hipMemcpyAsync
Expand Down Expand Up @@ -526,6 +536,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuMemcpyDtoHAsync hipMemcpyDtoHAsync
#define gpuMemcpyHtoDAsync hipMemcpyHtoDAsync
#define gpuMemsetD8Async hipMemsetD8Async
#define gpuThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode
#define gpuStreamCreate hipStreamCreateWithFlags
#define gpuStreamDestroy hipStreamDestroy
#define gpuStreamIsCapturing hipStreamIsCapturing

#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
hipDeviceAttributeComputeCapabilityMajor
Expand Down

0 comments on commit f01c27f

Please sign in to comment.