Skip to content

Commit

Permalink
Merge pull request #1008 from helmholtz-analytics/api/1006-rename-kee…
Browse files Browse the repository at this point in the history
…pdim-keepdims

API: Rename `keepdim` kwarg to `keepdims`
  • Loading branch information
mtar authored Apr 24, 2023
2 parents cb46e2c + 6b8a18f commit eaa90d6
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 120 deletions.
2 changes: 1 addition & 1 deletion heat/cluster/_kcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _assign_to_cluster(self, x: DNDarray):
"""
# calculate the distance matrix and determine the closest centroid
distances = self._metric(x, self._cluster_centers)
matching_centroids = distances.argmin(axis=1, keepdim=True)
matching_centroids = distances.argmin(axis=1, keepdims=True)

return matching_centroids

Expand Down
4 changes: 2 additions & 2 deletions heat/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
selection = (matching_centroids == i).astype(ht.int64)
# accumulate points and total number of points in cluster
assigned_points = x * selection
points_in_cluster = selection.sum(axis=0, keepdim=True).clip(
points_in_cluster = selection.sum(axis=0, keepdims=True).clip(
1.0, ht.iinfo(ht.int64).max
)

# compute the new centroids
new_cluster_centers[i : i + 1, :] = (assigned_points / points_in_cluster).sum(
axis=0, keepdim=True
axis=0, keepdims=True
)

return new_cluster_centers
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/kmedians.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
else:
if clean.shape[0] <= ht.MPI_WORLD.size:
clean.resplit_(axis=None)
median = ht.median(clean, axis=0, keepdim=True)
median = ht.median(clean, axis=0, keepdims=True)
new_cluster_centers[i : i + 1, :] = median

return new_cluster_centers
Expand Down
4 changes: 2 additions & 2 deletions heat/cluster/kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
else:
if clean.shape[0] <= ht.MPI_WORLD.size:
clean.resplit_(axis=None)
median = ht.median(clean, axis=0, keepdim=True)
median = ht.median(clean, axis=0, keepdims=True)

dist = self._metric(x, median)
_, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0)
idx = dist.argmin(axis=0, keepdim=False).item()
idx = dist.argmin(axis=0, keepdims=False).item()
proc = 0
for p in range(x.comm.size):
if displ[p] > idx:
Expand Down
6 changes: 3 additions & 3 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def __reduce_op(
axis = stride_tricks.sanitize_axis(x.shape, kwargs.get("axis"))
if isinstance(axis, int):
axis = (axis,)
keepdim = kwargs.get("keepdim")
keepdims = kwargs.get("keepdims")
out = kwargs.get("out")
split = x.split
balanced = x.balanced
Expand Down Expand Up @@ -446,7 +446,7 @@ def __reduce_op(
): # no neutral element for max/min
partial = partial_op(partial, dim=dim, keepdim=True)
output_shape = output_shape[:dim] + (1,) + output_shape[dim + 1 :]
if not keepdim and not len(partial.shape) == 1:
if not keepdims and not len(partial.shape) == 1:
gshape_losedim = tuple(x.gshape[dim] for dim in range(len(x.gshape)) if dim not in axis)
lshape_losedim = tuple(x.lshape[dim] for dim in range(len(x.lshape)) if dim not in axis)
output_shape = gshape_losedim
Expand All @@ -464,7 +464,7 @@ def __reduce_op(
balanced = True
if x.comm.is_distributed():
x.comm.Allreduce(MPI.IN_PLACE, partial, reduction_op)
elif axis is not None and not keepdim:
elif axis is not None and not keepdims:
down_dims = len(tuple(dim for dim in axis if dim < x.split))
split -= down_dims
balanced = x.balanced
Expand Down
16 changes: 8 additions & 8 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def prod(
a: DNDarray,
axis: Union[int, Tuple[int, ...]] = None,
out: DNDarray = None,
keepdim: bool = None,
keepdims: bool = None,
) -> DNDarray:
"""
Return the product of array elements over a given axis in form of a DNDarray shaped as a but with the specified axis removed.
Expand All @@ -888,7 +888,7 @@ def prod(
out : DNDarray, optional
Alternative output array in which to place the result. It must have the same shape as the expected output, but
the datatype of the output values will be cast if necessary.
keepdim : bool, optional
keepdims : bool, optional
If this is set to ``True``, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
Expand All @@ -907,11 +907,11 @@ def prod(
DNDarray([ 2., 12.], dtype=ht.float32, device=cpu:0, split=None)
"""
return _operations.__reduce_op(
a, torch.prod, MPI.PROD, axis=axis, out=out, neutral=1, keepdim=keepdim
a, torch.prod, MPI.PROD, axis=axis, out=out, neutral=1, keepdims=keepdims
)


DNDarray.prod = lambda self, axis=None, out=None, keepdim=None: prod(self, axis, out, keepdim)
DNDarray.prod = lambda self, axis=None, out=None, keepdims=None: prod(self, axis, out, keepdims)
DNDarray.prod.__doc__ = prod.__doc__


Expand Down Expand Up @@ -961,7 +961,7 @@ def sum(
a: DNDarray,
axis: Union[int, Tuple[int, ...]] = None,
out: DNDarray = None,
keepdim: bool = None,
keepdims: bool = None,
) -> DNDarray:
"""
Sum of array elements over a given axis. An array with the same shape as ``self.__array`` except for the specified
Expand All @@ -978,7 +978,7 @@ def sum(
out : DNDarray, optional
Alternative output array in which to place the result. It must have the same shape as the expected output, but
the datatype of the output values will be cast if necessary.
keepdim : bool, optional
keepdims : bool, optional
If this is set to ``True``, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
Expand All @@ -996,8 +996,8 @@ def sum(
"""
# TODO: make me more numpy API complete Issue #101
return _operations.__reduce_op(
a, torch.sum, MPI.SUM, axis=axis, out=out, neutral=0, keepdim=keepdim
a, torch.sum, MPI.SUM, axis=axis, out=out, neutral=0, keepdims=keepdims
)


DNDarray.sum = lambda self, axis=None, out=None, keepdim=None: sum(self, axis, out, keepdim)
DNDarray.sum = lambda self, axis=None, out=None, keepdims=None: sum(self, axis, out, keepdims)
36 changes: 18 additions & 18 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,17 +1185,17 @@ def matrix_norm(
if col_axis > row_axis and not keepdims:
col_axis -= 1
return statistics.max(
arithmetics.sum(rounding.abs(x), axis=row_axis, keepdim=keepdims),
arithmetics.sum(rounding.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis,
keepdim=keepdims,
keepdims=keepdims,
)
elif ord == -1:
if col_axis > row_axis and not keepdims:
col_axis -= 1
return statistics.min(
arithmetics.sum(rounding.abs(x), axis=row_axis, keepdim=keepdims),
arithmetics.sum(rounding.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis,
keepdim=keepdims,
keepdims=keepdims,
)
elif ord == 2:
raise NotImplementedError("The largest singular value can't be computed yet.")
Expand All @@ -1205,21 +1205,21 @@ def matrix_norm(
if row_axis > col_axis and not keepdims:
row_axis -= 1
return statistics.max(
arithmetics.sum(rounding.abs(x), axis=col_axis, keepdim=keepdims),
arithmetics.sum(rounding.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdim=keepdims,
keepdims=keepdims,
)
elif ord == -constants.inf:
if row_axis > col_axis and not keepdims:
row_axis -= 1
return statistics.min(
arithmetics.sum(rounding.abs(x), axis=col_axis, keepdim=keepdims),
arithmetics.sum(rounding.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdim=keepdims,
keepdims=keepdims,
)
elif ord in [None, "fro"]:
return exponential.sqrt(
arithmetics.sum((complex_math.conj(x) * x).real, axis=axis, keepdim=keepdims)
arithmetics.sum((complex_math.conj(x) * x).real, axis=axis, keepdims=keepdims)
)
elif ord == "nuc":
raise NotImplementedError("The nuclear norm can't be computed yet.")
Expand Down Expand Up @@ -2277,7 +2277,7 @@ def vdot(x1: DNDarray, x2: DNDarray) -> DNDarray:


def vecdot(
x1: DNDarray, x2: DNDarray, axis: Optional[int] = None, keepdim: Optional[bool] = None
x1: DNDarray, x2: DNDarray, axis: Optional[int] = None, keepdims: Optional[bool] = None
) -> DNDarray:
"""
Computes the (vector) dot product of two DNDarrays.
Expand All @@ -2290,7 +2290,7 @@ def vecdot(
second input array. Must be compatible with x1.
axis : int, optional
axis over which to compute the dot product. The last dimension is used if 'None'.
keepdim : bool, optional
keepdims : bool, optional
If this is set to 'True', the axes which are reduced are left in the result as dimensions with size one.
See Also
Expand All @@ -2310,7 +2310,7 @@ def vecdot(
if axis is None:
axis = m.ndim - 1

return arithmetics.sum(m, axis=axis, keepdim=keepdim)
return arithmetics.sum(m, axis=axis, keepdims=keepdims)


def vector_norm(
Expand Down Expand Up @@ -2386,20 +2386,20 @@ def vector_norm(
raise TypeError("'axis' must be an integer or 1-tuple for vectors.")

if ord == constants.INF:
return statistics.max(rounding.abs(x), axis=axis, keepdim=keepdims)
return statistics.max(rounding.abs(x), axis=axis, keepdims=keepdims)
elif ord == -constants.INF:
return statistics.min(rounding.abs(x), axis=axis, keepdim=keepdims)
return statistics.min(rounding.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return arithmetics.sum(x != 0, axis=axis, keepdim=keepdims).astype(types.float)
return arithmetics.sum(x != 0, axis=axis, keepdims=keepdims).astype(types.float)
elif ord == 1:
return arithmetics.sum(rounding.abs(x), axis=axis, keepdim=keepdims)
return arithmetics.sum(rounding.abs(x), axis=axis, keepdims=keepdims)
elif ord is None or ord == 2:
s = (complex_math.conj(x) * x).real
return exponential.sqrt(arithmetics.sum(s, axis=axis, keepdim=keepdims))
return exponential.sqrt(arithmetics.sum(s, axis=axis, keepdims=keepdims))
elif isinstance(ord, str):
raise ValueError("Norm order {} is invalid for vectors".format(ord))
else:
ret = arithmetics.pow(rounding.abs(x), ord)
ret = arithmetics.sum(ret, axis=axis, keepdim=keepdims)
ret = arithmetics.sum(ret, axis=axis, keepdims=keepdims)
ret = arithmetics.pow(ret, 1.0 / ord)
return ret
4 changes: 2 additions & 2 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,7 +2071,7 @@ def test_vecdot(self):
a = ht.full((4, 4), 2, split=0)
b = ht.ones(4)

c = ht.linalg.vecdot(a, b, axis=0, keepdim=True)
c = ht.linalg.vecdot(a, b, axis=0, keepdims=True)
self.assertEqual(c.dtype, ht.float32)
self.assertEqual(c.device, a.device)
self.assertTrue(ht.equal(c, ht.array([[8, 8, 8, 8]])))
Expand Down Expand Up @@ -2118,7 +2118,7 @@ def test_vector_norm(self):
self.assertEqual(vn.device, b0.device)
self.assertTrue(ht.allclose(vn, ht.array([5.38516481, 1.41421356, 5.38516481], split=0)))

# split matrix axis keepdim norm 3
# split matrix axis keepdims norm 3
vn = ht.vector_norm(b1, axis=1, keepdims=True, ord=3)
self.assertEqual(vn.split, None)
self.assertEqual(vn.dtype, b1.dtype)
Expand Down
20 changes: 10 additions & 10 deletions heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def all(
x: DNDarray,
axis: Union[int, Tuple[int], None] = None,
out: Optional[DNDarray] = None,
keepdim: bool = False,
keepdims: bool = False,
) -> Union[DNDarray, bool]:
"""
Test whether all array elements along a given axis evaluate to ``True``.
Expand All @@ -57,7 +57,7 @@ def all(
out : DNDarray, optional
Alternate output array in which to place the result. It must have the same shape as the expected output
and its type is preserved.
keepdim : bool, optional
keepdims : bool, optional
If this is set to ``True``, the axes which are reduced are left in the result as dimensions with size one.
With this option, the result will broadcast correctly against the original array.
Expand Down Expand Up @@ -91,17 +91,17 @@ def all(
def local_all(t, *args, **kwargs):
return torch.all(t != 0, *args, **kwargs)

if keepdim and axis is None:
if keepdims and axis is None:
axis = tuple(range(x.ndim))

return _operations.__reduce_op(
x, local_all, MPI.LAND, axis=axis, out=out, neutral=1, keepdim=keepdim
x, local_all, MPI.LAND, axis=axis, out=out, neutral=1, keepdims=keepdims
)


DNDarray.all: Callable[
[Union[int, Tuple[int], None], Optional[DNDarray], bool], Union[DNDarray, bool]
] = lambda self, axis=None, out=None, keepdim=False: all(self, axis, out, keepdim)
] = lambda self, axis=None, out=None, keepdims=False: all(self, axis, out, keepdims)
DNDarray.all.__doc__ = all.__doc__


Expand Down Expand Up @@ -170,7 +170,7 @@ def allclose(


def any(
x, axis: Optional[int] = None, out: Optional[DNDarray] = None, keepdim: bool = False
x, axis: Optional[int] = None, out: Optional[DNDarray] = None, keepdims: bool = False
) -> DNDarray:
"""
Returns a :class:`~heat.core.dndarray.DNDarray` containing the result of the test whether any array elements along a
Expand All @@ -187,7 +187,7 @@ def any(
out : DNDarray, optional
Alternative output tensor in which to place the result. It must have the same shape as the expected output.
The output is a array with ``datatype=bool``.
keepdim : bool, optional
keepdims : bool, optional
If this is set to ``True``, the axes which are reduced are left in the result as dimensions with size one.
With this option, the result will broadcast correctly against the original array.
Expand All @@ -211,17 +211,17 @@ def any(
def local_any(t, *args, **kwargs):
return torch.any(t != 0, *args, **kwargs)

if keepdim and axis is None:
if keepdims and axis is None:
axis = tuple(range(x.ndim))

return _operations.__reduce_op(
x, local_any, MPI.LOR, axis=axis, out=out, neutral=0, keepdim=keepdim
x, local_any, MPI.LOR, axis=axis, out=out, neutral=0, keepdims=keepdims
)


DNDarray.any: Callable[
[DNDarray, Optional[int], Optional[DNDarray], bool], DNDarray
] = lambda self, axis=None, out=None, keepdim=False: any(self, axis, out, keepdim)
] = lambda self, axis=None, out=None, keepdims=False: any(self, axis, out, keepdims)
DNDarray.any.__doc__ = any.__doc__


Expand Down
Loading

0 comments on commit eaa90d6

Please sign in to comment.