Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add kron and expand_dims #7

Merged
merged 10 commits into from
Oct 23, 2024
2 changes: 2 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@
:toctree: generated

atleast_nd
expand_dims
kron
```
44 changes: 29 additions & 15 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pylint = "*"
# import dependencies for mypy:
array-api-strict = "*"
numpy = "*"
pytest = "*"

[tool.pixi.feature.lint.tasks]
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from ._funcs import atleast_nd
from ._funcs import atleast_nd, expand_dims, kron

__version__ = "0.1.2.dev0"

__all__ = ["__version__", "atleast_nd"]
__all__ = ["__version__", "atleast_nd", "expand_dims", "kron"]
194 changes: 192 additions & 2 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd"]
__all__ = ["atleast_nd", "expand_dims", "kron"]


def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.

Expand Down Expand Up @@ -46,3 +46,193 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
return x


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
) -> Array:
"""
Expand the shape of an array.

Insert (a) new axis/axes that will appear at the position(s) specified by
`axis` in the expanded array shape.

This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.

Parameters
----------
a : array
axis : int or tuple of ints, optional
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
If multiple positions are provided, they should be unique (note that a position
given by a positive index could also be referred to by a negative index -
that will also result in an error).
Default: ``(0,)``.
xp : array_namespace
The standard-compatible namespace for `a`.

Returns
-------
res : array
`a` with an expanded shape.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1, 2])
>>> x.shape
(2,)

The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:

>>> y = xpx.expand_dims(x, axis=0, xp=xp)
>>> y
Array([[1, 2]], dtype=array_api_strict.int64)
>>> y.shape
(1, 2)

The following is equivalent to ``x[:, xp.newaxis]``:

>>> y = xpx.expand_dims(x, axis=1, xp=xp)
>>> y
Array([[1],
[2]], dtype=array_api_strict.int64)
>>> y.shape
(2, 1)

``axis`` may also be a tuple:

>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
>>> y
Array([[[1, 2]]], dtype=array_api_strict.int64)

>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
>>> y
Array([[[1],
[2]]], dtype=array_api_strict.int64)

"""
if not isinstance(axis, tuple):
axis = (axis,)
ndim = a.ndim + len(axis)
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
err_msg = (
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
)
raise IndexError(err_msg)
axis = tuple(dim % ndim for dim in axis)
if len(set(axis)) != len(axis):
err_msg = "Duplicate dimensions specified in `axis`."
raise ValueError(err_msg)
for i in sorted(axis):
a = xp.expand_dims(a, axis=i)
return a


def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
"""
Kronecker product of two arrays.

Computes the Kronecker product, a composite array made of blocks of the
second array scaled by the first.

Equivalent to ``numpy.kron`` for NumPy arrays.

Parameters
----------
a, b : array
xp : array_namespace
The standard-compatible namespace for `a` and `b`.

Returns
-------
res : array
The Kronecker product of `a` and `b`.

Notes
-----
The function assumes that the number of dimensions of `a` and `b`
are the same, if necessary prepending the smallest with ones.
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
The elements are products of elements from `a` and `b`, organized
explicitly by::

kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]

where::

kt = it * st + jt, t = 0,...,N

In the common 2-D case (N=1), the block structure can be visualized::

[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
[ ... ... ],
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]


Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
Array([ 5, 6, 7, 50, 60, 70, 500,
600, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
Array([ 5, 50, 500, 6, 60, 600, 7,
70, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
Array([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]], dtype=array_api_strict.float64)


>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
>>> c = xpx.kron(a, b, xp=xp)
>>> c.shape
(2, 10, 6, 20)
>>> I = (1, 3, 0, 2)
>>> J = (0, 2, 1)
>>> J1 = (0,) + J # extend to ndim=4
>>> S1 = (1,) + b.shape
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
>>> c[K] == a[I]*b[J]
Array(True, dtype=array_api_strict.bool)

"""

b = xp.asarray(b)
singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
if nd_a == 0 or nd_b == 0:
return xp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape

# Equalise the shapes by prepending smaller one with 1s
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
b_shape = (1,) * max(0, nd_a - nd_b) + b_shape

# Insert empty dimensions
a_arr = expand_dims(a, axis=tuple(range(nd_b - nd_a)), xp=xp)
b_arr = expand_dims(b, axis=tuple(range(nd_a - nd_b)), xp=xp)

# Compute the product
a_arr = expand_dims(a_arr, axis=tuple(range(1, nd_max * 2, 2)), xp=xp)
b_arr = expand_dims(b_arr, axis=tuple(range(0, nd_max * 2, 2)), xp=xp)
result = xp.multiply(a_arr, b_arr)

# Reshape back and return
a_shape = xp.asarray(a_shape)
b_shape = xp.asarray(b_shape)
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
Loading