Skip to content

Commit

Permalink
restore some code
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Oct 16, 2024
1 parent 411f3bd commit 3f7266c
Showing 1 changed file with 146 additions and 91 deletions.
237 changes: 146 additions & 91 deletions src/ReductionMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ module ReductionMsg
use RadixSortLSD;

private config const lBins = 2**25 * numLocales;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
const rmLogger = new Logger(logLevel, logChannel);
Expand All @@ -41,6 +40,100 @@ module ReductionMsg
Supports: 'sum', 'prod', 'min', 'max'
*/


@arkouda.registerND(cmd_prefix="reduce")
proc argTypeReductionMessage(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
use SliceReductionOps;
param pn = Reflection.getRoutineName();
const x = msgArgs.getValueOf("x"),
op = msgArgs.getValueOf("op"),
nAxes = msgArgs.get("nAxes").getIntValue(),
axesRaw = msgArgs.get("axis").toScalarArray(int, nAxes),
skipNan = msgArgs.get("skipNan").getBoolValue(),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(x, st);

if !basicReductionOps.contains(op) {
const errorMsg = notImplementedError(pn,op,gEnt.dtype);
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}

proc computeReduction(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd);
type opType = if t == bool then int else t;

if nd == 1 || nAxes == 0 {
var s: opType;
select op {
when "sum" do s = if skipNan then sumSkipNan(eIn.a, opType) else (+ reduce eIn.a:opType):opType;
when "prod" do s = if skipNan then prodSkipNan(eIn.a, opType) else (* reduce eIn.a:opType):opType;
when "min" do s = if skipNan then getMinSkipNan(eIn.a) else min reduce eIn.a;
when "max" do s = if skipNan then getMaxSkipNan(eIn.a) else max reduce eIn.a;
otherwise halt("unreachable");
}

const scalarValue = if (t == bool && (op == "min" || op == "max"))
then "bool " + bool2str(if s == 1 then true else false)
else (type2str(opType) + " " + type2fmt(opType)).format(s);
rmLogger.debug(getModuleName(),pn,getLineNumber(),scalarValue);
return new MsgTuple(scalarValue, MsgType.NORMAL);
} else {
const (valid, axes) = validateNegativeAxes(axesRaw, nd);
if !valid {
var errorMsg = "Invalid axis value(s) '%?' in slicing reduction".format(axesRaw);
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
} else {
const outShape = reducedShape(eIn.a.shape, axes);
var eOut = st.addEntry(rname, outShape, opType);

forall sliceIdx in domOffAxis(eIn.a.domain, axes) {
const sliceDom = domOnAxis(eIn.a.domain, sliceIdx, axes);
var s: opType;
select op {
when "sum" do s = if skipNan
then sumSkipNan(eIn.a, sliceDom, opType)
else sum(eIn.a, sliceDom, opType);
when "prod" do s =if skipNan
then prodSkipNan(eIn.a, sliceDom, opType)
else prod(eIn.a, sliceDom, opType);
when "min" do s = if skipNan
then getMinSkipNan(eIn.a, sliceDom)
else getMin(eIn.a, sliceDom);
when "max" do s = if skipNan
then getMaxSkipNan(eIn.a, sliceDom)
else getMax(eIn.a, sliceDom);
otherwise halt("unreachable");
}
eOut.a[sliceIdx] = s;
}

const repMsg = "created " + st.attrib(rname);
rmLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
}
}

select gEnt.dtype {
when DType.Int64 do return computeReduction(int);
when DType.UInt64 do return computeReduction(uint);
when DType.Float64 do return computeReduction(real);
when DType.Bool do return computeReduction(bool);
otherwise {
var errorMsg = notImplementedError(pn,dtype2str(gEnt.dtype));
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
}
}
}




@arkouda.registerCommand
proc sum(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] t throws
where (t==int || t==real || t==uint(64)) && (x.rank == 1) && (axis.rank == 1) {
Expand Down Expand Up @@ -87,7 +180,6 @@ module ReductionMsg
throw new Error("sum does not support type %s".format(type2str(t)));
}


@arkouda.registerCommand
proc prod(ref x:[?d] ?t, axis: [?d2] int, skipNan: bool): [] t throws
where (t==int || t==real || t==uint(64)) && (x.rank == 1) && (axis.rank == 1) {
Expand Down Expand Up @@ -224,95 +316,6 @@ module ReductionMsg
throw new Error("min does not support type %s".format(type2str(t)));
}

@arkouda.registerND(cmd_prefix="reduce")
proc argTypeReductionMessage(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
use SliceReductionOps;
param pn = Reflection.getRoutineName();
const x = msgArgs.getValueOf("x"),
op = msgArgs.getValueOf("op"),
nAxes = msgArgs.get("nAxes").getIntValue(),
axesRaw = msgArgs.get("axis").toScalarArray(int, nAxes),
skipNan = msgArgs.get("skipNan").getBoolValue(),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(x, st);

if !basicReductionOps.contains(op) {
const errorMsg = notImplementedError(pn,op,gEnt.dtype);
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}

proc computeReduction(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd);
type opType = if t == bool then int else t;

if nd == 1 || nAxes == 0 {
var s: opType;
select op {
when "sum" do s = if skipNan then sumSkipNan(eIn.a, opType) else (+ reduce eIn.a:opType):opType;
when "prod" do s = if skipNan then prodSkipNan(eIn.a, opType) else (* reduce eIn.a:opType):opType;
when "min" do s = if skipNan then getMinSkipNan(eIn.a) else min reduce eIn.a;
when "max" do s = if skipNan then getMaxSkipNan(eIn.a) else max reduce eIn.a;
otherwise halt("unreachable");
}

const scalarValue = if (t == bool && (op == "min" || op == "max"))
then "bool " + bool2str(if s == 1 then true else false)
else (type2str(opType) + " " + type2fmt(opType)).format(s);
rmLogger.debug(getModuleName(),pn,getLineNumber(),scalarValue);
return new MsgTuple(scalarValue, MsgType.NORMAL);
} else {
const (valid, axes) = validateNegativeAxes(axesRaw, nd);
if !valid {
var errorMsg = "Invalid axis value(s) '%?' in slicing reduction".format(axesRaw);
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
} else {
const outShape = reducedShape(eIn.a.shape, axes);
var eOut = st.addEntry(rname, outShape, opType);

forall sliceIdx in domOffAxis(eIn.a.domain, axes) {
const sliceDom = domOnAxis(eIn.a.domain, sliceIdx, axes);
var s: opType;
select op {
when "sum" do s = if skipNan
then sumSkipNan(eIn.a, sliceDom, opType)
else sum(eIn.a, sliceDom, opType);
when "prod" do s =if skipNan
then prodSkipNan(eIn.a, sliceDom, opType)
else prod(eIn.a, sliceDom, opType);
when "min" do s = if skipNan
then getMinSkipNan(eIn.a, sliceDom)
else getMin(eIn.a, sliceDom);
when "max" do s = if skipNan
then getMaxSkipNan(eIn.a, sliceDom)
else getMax(eIn.a, sliceDom);
otherwise halt("unreachable");
}
eOut.a[sliceIdx] = s;
}

const repMsg = "created " + st.attrib(rname);
rmLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
}
}

select gEnt.dtype {
when DType.Int64 do return computeReduction(int);
when DType.UInt64 do return computeReduction(uint);
when DType.Float64 do return computeReduction(real);
when DType.Bool do return computeReduction(bool);
otherwise {
var errorMsg = notImplementedError(pn,dtype2str(gEnt.dtype));
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
}
}
}

/*
Compute an array reduction along one or more axes.
(where the result has a bool data type)
Expand Down Expand Up @@ -639,6 +642,7 @@ module ReductionMsg
return sum == a.size;
}


proc sum(ref a: [] ?t, slice, type opType): opType {
var sum = 0:opType;
forall i in slice with (+ reduce sum) do sum += a[i]:opType;
Expand Down Expand Up @@ -710,6 +714,57 @@ module ReductionMsg
}
return maxVal;
}
proc sumSlice(ref a: [?d] ?t, slice, type opType, skipNan: bool): opType {
var sum = 0:opType;
if skipNan{
forall i in slice with (+ reduce sum) {
if isArgandType(t) { if isNan(a[i]) then continue; }
sum += a[i]:opType;
}
}else{
forall i in slice with (+ reduce sum) do sum += a[i]:opType;
}
return sum;
}

proc prodSlice(ref a: [] ?t, slice, type opType, skipNan: bool): opType {
var prod = 1.0; // always use real(64) to avoid int overflow
if skipNan{
forall i in slice with (* reduce prod) {
if isArgandType(t) { if isNan(a[i]) then continue; }
prod *= a[i]:opType;
}
}else{
forall i in slice with (* reduce prod) do prod *= a[i]:opType;
}
return prod:opType;
}

proc getMinSlice(ref a: [] ?t, slice, skipNan: bool): t {
var minVal = max(t);
if skipNan{
forall i in slice with (min reduce minVal) {
if isArgandType(t) { if isNan(a[i]) then continue; }
minVal reduce= a[i];
}
}else{
forall i in slice with (min reduce minVal) do minVal reduce= a[i];
}
return minVal;
}

proc getMaxSlice(ref a: [] ?t, slice, skipNan: bool): t {
var maxVal = min(t);
if skipNan{
forall i in slice with (max reduce maxVal) {
if isArgandType(t) { if isNan(a[i]) then continue; }
maxVal reduce= a[i];
}
}else{
forall i in slice with (max reduce maxVal) do maxVal reduce= a[i];
}
return maxVal;
}

proc argmin(ref a: [?d] ?t, slice, axis: int): d.idxType {
var minValLoc = (max(t), d.low);
Expand Down

0 comments on commit 3f7266c

Please sign in to comment.