Skip to content

Commit

Permalink
Support where-clause evaluation in registration annotations (Bears-R-…
Browse files Browse the repository at this point in the history
…Us#3841)

* add primitive where-clause evaluation (for instantiateAndRegister) to register_commands.py

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* remove uninstantiated overloads from commands annotated with instantiateAndRegister

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* remove match statements from register_commands.py

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* add error message for 'pad' w/ bigint

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* add preliminary support for where-clause evaluation in 'registerCommand'

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix misinterpretation of explicit type widths (e.g., uint(64)) in where clauses

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix error message

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix error message

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* refactor 'sum' to not rely on where-clauses for dispatching

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* refactor prod, max, and min to take their 'axis' argument as a list

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix mypy errors and multi-dim build failure

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* create list-specific implementations of aryUtil helper procedures used in ReductionMsg. Fix 'axis' argument in 'ak.sum'

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix return type change to reduction procedures

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix reducedShape helper for lists

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

---------

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
  • Loading branch information
jeremiah-corrado authored Oct 21, 2024
1 parent 0ee2a36 commit 76408db
Show file tree
Hide file tree
Showing 18 changed files with 351 additions and 357 deletions.
9 changes: 9 additions & 0 deletions arkouda/array_api/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def clip(a: Array, a_min, a_max, /) -> Array:
a_max : scalar
The maximum value
"""
if a.dtype == ak.bigint or a.dtype == ak.bool_:
raise RuntimeError(f"Error executing command: clip does not support dtype {a.dtype}")

return Array._new(
create_pdarray(
generic_msg(
Expand Down Expand Up @@ -99,6 +102,9 @@ def diff(a: Array, /, n: int = 1, axis: int = -1, prepend=None, append=None) ->
append : Array, optional
Array to append to `a` along `axis` before calculating the difference.
"""
if a.dtype == ak.bigint or a.dtype == ak.bool_:
raise RuntimeError(f"Error executing command: diff does not support dtype {a.dtype}")

if prepend is not None and append is not None:
a_ = concat((prepend, a, append), axis=axis)
elif prepend is not None:
Expand Down Expand Up @@ -146,6 +152,9 @@ def pad(
if mode != "constant":
raise NotImplementedError(f"pad mode '{mode}' is not supported")

if array.dtype == ak.bigint:
raise RuntimeError("Error executing command: pad does not support dtype bigint")

if "constant_values" not in kwargs:
cvals = 0
else:
Expand Down
43 changes: 17 additions & 26 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2720,16 +2720,6 @@ def is_sorted(pda: pdarray) -> np.bool_:
)


def _get_axis_pdarray(axis: Optional[Union[int, Tuple[int, ...]]] = None):
from arkouda import array as ak_array

axis_list = []
if axis is not None:
axis_list = list(axis) if isinstance(axis, tuple) else [axis]

return ak_array(axis_list, dtype="int64")


@typechecked
def sum(
pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]] = None
Expand Down Expand Up @@ -2757,12 +2747,13 @@ def sum(
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"sum<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"sum<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
# TODO: remove call to 'flatten'
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down Expand Up @@ -2845,12 +2836,12 @@ def prod(pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Un
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"prod<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"prod<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down Expand Up @@ -2882,12 +2873,12 @@ def min(
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"min<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"min<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down Expand Up @@ -2920,12 +2911,12 @@ def max(
RuntimeError
Raised if there's a server-side error thrown
"""
axis_arry = _get_axis_pdarray(axis)
axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
repMsg = generic_msg(
cmd=f"max<{pda.dtype.name},{pda.ndim},{axis_arry.ndim}>",
args={"x": pda, "axis": axis_arry, "skipNan": False},
cmd=f"max<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "axis": axis_, "skipNan": False},
)
if axis is None or len(axis_arry) == 0 or pda.ndim == 1:
if axis is None or len(axis_) == 0 or pda.ndim == 1:
return create_pdarray(cast(str, repMsg)).flatten()[0]
else:
return create_pdarray(cast(str, repMsg))
Expand Down
8 changes: 1 addition & 7 deletions src/ArgSortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,11 @@ module ArgSortMsg
axis = msgArgs["axis"].toScalar(int),
symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype);

const iv = argsortDefault(vals, algorithm=algorithm, axis);
return st.insert(new shared SymEntry(iv));
}

proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("argsort does not support the %s dtype".format(array_dtype:string));
}

proc argsortStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const name = msgArgs["name"].toScalar(string),
strings = getSegString(name, st),
Expand Down
36 changes: 27 additions & 9 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ module AryUtil
return (true, ret);
}

proc validateNegativeAxes(axes: list(int), param nd: int): (bool, list(int)) {
var ret = new list(int);
for a in axes {
if a >= 0 && a < nd {
ret.pushBack(a);
} else if a < 0 && a >= -nd {
ret.pushBack(nd + a);
} else {
return (false, ret);
}
}
return (true, ret);
}

/*
Get a domain that selects out the idx'th set of indices along the specified axes
Expand Down Expand Up @@ -328,6 +342,16 @@ module AryUtil
return ret;
}

proc reducedShape(shape: ?N*int, axes: list(int)): N*int {
var ret: N*int;
for param i in 0..<N {
if N == 1 || axes.size == 0 || axes.contains(i)
then ret[i] = 1;
else ret[i] = shape[i];
}
return ret;
}

/*
Returns stats on a given array in form (int,int,real,real,real).
Expand Down Expand Up @@ -947,9 +971,9 @@ module AryUtil
flatten a multi-dimensional array into a 1D array
*/
@arkouda.registerCommand
proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank > 1
{
proc flatten(const ref a: [?d] ?t): [] t throws {
if a.rank == 1 then return a;

var flat = makeDistArray(d.size, t);

// ranges of flat indices owned by each locale
Expand Down Expand Up @@ -1006,12 +1030,6 @@ module AryUtil
return flat;
}

proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank == 1
{
return a;
}

// helper for computing an array element's index from its order
record orderer {
param rank: int;
Expand Down
22 changes: 1 addition & 21 deletions src/CastMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@ module CastMsg {
private config const logChannel = ServerConfig.logChannel;
const castLogger = new Logger(logLevel, logChannel);

proc isFloatingType(type t) param : bool {
return isRealType(t) || isImagType(t) || isComplexType(t);
}

@arkouda.instantiateAndRegister(prefix="cast")
proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_from,
type array_dtype_to,
param array_nd: int
): MsgTuple throws
where !(isFloatingType(array_dtype_from) && array_dtype_to == bigint) &&
where !((isRealType(array_dtype_from) || isImagType(array_dtype_from) || isComplexType(array_dtype_from)) && array_dtype_to == bigint) &&
!(array_dtype_from == bigint && array_dtype_to == bool)
{
const a = st[msgArgs["name"]]: SymEntry(array_dtype_from, array_nd);
Expand All @@ -40,22 +36,6 @@ module CastMsg {
}
}

// cannot cast float types to bigint, cannot cast bigint to bool
proc castArray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_from,
type array_dtype_to,
param array_nd: int
): MsgTuple throws
where (isFloatingType(array_dtype_from) && array_dtype_to == bigint) ||
(array_dtype_from == bigint && array_dtype_to == bool)
{
return MsgTuple.error(
"cannot cast array of type %s to %s".format(
type2str(array_dtype_from),
type2str(array_dtype_to)
));
}

@arkouda.instantiateAndRegister(prefix="castToStrings")
proc castArrayToStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws {
const name = msgArgs["name"].toScalar(string);
Expand Down
12 changes: 0 additions & 12 deletions src/GenSymIO.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ module GenSymIO {
return st.insert(new shared SymEntry(makeArrayFromBytes(msgArgs.payload, shape, array_dtype)));
}

proc array(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("Array creation from binary payload is not supported for bigint arrays");
}

proc makeArrayFromBytes(ref payload: bytes, shape: ?N*int, type t): [] t throws {
var size = 1;
for s in shape do size *= s;
Expand Down Expand Up @@ -138,12 +132,6 @@ module GenSymIO {
return MsgTuple.payload(bytes.createAdoptingBuffer(ptr:c_ptr(uint(8)), size, size));
}

proc tondarray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("cannot create ndarray from bigint array");
}

/*
* Utility proc to test casting a string to a specified type
* :arg c: String to cast
Expand Down
14 changes: 0 additions & 14 deletions src/IndexingMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,6 @@ module IndexingMsg
}
}

proc multiPDArrayIndex(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_a, type array_dtype_idx, param array_nd: int): MsgTuple throws
where array_dtype_idx != int && array_dtype_idx != uint
{
return MsgTuple.error("Invalid index type: %s; must be 'int' or 'uint'".format(type2str(array_dtype_idx)));
}

private proc multiIndexShape(inShape: ?N*int, idxDims: [?d] int, outSize: int): (bool, int, N*int) {
var minShape: N*int = inShape,
firstRank = -1;
Expand Down Expand Up @@ -960,14 +954,6 @@ module IndexingMsg
return st.insert(new shared SymEntry(y, x.max_bits));
}

proc takeAlongAxis(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype_x,
type array_dtype_idx,
param array_nd: int
): MsgTuple throws {
return MsgTuple.error("Cannot take along axis with non-integer index array");
}

use CommandMap;
registerFunction("arrayViewMixedIndex", arrayViewMixedIndexMsg, getModuleName());
registerFunction("[pdarray]", pdarrayIndexMsg, getModuleName());
Expand Down
53 changes: 3 additions & 50 deletions src/LinalgMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ module LinalgMsg {
return st.insert(e);
}


proc eye(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws
where array_dtype == BigInteger.bigint
{
return MsgTuple.error("eye does not support the bigint dtype");
}

// tril and triu are identical except for the argument they pass to triluHandler (true for upper, false for lower)
// The zeros are written into the upper (or lower) triangle of the array, offset by the value of diag.

Expand All @@ -79,11 +72,6 @@ module LinalgMsg {
return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, false);
}

proc tril(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd < 2 {
return MsgTuple.error("Array must be at least 2 dimensional for 'tril'");
}

// Create an array from an existing array with its lower triangle zeroed out

@arkouda.instantiateAndRegister
Expand All @@ -92,13 +80,9 @@ module LinalgMsg {
return triluHandler(cmd, msgArgs, st, array_dtype, array_nd, true);
}

proc triu(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd < 2 {
return MsgTuple.error("Array must be at least 2 dimensional for 'triu'");
}

// Fetch the arguments, call zeroTri, return result.

// Fetch the arguments, call zeroTri, return result.
// TODO: support instantiating param bools with 'true' and 'false' s.t. we'd have 'triluHandler<true>' and 'triluHandler<false>'
// cmds if this procedure were annotated instead of the two above.
proc triluHandler(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype, param array_nd: int, param upper: bool
): MsgTuple throws {
Expand Down Expand Up @@ -195,16 +179,6 @@ module LinalgMsg {

}

proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_nd < 2) && (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) {
return MsgTuple.error("Matrix multiplication with arrays of dimension < 2 is not supported");
}

proc matmul(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) {
return MsgTuple.error("Matrix multiplication with arrays of bigint type is not supported");
}

proc compute_result_type_matmul(type t1, type t2) type {
if t1 == real || t2 == real then return real;
if t1 == int || t2 == int then return int;
Expand Down Expand Up @@ -302,11 +276,6 @@ module LinalgMsg {
return ret;
}

proc transpose(array: [?d] ?t): [d] t throws
where d.rank < 2 {
throw new Error("Matrix transpose with arrays of dimension < 2 is not supported");
}

/*
Compute the generalized dot product of two tensors along the specified axis.
Expand Down Expand Up @@ -366,22 +335,6 @@ module LinalgMsg {
return bool;
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_nd < 2) && ((array_dtype_x1 != bool) || (array_dtype_x2 != bool))
&& (array_dtype_x1 != BigInteger.bigint) && (array_dtype_x2 != BigInteger.bigint) {
return MsgTuple.error("VecDot with arrays of dimension < 2 is not supported");
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == bool) && (array_dtype_x2 == bool) {
return MsgTuple.error("VecDot with arrays both of type bool is not supported");
}

proc vecdot(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_x1, type array_dtype_x2, param array_nd: int): MsgTuple throws
where (array_dtype_x1 == BigInteger.bigint) || (array_dtype_x2 == BigInteger.bigint) {
return MsgTuple.error("VecDot with arrays of type bigint is not supported");
}

// @arkouda.registerND(???)
// proc tensorDotMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd1: int, param nd2: int): MsgTuple throws {
// if nd < 3 {
Expand Down
Loading

0 comments on commit 76408db

Please sign in to comment.