Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2e2d717

Browse files
josephevansptrendx
andauthored
Fix the regular expression in RTC code (#20810) (#20839)
Co-authored-by: Przemyslaw Tredak <ptredak@nvidia.com>
1 parent 4e64e2c commit 2e2d717

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

python/mxnet/rtc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,21 +141,22 @@ def get_kernel(self, name, signature):
141141
is_ndarray = []
142142
is_const = []
143143
dtypes = []
144-
pattern = re.compile(r"""^\s*(const)?\s*([\w_]+)\s*(\*)?\s*([\w_]+)?\s*$""")
144+
pattern = re.compile(r"""^(const)?\s?([\w_]+)\s?(\*)?\s?([\w_]+)?$""")
145145
args = re.sub(r"\s+", " ", signature).split(",")
146146
for arg in args:
147-
match = pattern.match(arg)
147+
sanitized_arg = " ".join(arg.split())
148+
match = pattern.match(sanitized_arg)
148149
if not match or match.groups()[1] == 'const':
149150
raise ValueError(
150151
'Invalid function prototype "%s". Must be in the '
151-
'form of "(const) type (*) (name)"'%arg)
152+
'form of "(const) type (*) (name)"'%sanitized_arg)
152153
is_const.append(bool(match.groups()[0]))
153154
dtype = match.groups()[1]
154155
is_ndarray.append(bool(match.groups()[2]))
155156
if dtype not in _DTYPE_CPP_TO_NP:
156157
raise TypeError(
157158
"Unsupported kernel argument type %s. Supported types are: %s."%(
158-
arg, ','.join(_DTYPE_CPP_TO_NP.keys())))
159+
sanitized_arg, ','.join(_DTYPE_CPP_TO_NP.keys())))
159160
dtypes.append(_DTYPE_NP_TO_MX[_DTYPE_CPP_TO_NP[dtype]])
160161

161162
check_call(_LIB.MXRtcCudaKernelCreate(

0 commit comments

Comments
 (0)