Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding infer shape type before partition
Browse files Browse the repository at this point in the history
  • Loading branch information
mseth10 committed Aug 16, 2019
1 parent bb4fb4f commit 6fb556b
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,10 +1437,39 @@ def _gen_atomic_symbol(self):
return Symbol(handle)


def optimizeFor(self, backend):
def optimizeFor(self, backend, ctx=None, args=None, **kwargs):
"""Partition symbol and optimize it for a given backend"""
if (!ctx or !args):
infer_shape_type = False
else:
infer_shape_type = True
if not isinstance(ctx, Context):
raise TypeError("Context type error")
listed_arguments = self.list_arguments()
args_handle, args = self._get_ndarray_inputs('args', args, listed_arguments, False)
aux_states = []
aux_args_handle, aux_states = self._get_ndarray_inputs(
'aux_states', aux_states, self.list_auxiliary_states(), False)
ctx_map_keys = []
ctx_map_dev_types = []
ctx_map_dev_ids = []

out = SymbolHandle()
check_call(_LIB.MXOptimizeForBackend(self.handle, c_str(backend), ctypes.byref(out)))
check_call(_LIB.MXOptimizeForBackend(self.handle,
c_str(backend),
ctypes.byref(out),
ctypes.c_bool(infer_shape_type)
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
mx_uint(len(ctx_map_keys)),
c_str_array(ctx_map_keys),
c_array_buf(ctypes.c_int, array('i', ctx_map_dev_types)),
c_array_buf(ctypes.c_int, array('i', ctx_map_dev_ids)),
mx_uint(len(args)),
args_handle,
mx_uint(len(aux_states)),
aux_args_handle))
return out


# pylint: disable=too-many-locals
Expand Down

0 comments on commit 6fb556b

Please sign in to comment.