diff --git a/python/mxnet/rtc.py b/python/mxnet/rtc.py index ada487663866..f03685a09940 100644 --- a/python/mxnet/rtc.py +++ b/python/mxnet/rtc.py @@ -141,21 +141,22 @@ def get_kernel(self, name, signature): is_ndarray = [] is_const = [] dtypes = [] - pattern = re.compile(r"""^\s*(const)?\s*([\w_]+)\s*(\*)?\s*([\w_]+)?\s*$""") + pattern = re.compile(r"""^(const)?\s?([\w_]+)\s?(\*)?\s?([\w_]+)?$""") args = re.sub(r"\s+", " ", signature).split(",") for arg in args: - match = pattern.match(arg) + sanitized_arg = " ".join(arg.split()) + match = pattern.match(sanitized_arg) if not match or match.groups()[1] == 'const': raise ValueError( 'Invalid function prototype "%s". Must be in the ' - 'form of "(const) type (*) (name)"'%arg) + 'form of "(const) type (*) (name)"'%sanitized_arg) is_const.append(bool(match.groups()[0])) dtype = match.groups()[1] is_ndarray.append(bool(match.groups()[2])) if dtype not in _DTYPE_CPP_TO_NP: raise TypeError( "Unsupported kernel argument type %s. Supported types are: %s."%( - arg, ','.join(_DTYPE_CPP_TO_NP.keys()))) + sanitized_arg, ','.join(_DTYPE_CPP_TO_NP.keys()))) dtypes.append(_DTYPE_NP_TO_MX[_DTYPE_CPP_TO_NP[dtype]]) check_call(_LIB.MXRtcCudaKernelCreate(