From 8b1508a8d5b5ca44a0def938b77e84762f7364d9 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 17 Jan 2022 23:06:38 -0800 Subject: [PATCH] Fix the regular expression in RTC code (#20810) --- python/mxnet/rtc.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/mxnet/rtc.py b/python/mxnet/rtc.py index ada487663866..f03685a09940 100644 --- a/python/mxnet/rtc.py +++ b/python/mxnet/rtc.py @@ -141,21 +141,22 @@ def get_kernel(self, name, signature): is_ndarray = [] is_const = [] dtypes = [] - pattern = re.compile(r"""^\s*(const)?\s*([\w_]+)\s*(\*)?\s*([\w_]+)?\s*$""") + pattern = re.compile(r"""^(const)?\s?([\w_]+)\s?(\*)?\s?([\w_]+)?$""") args = re.sub(r"\s+", " ", signature).split(",") for arg in args: - match = pattern.match(arg) + sanitized_arg = " ".join(arg.split()) + match = pattern.match(sanitized_arg) if not match or match.groups()[1] == 'const': raise ValueError( 'Invalid function prototype "%s". Must be in the ' - 'form of "(const) type (*) (name)"'%arg) + 'form of "(const) type (*) (name)"'%sanitized_arg) is_const.append(bool(match.groups()[0])) dtype = match.groups()[1] is_ndarray.append(bool(match.groups()[2])) if dtype not in _DTYPE_CPP_TO_NP: raise TypeError( "Unsupported kernel argument type %s. Supported types are: %s."%( - arg, ','.join(_DTYPE_CPP_TO_NP.keys()))) + sanitized_arg, ','.join(_DTYPE_CPP_TO_NP.keys()))) dtypes.append(_DTYPE_NP_TO_MX[_DTYPE_CPP_TO_NP[dtype]]) check_call(_LIB.MXRtcCudaKernelCreate(