Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: Prevent grouping over grouping functions #38649

Merged
merged 2 commits into from
Feb 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -593,20 +593,36 @@ private static void checkGroupingFunctionInGroupBy(LogicalPlan p, Set<Failure> l
// check if the query has a grouping function (Histogram) but no GROUP BY
if (p instanceof Project) {
Project proj = (Project) p;
proj.projections().forEach(e -> e.forEachDown(f ->
proj.projections().forEach(e -> e.forEachDown(f ->
localFailures.add(fail(f, "[{}] needs to be part of the grouping", Expressions.name(f))), GroupingFunction.class));
} else if (p instanceof Aggregate) {
// if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms)
// if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms)
Aggregate a = (Aggregate) p;
a.aggregates().forEach(agg -> agg.forEachDown(e -> {
if (a.groupings().size() == 0
if (a.groupings().size() == 0
|| Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function) g)) == false) {
localFailures.add(fail(e, "[{}] needs to be part of the grouping", Expressions.name(e)));
}
else {
checkGroupingFunctionTarget(e, localFailures);
}
}, GroupingFunction.class));

a.groupings().forEach(g -> g.forEachDown(e -> {
checkGroupingFunctionTarget(e, localFailures);
}, GroupingFunction.class));
}
}

private static void checkGroupingFunctionTarget(GroupingFunction f, Set<Failure> localFailures) {
f.field().forEachDown(e -> {
if (e instanceof GroupingFunction) {
localFailures.add(fail(f.field(), "Cannot embed grouping functions within each other, found [{}] in [{}]",
Expressions.name(f.field()), Expressions.name(f)));
}
});
}

private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures) {
if (p instanceof Filter) {
Filter filter = (Filter) p;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import org.elasticsearch.xpack.sql.util.StringUtils;

import java.util.List;
import java.util.Locale;

import static java.lang.String.format;

/**
* In a SQL statement, an Expression is whatever a user specifies inside an
Expand All @@ -39,10 +36,6 @@ public TypeResolution(String message) {
this(true, message);
}

TypeResolution(String message, Object... args) {
this(true, format(Locale.ROOT, message, args));
}

private TypeResolution(boolean unresolved, String message) {
this.failed = unresolved;
this.message = message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import java.util.StringJoiner;
import java.util.function.Predicate;

import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.type.DataType.BOOLEAN;

public final class Expressions {
Expand Down Expand Up @@ -186,7 +186,7 @@ public static TypeResolution typeMustBe(Expression e,
String... acceptedTypes) {
return predicate.test(e.dataType()) || DataTypes.isNull(e.dataType())?
TypeResolution.TYPE_RESOLVED :
new TypeResolution(format(Locale.ROOT, "[%s]%s argument must be [%s], found value [%s] type [%s]",
new TypeResolution(format(null, "[{}]{} argument must be [{}], found value [{}] type [{}]",
operationName,
paramOrd == null || paramOrd == ParamOrdinal.DEFAULT ? "" : " " + paramOrd.name().toLowerCase(Locale.ROOT),
acceptedTypesForErrorMsg(acceptedTypes),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,10 +566,20 @@ public void testGroupByScalarOnTopOfGrouping() {
}

public void testAggsInHistogram() {
assertEquals("1:47: Cannot use an aggregate [MAX] for grouping",
error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(MAX(int), 1)"));
assertEquals("1:37: Cannot use an aggregate [MAX] for grouping",
error("SELECT MAX(date) FROM test GROUP BY MAX(int)"));
}


public void testGroupingsInHistogram() {
assertEquals(
"1:47: Cannot embed grouping functions within each other, found [HISTOGRAM(int, 1)] in [HISTOGRAM(HISTOGRAM(int, 1), 1)]",
error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(HISTOGRAM(int, 1), 1)"));
}

public void testCastInHistogram() {
accept("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(CAST(int AS LONG), 1)");
}

public void testHistogramNotInGrouping() {
assertEquals("1:8: [HISTOGRAM(date, INTERVAL 1 MONTH)] needs to be part of the grouping",
error("SELECT HISTOGRAM(date, INTERVAL 1 MONTH) AS h FROM test"));
Expand Down