From 2ebeafb1db772ba0e30f13f02e587eb9edf66c3c Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Wed, 31 Mar 2021 10:17:25 +0800 Subject: [PATCH] [ROCM] Add ROCm support for warpctc op (#31817) * bugfix for warpctc * fix warpctc commit id * fix warpctc commit id * fix warpctc commit id * fix warpctc commit id * fix warpctc commit id * fix WARPCTC_WITH_HIP invalid * Add logs to find out why can not dlopen libwarpctc.so * fix warpctc commit id * fix unit test test_warpctc_op * Optime failed log for dlopen * Optime failed log for dlopen * Delete extra changes * fix warpctc commit id * fix warpctc commit id * Add is_compiled_with_rocm for test_warpctc_op * fix warpctc commit id * Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed * Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed * Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed * fix code style problems --- cmake/external/warpctc.cmake | 7 ++++- paddle/fluid/operators/warpctc_op.h | 3 +- .../fluid/tests/unittests/test_warpctc_op.py | 29 ++++++++++++++++--- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index b0ef575f64323..ac28f7561f60c 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -14,11 +14,15 @@ INCLUDE(ExternalProject) +IF(WITH_ROCM) + add_definitions(-DWARPCTC_WITH_HIP) +ENDIF() + SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc) SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc) SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git) -set(WARPCTC_TAG cd828e5b6c3b953b82af73f7f44cddc393a20efa) +set(WARPCTC_TAG c690fc5755abbdbdc98ef78d51ec10a6748a8cd1) SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" CACHE PATH "Warp-ctc Directory" FORCE) @@ -57,6 +61,7 @@ ExternalProject_Add( -DCMAKE_CXX_FLAGS_DEBUG=$ -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR} -DWITH_GPU=${WITH_GPU} + -DWITH_ROCM=${WITH_ROCM} -DWITH_OMP=${USE_OMP} -DWITH_TORCH=OFF -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 7451cac63d0ce..e90eefd72d4ce 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -159,8 +159,7 @@ class WarpCTCFunctor { warpctc_version_ = platform::dynload::get_warpctc_version(); if (platform::is_gpu_place(ctx.GetPlace())) { -// HIP not support ctcOptions in third-party warpctc -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) options_.loc = CTC_GPU; options_.stream = reinterpret_cast( ctx.device_context()) diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index 6310a76d8d000..53f3b3cf53d76 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -20,6 +20,7 @@ from op_test import OpTest from test_softmax_op import stable_softmax import paddle.fluid as fluid +import paddle.fluid.core as core from paddle.fluid import Program, program_guard import paddle import paddle.nn.functional as F @@ -240,8 +241,18 @@ def test_check_output(self): def test_check_grad(self): self.outputs['WarpCTCGrad'] = self.gradient - self.check_grad( - ["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False) + if core.is_compiled_with_rocm(): + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=0.009, + check_dygraph=False) + else: + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=0.007, + check_dygraph=False) class TestWarpCTCOpCase1(TestWarpCTCOp): @@ -335,8 +346,18 @@ def test_check_output(self): def test_check_grad(self): self.outputs['WarpCTCGrad'] = self.gradient - self.check_grad( - ["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False) + if core.is_compiled_with_rocm(): + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=0.009, + check_dygraph=False) + else: + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=0.007, + check_dygraph=False) class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):