From 6fb556b41a65d528488379828b6fda81f13a8dc9 Mon Sep 17 00:00:00 2001 From: mseth10 Date: Fri, 16 Aug 2019 02:11:22 +0000 Subject: [PATCH] adding infer shape type before partition --- python/mxnet/symbol/symbol.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 4b71177afaa8..d40d1b39bcf6 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1437,10 +1437,39 @@ def _gen_atomic_symbol(self): return Symbol(handle) - def optimizeFor(self, backend): + def optimizeFor(self, backend, ctx=None, args=None, **kwargs): """Partition symbol and optimize it for a given backend""" + if (!ctx or !args): + infer_shape_type = False + else: + infer_shape_type = True + if not isinstance(ctx, Context): + raise TypeError("Context type error") + listed_arguments = self.list_arguments() + args_handle, args = self._get_ndarray_inputs('args', args, listed_arguments, False) + aux_states = [] + aux_args_handle, aux_states = self._get_ndarray_inputs( + 'aux_states', aux_states, self.list_auxiliary_states(), False) + ctx_map_keys = [] + ctx_map_dev_types = [] + ctx_map_dev_ids = [] + out = SymbolHandle() - check_call(_LIB.MXOptimizeForBackend(self.handle, c_str(backend), ctypes.byref(out))) + check_call(_LIB.MXOptimizeForBackend(self.handle, + c_str(backend), + ctypes.byref(out), + ctypes.c_bool(infer_shape_type) + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + mx_uint(len(ctx_map_keys)), + c_str_array(ctx_map_keys), + c_array_buf(ctypes.c_int, array('i', ctx_map_dev_types)), + c_array_buf(ctypes.c_int, array('i', ctx_map_dev_ids)), + mx_uint(len(args)), + args_handle, + mx_uint(len(aux_states)), + aux_args_handle)) + return out # pylint: disable=too-many-locals