Skip to content

Commit

Permalink
add preliminary support for where-clause evaluation in 'registerCommand'
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
  • Loading branch information
jeremiah-corrado committed Oct 15, 2024
1 parent f4a9d4b commit 16395f1
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 99 deletions.
12 changes: 3 additions & 9 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,9 @@ module AryUtil
flatten a multi-dimensional array into a 1D array
*/
@arkouda.registerCommand
proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank > 1
{
proc flatten(const ref a: [?d] ?t): [] t throws {
if a.rank == 1 then return a;

var flat = makeDistArray(d.size, t);

// ranges of flat indices owned by each locale
Expand Down Expand Up @@ -1004,12 +1004,6 @@ module AryUtil
return flat;
}

proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank == 1
{
return a;
}

// helper for computing an array element's index from its order
record orderer {
param rank: int;
Expand Down
5 changes: 0 additions & 5 deletions src/LinalgMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,6 @@ module LinalgMsg {
return ret;
}

proc transpose(array: [?d] ?t): [d] t throws
where d.rank < 2 {
throw new Error("Matrix transpose with arrays of dimension < 2 is not supported");
}

/*
Compute the generalized dot product of two tensors along the specified axis.
Expand Down
10 changes: 0 additions & 10 deletions src/MsgProcessing.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,6 @@ module MsgProcessing
return msg;
}

proc chunkInfoAsString(array: [?d] ?t): string throws
where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){
throw new Error("chunkInfo does not support dtype %s".format(t:string));
}

@arkouda.registerCommand
proc chunkInfoAsArray(array: [?d] ?t):[] int throws
where (t == bool) || (t == int(64)) || (t == uint(64)) || (t == uint(8)) ||(t == real) {
Expand All @@ -357,9 +352,4 @@ module MsgProcessing
}
return blockSizes;
}

proc chunkInfoAsArray(array: [?d] ?t): [d] int throws
where (t != bool) && (t != int(64)) && (t != uint(64)) && (t != uint(8)) && (t != real){
throw new Error("chunkInfo does not support dtype %s".format(t:string));
}
}
28 changes: 11 additions & 17 deletions src/SortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ module SortMsg

/* sort takes pdarray and returns a sorted copy of the array */
@arkouda.registerCommand
proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ((t == real) || (t == int) || (t == uint(64))) && (d.rank == 1) {
proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ((t == real) || (t == int) || (t == uint(64)))
do return sortHelp(array, alg, axis);

proc sortHelp(array: [?d] ?t, alg: string, axis: int): [d] t throws
where d.rank == 1
{
var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg);
const itemsize = dtypeSize(whichDtype(t));
overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize));
Expand All @@ -48,9 +52,9 @@ module SortMsg
}
}

proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ((t == real) || (t==int) || (t==uint(64))) && (d.rank > 1) {

proc sortHelp(array: [?d] ?t, alg: string, axis: int): [d] t throws
where d.rank > 1
{
var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg);
const itemsize = dtypeSize(whichDtype(t));
overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize));
Expand Down Expand Up @@ -91,16 +95,11 @@ module SortMsg
return sorted;
}

proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ((t != real) && (t!=int) && (t!=uint(64))) {
throw new Error("sort does not support type %s".format(type2str(t)));
}

// https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#array_api.searchsorted
@arkouda.registerCommand
proc searchSorted(x1: [?d1] real, x2: [?d2] real, side: string): [d2] int throws
where (d1.rank == 1) {

where d1.rank == 1
{
if side != "left" && side != "right" {
throw new Error("searchSorted side must be a string with value 'left' or 'right'.");
}
Expand All @@ -123,11 +122,6 @@ module SortMsg
return ret;
}

proc searchSorted(x1: [?d1] real, x2: [?d2] real, side: string): [d2] int throws
where (d1.rank != 1){
throw new Error("searchSorted only arrays x1 of dimension 1.");
}

record leftCmp: relativeComparator {
proc compare(a: real, b: real): int {
if a < b then return -1;
Expand Down
2 changes: 1 addition & 1 deletion src/SparseMatrixMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ module SparseMatrixMsg {
return MsgTuple.fromResponses(responses);
}

@arkouda.registerCommand("fill_sparse_vals")
@arkouda.registerCommand("fill_sparse_vals", ignoreWhereClause=true)
proc fillSparseMatrixMsg(matrix: borrowed SparseSymEntry(?), vals: [?d] ?t /* matrix.etype */) throws
where t == matrix.etype && d.rank == 1
do fillSparseMatrix(matrix.a, vals, matrix.matLayout);
Expand Down
18 changes: 0 additions & 18 deletions src/StatsMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,6 @@ module StatsMsg {
return (+ reduce ((x:real - mx) * (y:real - my))) / (dx.size - 1):real;
}

// above registration will instantiate `cov` for all combinations of array ranks
// even though it is only valid when the ranks are the same
// (respecting the where clause in the signature is future work for 'registerCommand')
proc cov(const ref x: [?dx], const ref y: [?dy]): real throws
where dx.rank != dy.rank
{
throw new Error("x and y must have the same rank");
}

@arkouda.registerCommand()
proc corr(const ref x: [?dx] ?tx, const ref y: [?dy] ?ty): real throws
where dx.rank == dy.rank
Expand All @@ -107,15 +98,6 @@ module StatsMsg {
return cov(x, y) / (std(x, 1) * std(y, 1));
}

// above registration will instantiate `corr` for all combinations of array ranks
// even though it is only valid when the ranks are the same
// (respecting the where clause in the signature is future work for 'registerCommand')
proc corr(const ref x: [?dx], const ref y: [?dy]): real throws
where dx.rank != dy.rank
{
throw new Error("x and y must have the same rank");
}

@arkouda.registerCommand()
proc cumSum(const ref x: [?d] ?t, axis: int, includeInitial: bool): [] t throws {
if d.rank == 1 {
Expand Down
10 changes: 0 additions & 10 deletions src/UtilMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ module UtilMsg {
return y;
}

proc clip(const ref x: [?d] ?t, min: real, max: real): [d] t throws
where (t != int) && (t != real) && (t != uint(8)) && (t != uint(64)){
throw new Error("clip does not support dtype %s".format(t:string));
}

/*
Compute the n'th order discrete difference along a given axis
Expand Down Expand Up @@ -95,11 +90,6 @@ module UtilMsg {
}
}

proc diff(x: [?d] ?t, n: int, axis: int): [d] t throws
where (t != real) && (t != int) && (t != uint(8)) && (t != uint(64)){
throw new Error("diff does not support dtype %s".format(t:string));
}

// helper to create a domain that's 'n' elements smaller in the 'axis' dimension
private proc subDomain(shape: ?N*int, axis: int, n: int) {
var rngs: N*range;
Expand Down
Loading

0 comments on commit 16395f1

Please sign in to comment.