Skip to content

Commit

Permalink
SQL: Prevent grouping over grouping functions (#38649)
Browse files Browse the repository at this point in the history
Improve verifier to disallow grouping over grouping functions (e.g.
HISTOGRAM over HISTOGRAM).

Close #38308

(cherry picked from commit 4e9b1cf)
(cherry picked from commit 794ee4f)
  • Loading branch information
costin committed Feb 9, 2019
1 parent 7881639 commit 91cb151
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -593,20 +593,36 @@ private static void checkGroupingFunctionInGroupBy(LogicalPlan p, Set<Failure> l
// check if the query has a grouping function (Histogram) but no GROUP BY
if (p instanceof Project) {
Project proj = (Project) p;
proj.projections().forEach(e -> e.forEachDown(f ->
proj.projections().forEach(e -> e.forEachDown(f ->
localFailures.add(fail(f, "[{}] needs to be part of the grouping", Expressions.name(f))), GroupingFunction.class));
} else if (p instanceof Aggregate) {
// if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms)
// if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms)
Aggregate a = (Aggregate) p;
a.aggregates().forEach(agg -> agg.forEachDown(e -> {
if (a.groupings().size() == 0
if (a.groupings().size() == 0
|| Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function) g)) == false) {
localFailures.add(fail(e, "[{}] needs to be part of the grouping", Expressions.name(e)));
}
else {
checkGroupingFunctionTarget(e, localFailures);
}
}, GroupingFunction.class));

a.groupings().forEach(g -> g.forEachDown(e -> {
checkGroupingFunctionTarget(e, localFailures);
}, GroupingFunction.class));
}
}

private static void checkGroupingFunctionTarget(GroupingFunction f, Set<Failure> localFailures) {
f.field().forEachDown(e -> {
if (e instanceof GroupingFunction) {
localFailures.add(fail(f.field(), "Cannot embed grouping functions within each other, found [{}] in [{}]",
Expressions.name(f.field()), Expressions.name(f)));
}
});
}

private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures) {
if (p instanceof Filter) {
Filter filter = (Filter) p;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import org.elasticsearch.xpack.sql.util.StringUtils;

import java.util.List;
import java.util.Locale;

import static java.lang.String.format;

/**
* In a SQL statement, an Expression is whatever a user specifies inside an
Expand All @@ -39,10 +36,6 @@ public TypeResolution(String message) {
this(true, message);
}

TypeResolution(String message, Object... args) {
this(true, format(Locale.ROOT, message, args));
}

private TypeResolution(boolean unresolved, String message) {
this.failed = unresolved;
this.message = message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import java.util.StringJoiner;
import java.util.function.Predicate;

import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.type.DataType.BOOLEAN;

public final class Expressions {
Expand Down Expand Up @@ -186,7 +186,7 @@ public static TypeResolution typeMustBe(Expression e,
String... acceptedTypes) {
return predicate.test(e.dataType()) || DataTypes.isNull(e.dataType())?
TypeResolution.TYPE_RESOLVED :
new TypeResolution(format(Locale.ROOT, "[%s]%s argument must be [%s], found value [%s] type [%s]",
new TypeResolution(format(null, "[{}]{} argument must be [{}], found value [{}] type [{}]",
operationName,
paramOrd == null || paramOrd == ParamOrdinal.DEFAULT ? "" : " " + paramOrd.name().toLowerCase(Locale.ROOT),
acceptedTypesForErrorMsg(acceptedTypes),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,10 +566,20 @@ public void testGroupByScalarOnTopOfGrouping() {
}

public void testAggsInHistogram() {
assertEquals("1:47: Cannot use an aggregate [MAX] for grouping",
error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(MAX(int), 1)"));
assertEquals("1:37: Cannot use an aggregate [MAX] for grouping",
error("SELECT MAX(date) FROM test GROUP BY MAX(int)"));
}


public void testGroupingsInHistogram() {
assertEquals(
"1:47: Cannot embed grouping functions within each other, found [HISTOGRAM(int, 1)] in [HISTOGRAM(HISTOGRAM(int, 1), 1)]",
error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(HISTOGRAM(int, 1), 1)"));
}

public void testCastInHistogram() {
accept("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(CAST(int AS LONG), 1)");
}

public void testHistogramNotInGrouping() {
assertEquals("1:8: [HISTOGRAM(date, INTERVAL 1 MONTH)] needs to be part of the grouping",
error("SELECT HISTOGRAM(date, INTERVAL 1 MONTH) AS h FROM test"));
Expand Down

0 comments on commit 91cb151

Please sign in to comment.