Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Fix] Add ctx to the original ndarray and revise the usage of context to ctx #16819

Merged
merged 3 commits into from
Nov 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _gather_type_ctx_info(args):
Context of the first appeared NDArray (for backward-compatibility)
"""
if isinstance(args, NDArray):
return False, True, {args.context}, args.context
return False, True, {args.ctx}, args.ctx
elif isinstance(args, Symbol):
return True, False, set(), None
elif isinstance(args, (list, tuple)):
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def forward(self, x, *args):
if len(ctx_set) > 1:
raise ValueError('Find multiple contexts in the input, '
'After hybridized, the HybridBlock only supports one input '
'context. You can print the ele.context in the '
'context. You can print the ele.ctx in the '
'input arguments to inspect their contexts. '
'Find all contexts = {}'.format(ctx_set))
with ctx:
Expand Down Expand Up @@ -1324,7 +1324,7 @@ def __init__(self, outputs, inputs, params=None):

def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
with x.ctx:
return self._call_cached_op(x, *args)

assert isinstance(x, Symbol), \
Expand Down
8 changes: 4 additions & 4 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,10 @@ def _init_grad(self):
if self._grad_stype != 'default':
raise ValueError("mxnet.numpy.zeros does not support stype = {}"
.format(self._grad_stype))
self._grad = [_mx_np.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context)
self._grad = [_mx_np.zeros(shape=i.shape, dtype=i.dtype, ctx=i.ctx)
for i in self._data]
else:
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.ctx,
stype=self._grad_stype) for i in self._data]

autograd.mark_variables(self._check_and_get(self._data, list),
Expand Down Expand Up @@ -522,7 +522,7 @@ def row_sparse_data(self, row_id):
raise RuntimeError("Cannot return a copy of Parameter %s via row_sparse_data() " \
"because its storage type is %s. Please use data() instead." \
%(self.name, self._stype))
return self._get_row_sparse(self._data, row_id.context, row_id)
return self._get_row_sparse(self._data, row_id.ctx, row_id)

def list_row_sparse_data(self, row_id):
"""Returns copies of the 'row_sparse' parameter on all contexts, in the same order
Expand Down Expand Up @@ -897,7 +897,7 @@ def zero_grad(self):
if g.stype == 'row_sparse':
ndarray.zeros_like(g, out=g)
else:
arrays[g.context].append(g)
arrays[g.ctx].append(g)

if len(arrays) == 0:
return
Expand Down
37 changes: 27 additions & 10 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __repr__(self):
shape_info = 'x'.join(['%d' % x for x in self.shape])
return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
self.__class__.__name__,
shape_info, self.context)
shape_info, self.ctx)

def __reduce__(self):
return NDArray, (None,), self.__getstate__()
Expand Down Expand Up @@ -729,14 +729,14 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
`squeeze_axes`: a sequence of axes to squeeze in the value array.
"""
if isinstance(value, numeric_types):
value_nd = full(bcast_shape, value, ctx=self.context, dtype=self.dtype)
value_nd = full(bcast_shape, value, ctx=self.ctx, dtype=self.dtype)
elif type(value) == self.__class__: # pylint: disable=unidiomatic-typecheck
value_nd = value.as_in_context(self.context)
value_nd = value.as_in_context(self.ctx)
if value_nd.dtype != self.dtype:
value_nd = value_nd.astype(self.dtype)
else:
try:
value_nd = array(value, ctx=self.context, dtype=self.dtype)
value_nd = array(value, ctx=self.ctx, dtype=self.dtype)
except:
raise TypeError('{} does not support assignment with non-array-like '
'object {} of type {}'.format(self.__class__, value, type(value)))
Expand Down Expand Up @@ -1220,7 +1220,7 @@ def _get_index_nd(self, key):

shape_nd_permut = tuple(self.shape[ax] for ax in axs_nd_permut)
converted_idcs_short = [
self._advanced_index_to_array(idx, ax_len, self.context)
self._advanced_index_to_array(idx, ax_len, self.ctx)
for idx, ax_len in zip(idcs_permut_short, shape_nd_permut)
]
bcast_idcs_permut_short = self._broadcast_advanced_indices(
Expand All @@ -1229,7 +1229,7 @@ def _get_index_nd(self, key):

# Get the ndim of advanced indexing subspace
converted_advanced_idcs = [
self._advanced_index_to_array(idx, ax_len, self.context)
self._advanced_index_to_array(idx, ax_len, self.ctx)
for idx, ax_len in zip(adv_idcs_nd, [self.shape[ax] for ax in adv_axs_nd])
]
bcast_advanced_shape = _broadcast_shapes(converted_advanced_idcs)
Expand Down Expand Up @@ -2433,6 +2433,23 @@ def context(self):
self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id)))
return Context(Context.devtype2str[dev_typeid.value], dev_id.value)

@property
def ctx(self):
"""Device context of the array. Has the same meaning as context.

Examples
--------
>>> x = mx.nd.array([1, 2, 3, 4])
>>> x.ctx
cpu(0)
>>> type(x.ctx)
<class 'mxnet.context.Context'>
>>> y = mx.nd.zeros((2,3), mx.gpu(0))
>>> y.ctx
gpu(0)
"""
return self.context

@property
def dtype(self):
"""Data-type of the array's elements.
Expand Down Expand Up @@ -2580,7 +2597,7 @@ def astype(self, dtype, copy=True):
if not copy and np.dtype(dtype) == self.dtype:
return self

res = empty(self.shape, ctx=self.context, dtype=dtype)
res = empty(self.shape, ctx=self.ctx, dtype=dtype)
self.copyto(res)
return res

Expand Down Expand Up @@ -2646,7 +2663,7 @@ def copy(self):
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
"""
return self.copyto(self.context)
return self.copyto(self.ctx)

def slice_assign_scalar(self, value, begin, end, step):
"""
Expand Down Expand Up @@ -2904,7 +2921,7 @@ def _full(self, value):
"""
This is added as an NDArray class method in order to support polymorphism in NDArray and numpy.ndarray indexing
"""
return _internal._full(self.shape, value=value, ctx=self.context, dtype=self.dtype, out=self)
return _internal._full(self.shape, value=value, ctx=self.ctx, dtype=self.dtype, out=self)

def _scatter_set_nd(self, value_nd, indices):
"""
Expand Down Expand Up @@ -4542,7 +4559,7 @@ def concatenate(arrays, axis=0, always_copy=True):
assert shape_rest2 == arr.shape[axis+1:]
assert dtype == arr.dtype
ret_shape = shape_rest1 + (shape_axis,) + shape_rest2
ret = empty(ret_shape, ctx=arrays[0].context, dtype=dtype)
ret = empty(ret_shape, ctx=arrays[0].ctx, dtype=dtype)

idx = 0
begin = [0 for _ in ret_shape]
Expand Down
11 changes: 6 additions & 5 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,15 +921,15 @@ def __repr__(self):
elif dtype not in (_np.float32, _np.bool_):
array_str = array_str[:-1] + ', dtype={})'.format(dtype)

context = self.context
context = self.ctx
if context.device_type == 'cpu':
return array_str
return array_str[:-1] + ', ctx={})'.format(str(context))

def __str__(self):
"""Returns a string representation of the array."""
array_str = self.asnumpy().__str__()
context = self.context
context = self.ctx
if context.device_type == 'cpu' or self.ndim == 0:
return array_str
return '{array} @{ctx}'.format(array=array_str, ctx=context)
Expand Down Expand Up @@ -994,7 +994,7 @@ def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-ar
if not copy and _np.dtype(dtype) == self.dtype:
return self

res = empty(self.shape, dtype=dtype, ctx=self.context)
res = empty(self.shape, dtype=dtype, ctx=self.ctx)
self.copyto(res)
return res

Expand Down Expand Up @@ -1051,7 +1051,8 @@ def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ

def as_in_context(self, context):
"""This function has been deprecated. Please refer to ``ndarray.as_in_ctx``."""
warnings.warn('ndarray.context has been renamed to ndarray.ctx', DeprecationWarning)
warnings.warn('ndarray.as_in_context has been renamed to'
' ndarray.as_in_ctx', DeprecationWarning)
return self.as_nd_ndarray().as_in_context(context).as_np_ndarray()

def as_in_ctx(self, ctx):
Expand Down Expand Up @@ -1864,7 +1865,7 @@ def _full(self, value):
Currently for internal use only. Implemented for __setitem__.
Assign to self an array of self's same shape and type, filled with value.
"""
return _mx_nd_np.full(self.shape, value, ctx=self.context, dtype=self.dtype, out=self)
return _mx_nd_np.full(self.shape, value, ctx=self.ctx, dtype=self.dtype, out=self)

# pylint: disable=redefined-outer-name
def _scatter_set_nd(self, value_nd, indices):
Expand Down