Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM] Add ROCm support for warpctc op (#31817) #31971

Merged
merged 1 commit into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cmake/external/warpctc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

INCLUDE(ExternalProject)

IF(WITH_ROCM)
add_definitions(-DWARPCTC_WITH_HIP)
ENDIF()

SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git)
set(WARPCTC_TAG cd828e5b6c3b953b82af73f7f44cddc393a20efa)
set(WARPCTC_TAG c690fc5755abbdbdc98ef78d51ec10a6748a8cd1)

SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE)
Expand Down Expand Up @@ -57,6 +61,7 @@ ExternalProject_Add(
-DCMAKE_CXX_FLAGS_DEBUG=$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DWITH_TORCH=OFF
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/warpctc_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ class WarpCTCFunctor {
warpctc_version_ = platform::dynload::get_warpctc_version();

if (platform::is_gpu_place(ctx.GetPlace())) {
// HIP not support ctcOptions in third-party warpctc
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
options_.loc = CTC_GPU;
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
Expand Down
29 changes: 25 additions & 4 deletions python/paddle/fluid/tests/unittests/test_warpctc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from op_test import OpTest
from test_softmax_op import stable_softmax
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -240,8 +241,18 @@ def test_check_output(self):

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(
["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False)
if core.is_compiled_with_rocm():
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.009,
check_dygraph=False)
else:
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.007,
check_dygraph=False)


class TestWarpCTCOpCase1(TestWarpCTCOp):
Expand Down Expand Up @@ -335,8 +346,18 @@ def test_check_output(self):

def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(
["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False)
if core.is_compiled_with_rocm():
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.009,
check_dygraph=False)
else:
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.007,
check_dygraph=False)


class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):
Expand Down