Skip to content

Commit

Permalink
Add support for null treatment clause to various window functions
Browse files Browse the repository at this point in the history
Add parser null treatment clause test

Fix compilation error

Clean up

Remove ignoreNulls attribute from Aggregation class

Fix documentation
  • Loading branch information
ptkool authored and Rongrong Zhong committed Dec 3, 2019
1 parent 350f1cd commit 7efeecc
Show file tree
Hide file tree
Showing 37 changed files with 923 additions and 58 deletions.
12 changes: 9 additions & 3 deletions presto-docs/src/main/sphinx/functions/window.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ Ranking Functions
Value Functions
---------------

Value functions provide an option to specify how null values should be treated when evaluating the
function. Nulls can either be ignored (``IGNORE NULLS``) or respected (``RESPECT NULLS``). By default,
null values are respected. If ``IGNORE NULLS`` is specified, all rows where the value expresssion is
null are excluded from the calculation. If ``IGNORE NULLS`` is specified and the value expression is
null for all rows, the ``default_value`` is returned, or if it is not specified, ``null`` is returned.

.. function:: first_value(x) -> [same as input]

Returns the first value of the window.
Expand All @@ -110,14 +116,14 @@ Value Functions

Returns the value at ``offset`` rows after the current row in the window.
Offsets start at ``0``, which is the current row. The
offset can be any scalar expression. The default ``offset`` is ``1``. If the
offset can be any scalar expression. The default ``offset`` is ``1``. If the
offset is null or larger than the window, the ``default_value`` is returned,
or if it is not specified ``null`` is returned.

.. function:: lag(x[, offset [, default_value]]) -> [same as input]

Returns the value at ``offset`` rows before the current row in the window
Offsets start at ``0``, which is the current row. The
offset can be any scalar expression. The default ``offset`` is ``1``. If the
Offsets start at ``0``, which is the current row. The
offset can be any scalar expression. The default ``offset`` is ``1``. If the
offset is null or larger than the window, the ``default_value`` is returned,
or if it is not specified ``null`` is returned.
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,29 @@ public class WindowFunctionDefinition
private final Type type;
private final FrameInfo frameInfo;
private final List<Integer> argumentChannels;
private final boolean ignoreNulls;

public static WindowFunctionDefinition window(WindowFunctionSupplier functionSupplier, Type type, FrameInfo frameInfo, List<Integer> inputs)
{
return new WindowFunctionDefinition(functionSupplier, type, frameInfo, inputs);
return new WindowFunctionDefinition(functionSupplier, type, frameInfo, false, inputs);
}

public static WindowFunctionDefinition window(WindowFunctionSupplier functionSupplier, Type type, FrameInfo frameInfo, Integer... inputs)
{
return window(functionSupplier, type, frameInfo, Arrays.asList(inputs));
}

WindowFunctionDefinition(WindowFunctionSupplier functionSupplier, Type type, FrameInfo frameInfo, List<Integer> argumentChannels)
public static WindowFunctionDefinition window(WindowFunctionSupplier functionSupplier, Type type, FrameInfo frameInfo, boolean ignoreNulls, List<Integer> inputs)
{
return new WindowFunctionDefinition(functionSupplier, type, frameInfo, ignoreNulls, inputs);
}

public static WindowFunctionDefinition window(WindowFunctionSupplier functionSupplier, Type type, FrameInfo frameInfo, boolean ignoreNulls, Integer... inputs)
{
return window(functionSupplier, type, frameInfo, ignoreNulls, Arrays.asList(inputs));
}

WindowFunctionDefinition(WindowFunctionSupplier functionSupplier, Type type, FrameInfo frameInfo, boolean ignoreNulls, List<Integer> argumentChannels)
{
requireNonNull(functionSupplier, "functionSupplier is null");
requireNonNull(type, "type is null");
Expand All @@ -51,6 +62,7 @@ public static WindowFunctionDefinition window(WindowFunctionSupplier functionSup
this.functionSupplier = functionSupplier;
this.type = type;
this.frameInfo = frameInfo;
this.ignoreNulls = ignoreNulls;
this.argumentChannels = ImmutableList.copyOf(argumentChannels);
}

Expand All @@ -66,6 +78,6 @@ public Type getType()

public WindowFunction createWindowFunction()
{
return functionSupplier.createWindowFunction(argumentChannels);
return functionSupplier.createWindowFunction(argumentChannels, ignoreNulls);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public final String getDescription()
}

@Override
public final WindowFunction createWindowFunction(List<Integer> argumentChannels)
public final WindowFunction createWindowFunction(List<Integer> argumentChannels, boolean ignoreNulls)
{
requireNonNull(argumentChannels, "inputs is null");
checkArgument(argumentChannels.size() == signature.getArgumentTypes().size(),
Expand All @@ -55,12 +55,12 @@ public final WindowFunction createWindowFunction(List<Integer> argumentChannels)
signature.getNameSuffix(),
argumentChannels.size());

return newWindowFunction(argumentChannels);
return newWindowFunction(argumentChannels, ignoreNulls);
}

/**
* Create window function instance using the supplied arguments. The
* inputs have already validated.
*/
protected abstract WindowFunction newWindowFunction(List<Integer> inputs);
protected abstract WindowFunction newWindowFunction(List<Integer> inputs, boolean ignoreNulls);
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public static WindowFunctionSupplier supplier(Signature signature, final Interna
return new AbstractWindowFunctionSupplier(signature, null)
{
@Override
protected WindowFunction newWindowFunction(List<Integer> inputs)
protected WindowFunction newWindowFunction(List<Integer> inputs, boolean ignoreNulls)
{
return new AggregateWindowFunction(function, inputs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ public void processRow(BlockBuilder output, int frameStart, int frameEnd, int cu
return;
}

windowIndex.appendTo(argumentChannel, frameStart, output);
int valuePosition = frameStart;

if (ignoreNulls) {
while (valuePosition >= 0 && valuePosition <= frameEnd) {
if (!windowIndex.isNull(argumentChannel, valuePosition)) {
break;
}

valuePosition++;
}

if (valuePosition > frameEnd) {
output.appendNull();
return;
}
}

windowIndex.appendTo(argumentChannel, valuePosition, output);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,26 @@ public void processRow(BlockBuilder output, int frameStart, int frameEnd, int cu
long offset = (offsetChannel < 0) ? 1 : windowIndex.getLong(offsetChannel, currentPosition);
checkCondition(offset >= 0, INVALID_FUNCTION_ARGUMENT, "Offset must be at least 0");

long valuePosition = currentPosition - offset;
long valuePosition;

if ((valuePosition >= 0) && (valuePosition <= currentPosition)) {
if (ignoreNulls && (offset > 0)) {
long count = 0;
valuePosition = currentPosition - 1;
while (withinPartition(valuePosition, currentPosition)) {
if (!windowIndex.isNull(valueChannel, toIntExact(valuePosition))) {
count++;
if (count == offset) {
break;
}
}
valuePosition--;
}
}
else {
valuePosition = currentPosition - offset;
}

if (withinPartition(valuePosition, currentPosition)) {
windowIndex.appendTo(valueChannel, toIntExact(valuePosition), output);
}
else if (defaultChannel >= 0) {
Expand All @@ -63,4 +80,9 @@ else if (defaultChannel >= 0) {
}
}
}

private boolean withinPartition(long valuePosition, long currentPosition)
{
return valuePosition >= 0 && valuePosition <= currentPosition;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ public void processRow(BlockBuilder output, int frameStart, int frameEnd, int cu
return;
}

windowIndex.appendTo(argumentChannel, frameEnd, output);
int valuePosition = frameEnd;

if (ignoreNulls) {
while (valuePosition >= frameStart) {
if (!windowIndex.isNull(argumentChannel, valuePosition)) {
break;
}

valuePosition--;
}

if (valuePosition < frameStart) {
output.appendNull();
return;
}
}

windowIndex.appendTo(argumentChannel, valuePosition, output);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,26 @@ public void processRow(BlockBuilder output, int frameStart, int frameEnd, int cu
long offset = (offsetChannel < 0) ? 1 : windowIndex.getLong(offsetChannel, currentPosition);
checkCondition(offset >= 0, INVALID_FUNCTION_ARGUMENT, "Offset must be at least 0");

long valuePosition = currentPosition + offset;
long valuePosition;

if ((valuePosition >= 0) && (valuePosition < windowIndex.size())) {
if (ignoreNulls && (offset > 0)) {
long count = 0;
valuePosition = currentPosition + 1;
while (withinPartition(valuePosition)) {
if (!windowIndex.isNull(valueChannel, toIntExact(valuePosition))) {
count++;
if (count == offset) {
break;
}
}
valuePosition++;
}
}
else {
valuePosition = currentPosition + offset;
}

if (withinPartition(valuePosition)) {
windowIndex.appendTo(valueChannel, toIntExact(valuePosition), output);
}
else if (defaultChannel >= 0) {
Expand All @@ -63,4 +80,9 @@ else if (defaultChannel >= 0) {
}
}
}

private boolean withinPartition(long valuePosition)
{
return (valuePosition >= 0) && (valuePosition < windowIndex.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,25 @@ public void processRow(BlockBuilder output, int frameStart, int frameEnd, int cu
long offset = windowIndex.getLong(offsetChannel, currentPosition);
checkCondition(offset >= 1, INVALID_FUNCTION_ARGUMENT, "Offset must be at least 1");

// offset is base 1
long valuePosition = frameStart + (offset - 1);
long valuePosition;

if (ignoreNulls) {
long count = 0;
valuePosition = frameStart;
while (valuePosition >= 0 && valuePosition <= frameEnd) {
if (!windowIndex.isNull(valueChannel, toIntExact(valuePosition))) {
count++;
if (count == offset) {
break;
}
}
valuePosition++;
}
}
else {
// offset is base 1
valuePosition = frameStart + (offset - 1);
}

if ((valuePosition >= frameStart) && (valuePosition <= frameEnd)) {
windowIndex.appendTo(valueChannel, toIntExact(valuePosition), output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.QualifiedFunctionName;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.ValueWindowFunction;
import com.facebook.presto.spi.function.WindowFunction;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -55,15 +56,23 @@ public ReflectionWindowFunctionSupplier(Signature signature, Class<T> type)
}

@Override
protected T newWindowFunction(List<Integer> inputs)
protected T newWindowFunction(List<Integer> inputs, boolean ignoreNulls)
{
try {
T windowFunction;

if (getSignature().getArgumentTypes().isEmpty()) {
return constructor.newInstance();
windowFunction = constructor.newInstance();
}
else {
return constructor.newInstance(inputs);
windowFunction = constructor.newInstance(inputs);
}

if (windowFunction instanceof ValueWindowFunction) {
((ValueWindowFunction) windowFunction).setIgnoreNulls(ignoreNulls);
}

return windowFunction;
}
catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ public static List<SqlWindowFunction> parseFunctionDefinition(Class<? extends Wi
.collect(toImmutableList());
}

private static SqlWindowFunction parse(Class<? extends WindowFunction> clazz, WindowFunctionSignature window)
private static SqlWindowFunction parse(
Class<? extends WindowFunction> clazz,
WindowFunctionSignature window)
{
List<TypeVariableConstraint> typeVariables = ImmutableList.of();
if (!window.typeVariable().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ public interface WindowFunctionSupplier

String getDescription();

WindowFunction createWindowFunction(List<Integer> argumentChannels);
WindowFunction createWindowFunction(List<Integer> argumentChannels, boolean ignoreNulls);
}
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ protected Object visitFunctionCall(FunctionCall node, Object context)

// do not optimize non-deterministic functions
if (optimize && (!functionMetadata.isDeterministic() || hasUnresolvedValue(argumentValues) || node.getName().equals(QualifiedName.of("fail")))) {
return new FunctionCall(node.getName(), node.getWindow(), node.isDistinct(), toExpressions(argumentValues, argumentTypes));
return new FunctionCall(node.getName(), node.getWindow(), node.isDistinct(), node.isIgnoreNulls(), toExpressions(argumentValues, argumentTypes));
}

Object result;
Expand All @@ -935,7 +935,7 @@ protected Object visitFunctionCall(FunctionCall node, Object context)
}

if (optimize && !isSerializable(result, type(node))) {
return new FunctionCall(node.getName(), node.getWindow(), node.isDistinct(), toExpressions(argumentValues, argumentTypes));
return new FunctionCall(node.getName(), node.getWindow(), node.isDistinct(), node.isIgnoreNulls(), toExpressions(argumentValues, argumentTypes));
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,9 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext

FrameInfo frameInfo = new FrameInfo(frame.getType(), frame.getStartType(), frameStartChannel, frame.getEndType(), frameEndChannel);

CallExpression call = entry.getValue().getFunctionCall();
FunctionHandle functionHandle = entry.getValue().getFunctionHandle();
WindowNode.Function function = entry.getValue();
CallExpression call = function.getFunctionCall();
FunctionHandle functionHandle = function.getFunctionHandle();
ImmutableList.Builder<Integer> arguments = ImmutableList.builder();
for (RowExpression argument : call.getArguments()) {
checkState(argument instanceof VariableReferenceExpression);
Expand All @@ -959,7 +960,7 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext
FunctionManager functionManager = metadata.getFunctionManager();
WindowFunctionSupplier windowFunctionSupplier = functionManager.getWindowFunctionImplementation(functionHandle);
Type type = metadata.getType(functionManager.getFunctionMetadata(functionHandle).getReturnType());
windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, arguments.build()));
windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, function.isIgnoreNulls(), arguments.build()));
windowFunctionOutputVariablesBuilder.add(variable);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,8 @@ private PlanBuilder window(PlanBuilder subPlan, List<FunctionCall> windowFunctio
analysis.getFunctionHandle(windowFunction),
returnType,
((FunctionCall) rewritten).getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())),
frame);
frame,
windowFunction.isIgnoreNulls());

ImmutableList.Builder<VariableReferenceExpression> orderByVariables = ImmutableList.builder();
orderByVariables.addAll(orderings.keySet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ public Result apply(WindowNode windowNode, Captures captures, Context context)
callExpression.getFunctionHandle(),
callExpression.getType(),
newArguments.build()),
entry.getValue().getFrame()));
entry.getValue().getFrame(),
entry.getValue().isIgnoreNulls()));
}
if (anyRewritten) {
return Result.ofPlanNode(new WindowNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
callExpression.getFunctionHandle(),
callExpression.getType(),
rewrittenArguments),
canonicalFrame));
canonicalFrame,
entry.getValue().isIgnoreNulls()));
}

return new WindowNode(
Expand Down
Loading

0 comments on commit 7efeecc

Please sign in to comment.