Skip to content

Commit

Permalink
Rework fix according to @dai-chen's comment.
Browse files Browse the repository at this point in the history
#1196 (review)

* Revert `BinaryPredicateOperator`.
* Add automatic cast rules `DATE`/`TIME`/`DATETIME` -> `DATETIME`/TIMESTAMP`.
* Update unit tests.

Signed-off-by: Yury-Fridlyand <yury.fridlyand@improving.com>
  • Loading branch information
Yury-Fridlyand committed Jan 13, 2023
1 parent 2d0d36a commit 2705283
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ public enum ExprCoreType implements ExprType {
* Date.
* Todo. compatible relationship.
*/
TIMESTAMP(STRING),
DATE(STRING),
TIME(STRING),
DATETIME(STRING),
DATETIME(STRING, DATE, TIME),
TIMESTAMP(STRING, DATE, TIME, DATETIME),
INTERVAL(UNDEFINED),

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import static org.opensearch.sql.data.type.ExprCoreType.TIME;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
import static org.opensearch.sql.expression.function.FunctionDSL.implWithProperties;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandlingWithProperties;

import java.util.Arrays;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -182,6 +184,11 @@ private static DefaultFunctionResolver castToTimestamp() {
(v) -> new ExprTimestampValue(v.stringValue())), TIMESTAMP, STRING),
impl(nullMissingHandling(
(v) -> new ExprTimestampValue(v.timestampValue())), TIMESTAMP, DATETIME),
impl(nullMissingHandling(
(v) -> new ExprTimestampValue(v.timestampValue())), TIMESTAMP, DATE),
implWithProperties(nullMissingHandlingWithProperties(
(fp, v) -> new ExprTimestampValue(((ExprTimeValue)v).timestampValue(fp))),
TIMESTAMP, TIME),
impl(nullMissingHandling((v) -> v), TIMESTAMP, TIMESTAMP)
);
}
Expand All @@ -193,7 +200,11 @@ private static DefaultFunctionResolver castToDatetime() {
impl(nullMissingHandling(
(v) -> new ExprDatetimeValue(v.datetimeValue())), DATETIME, TIMESTAMP),
impl(nullMissingHandling(
(v) -> new ExprDatetimeValue(v.datetimeValue())), DATETIME, DATE)
(v) -> new ExprDatetimeValue(v.datetimeValue())), DATETIME, DATE),
implWithProperties(nullMissingHandlingWithProperties(
(fp, v) -> new ExprDatetimeValue(((ExprTimeValue)v).datetimeValue(fp))),
DATETIME, TIME),
impl(nullMissingHandling((v) -> v), DATETIME, DATETIME)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,20 @@
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.DATE;
import static org.opensearch.sql.data.type.ExprCoreType.DATETIME;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.TIME;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
import static org.opensearch.sql.expression.function.FunctionDSL.implWithProperties;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandlingWithProperties;
import static org.opensearch.sql.utils.DateTimeUtils.extractDateTime;

import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Table;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.data.model.ExprBooleanValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.DefaultFunctionResolver;
import org.opensearch.sql.expression.function.FunctionDSL;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.SerializableBiFunction;
import org.opensearch.sql.utils.OperatorUtils;

/**
Expand Down Expand Up @@ -155,80 +141,116 @@ public static void register(BuiltinFunctionRepository repository) {
.build();

private static DefaultFunctionResolver and() {
return FunctionDSL.define(BuiltinFunctionName.AND.getName(),
impl((v1, v2) -> lookupTableFunction(v1, v2, andTable), BOOLEAN, BOOLEAN, BOOLEAN));
return FunctionDSL.define(BuiltinFunctionName.AND.getName(), FunctionDSL
.impl((v1, v2) -> lookupTableFunction(v1, v2, andTable), BOOLEAN, BOOLEAN,
BOOLEAN));
}

private static DefaultFunctionResolver or() {
return FunctionDSL.define(BuiltinFunctionName.OR.getName(),
impl((v1, v2) -> lookupTableFunction(v1, v2, orTable), BOOLEAN, BOOLEAN, BOOLEAN));
return FunctionDSL.define(BuiltinFunctionName.OR.getName(), FunctionDSL
.impl((v1, v2) -> lookupTableFunction(v1, v2, orTable), BOOLEAN, BOOLEAN,
BOOLEAN));
}

private static DefaultFunctionResolver xor() {
return FunctionDSL.define(BuiltinFunctionName.XOR.getName(),
impl((v1, v2) -> lookupTableFunction(v1, v2, xorTable), BOOLEAN, BOOLEAN, BOOLEAN));
return FunctionDSL.define(BuiltinFunctionName.XOR.getName(), FunctionDSL
.impl((v1, v2) -> lookupTableFunction(v1, v2, xorTable), BOOLEAN, BOOLEAN,
BOOLEAN));
}

private static DefaultFunctionResolver equal() {
return compareImpl(BuiltinFunctionName.EQUAL.getName(),
(Comparable v1, Comparable v2) -> v1.equals(v2));
return FunctionDSL.define(BuiltinFunctionName.EQUAL.getName(),
ExprCoreType.coreTypes().stream()
.map(type -> FunctionDSL.impl(
FunctionDSL.nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.equals(v2))),
BOOLEAN, type, type))
.collect(
Collectors.toList()));
}

private static DefaultFunctionResolver notEqual() {
return compareImpl(BuiltinFunctionName.NOTEQUAL.getName(),
(Comparable v1, Comparable v2) -> !v1.equals(v2));
return FunctionDSL
.define(BuiltinFunctionName.NOTEQUAL.getName(), ExprCoreType.coreTypes().stream()
.map(type -> FunctionDSL
.impl(
FunctionDSL
.nullMissingHandling((v1, v2) -> ExprBooleanValue.of(!v1.equals(v2))),
BOOLEAN,
type,
type))
.collect(
Collectors.toList()));
}

private static DefaultFunctionResolver less() {
return compareImpl(BuiltinFunctionName.LESS.getName(),
(Comparable v1, Comparable v2) -> v1.compareTo(v2) < 0);
return FunctionDSL
.define(BuiltinFunctionName.LESS.getName(), ExprCoreType.coreTypes().stream()
.map(type -> FunctionDSL
.impl(FunctionDSL
.nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) < 0)),
BOOLEAN,
type, type))
.collect(
Collectors.toList()));
}

private static DefaultFunctionResolver lte() {
return compareImpl(BuiltinFunctionName.LTE.getName(),
(Comparable v1, Comparable v2) -> v1.compareTo(v2) <= 0);
return FunctionDSL
.define(BuiltinFunctionName.LTE.getName(), ExprCoreType.coreTypes().stream()
.map(type -> FunctionDSL
.impl(
FunctionDSL
.nullMissingHandling(
(v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) <= 0)),
BOOLEAN,
type, type))
.collect(
Collectors.toList()));
}

private static DefaultFunctionResolver greater() {
return compareImpl(BuiltinFunctionName.GREATER.getName(),
(Comparable v1, Comparable v2) -> v1.compareTo(v2) > 0);
return FunctionDSL
.define(BuiltinFunctionName.GREATER.getName(), ExprCoreType.coreTypes().stream()
.map(type -> FunctionDSL
.impl(FunctionDSL
.nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) > 0)),
BOOLEAN, type, type))
.collect(
Collectors.toList()));
}

private static DefaultFunctionResolver gte() {
return compareImpl(BuiltinFunctionName.GTE.getName(),
(Comparable v1, Comparable v2) -> v1.compareTo(v2) >= 0);
}

private static DefaultFunctionResolver compareImpl(
FunctionName function, SerializableBiFunction<Comparable, Comparable, Boolean> comparator) {
return FunctionDSL.define(function,
Stream.concat(
ExprCoreType.coreTypes().stream()
.map(type -> impl(nullMissingHandling(
(v1, v2) -> ExprBooleanValue.of(comparator.apply(v1, v2))),
BOOLEAN, type, type)),
permuteTemporalTypesByPairs().stream()
.map(pair -> implWithProperties(nullMissingHandlingWithProperties(
(fp, v1, v2) -> ExprBooleanValue.of(comparator.apply(
extractDateTime(v1, fp), extractDateTime(v2, fp)))),
BOOLEAN, pair.getLeft(), pair.getRight())))
.collect(Collectors.toList()));
return FunctionDSL
.define(BuiltinFunctionName.GTE.getName(), ExprCoreType.coreTypes().stream()
.map(type -> FunctionDSL
.impl(
FunctionDSL.nullMissingHandling(
(v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) >= 0)),
BOOLEAN,
type, type))
.collect(
Collectors.toList()));
}

private static DefaultFunctionResolver like() {
return FunctionDSL.define(BuiltinFunctionName.LIKE.getName(),
impl(nullMissingHandling(OperatorUtils::matches), BOOLEAN, STRING, STRING));
return FunctionDSL.define(BuiltinFunctionName.LIKE.getName(), FunctionDSL
.impl(FunctionDSL.nullMissingHandling(OperatorUtils::matches), BOOLEAN, STRING,
STRING));
}

private static DefaultFunctionResolver regexp() {
return FunctionDSL.define(BuiltinFunctionName.REGEXP.getName(),
impl(nullMissingHandling(OperatorUtils::matchesRegexp), INTEGER, STRING, STRING));
return FunctionDSL.define(BuiltinFunctionName.REGEXP.getName(), FunctionDSL
.impl(FunctionDSL.nullMissingHandling(OperatorUtils::matchesRegexp),
INTEGER, STRING, STRING));
}

private static DefaultFunctionResolver notLike() {
return FunctionDSL.define(BuiltinFunctionName.NOT_LIKE.getName(), impl(nullMissingHandling(
return FunctionDSL.define(BuiltinFunctionName.NOT_LIKE.getName(), FunctionDSL
.impl(FunctionDSL.nullMissingHandling(
(v1, v2) -> UnaryPredicateOperator.not(OperatorUtils.matches(v1, v2))),
BOOLEAN, STRING, STRING));
BOOLEAN,
STRING,
STRING));
}

private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2,
Expand All @@ -239,13 +261,4 @@ private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2,
return table.get(arg2, arg1);
}
}

private static List<Pair<ExprCoreType, ExprCoreType>> permuteTemporalTypesByPairs() {
return List.of(
Pair.of(DATE, TIME), Pair.of(DATE, DATETIME), Pair.of(DATE, TIMESTAMP),
Pair.of(TIME, DATE), Pair.of(TIME, DATETIME), Pair.of(TIME, TIMESTAMP),
Pair.of(DATETIME, TIME), Pair.of(DATETIME, DATE), Pair.of(DATETIME, TIMESTAMP),
Pair.of(TIMESTAMP, TIME), Pair.of(TIMESTAMP, DATE), Pair.of(TIMESTAMP, DATETIME)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class WideningTypeRuleTest {
.put(STRING, DATE, 1)
.put(STRING, TIME, 1)
.put(STRING, DATETIME, 1)
.put(DATE, DATETIME, 1)
.put(TIME, DATETIME, 1)
.put(DATE, TIMESTAMP, 1)
.put(TIME, TIMESTAMP, 1)
.put(DATETIME, TIMESTAMP, 1)
.put(UNDEFINED, BYTE, 1)
.put(UNDEFINED, SHORT, 2)
.put(UNDEFINED, INTEGER, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import static org.opensearch.sql.data.model.ExprValueUtils.fromObjectValue;
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.DATETIME;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.utils.ComparisonUtil.compare;
import static org.opensearch.sql.utils.OperatorUtils.matches;

Expand Down Expand Up @@ -54,6 +56,7 @@
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ExpressionTestBase;
Expand Down Expand Up @@ -419,7 +422,25 @@ public void test_equal(ExprValue v1, ExprValue v2) {
assertEquals(0 == compare(functionProperties, v1, v2),
ExprValueUtils.getBooleanValue(equal.valueOf(valueEnv())));
}
assertEquals(String.format("=(%s, %s)", v1.toString(), v2.toString()), equal.toString());
assertStringRepr(v1, v2, "=", equal);
}

private void assertStringRepr(ExprValue v1, ExprValue v2, String function,
FunctionExpression functionExpression) {
if (v1.type() == v2.type()) {
assertEquals(String.format("%s(%s, %s)", function, v1, v2), functionExpression.toString());
} else {
var widerType = v1.type() == TIMESTAMP || v2.type() == TIMESTAMP ? TIMESTAMP : DATETIME;
assertEquals(String.format("%s(%s, %s)", function, getExpectedStringRepr(widerType, v1),
getExpectedStringRepr(widerType, v2)), functionExpression.toString());
}
}

private String getExpectedStringRepr(ExprType widerType, ExprValue value) {
if (widerType == value.type()) {
return value.toString();
}
return String.format("cast_to_%s(%s)", widerType.toString().toLowerCase(), value);
}

@Test
Expand Down Expand Up @@ -475,7 +496,7 @@ public void test_notequal(ExprValue v1, ExprValue v2) {
assertEquals(0 != compare(functionProperties, v1, v2),
ExprValueUtils.getBooleanValue(notequal.valueOf(valueEnv())));
}
assertEquals(String.format("!=(%s, %s)", v1.toString(), v2.toString()), notequal.toString());
assertStringRepr(v1, v2, "!=", notequal);
}

@Test
Expand Down Expand Up @@ -528,7 +549,7 @@ public void test_less(ExprValue v1, ExprValue v2) {
assertEquals(BOOLEAN, less.type());
assertEquals(compare(functionProperties, v1, v2) < 0,
ExprValueUtils.getBooleanValue(less.valueOf(valueEnv())));
assertEquals(String.format("<(%s, %s)", v1.toString(), v2.toString()), less.toString());
assertStringRepr(v1, v2, "<", less);
}

@Test
Expand Down Expand Up @@ -585,7 +606,7 @@ public void test_lte(ExprValue v1, ExprValue v2) {
assertEquals(BOOLEAN, lte.type());
assertEquals(compare(functionProperties, v1, v2) <= 0,
ExprValueUtils.getBooleanValue(lte.valueOf(valueEnv())));
assertEquals(String.format("<=(%s, %s)", v1.toString(), v2.toString()), lte.toString());
assertStringRepr(v1, v2, "<=", lte);
}

@Test
Expand Down Expand Up @@ -642,7 +663,7 @@ public void test_greater(ExprValue v1, ExprValue v2) {
assertEquals(BOOLEAN, greater.type());
assertEquals(compare(functionProperties, v1, v2) > 0,
ExprValueUtils.getBooleanValue(greater.valueOf(valueEnv())));
assertEquals(String.format(">(%s, %s)", v1.toString(), v2.toString()), greater.toString());
assertStringRepr(v1, v2, ">", greater);
}

@Test
Expand Down Expand Up @@ -699,7 +720,7 @@ public void test_gte(ExprValue v1, ExprValue v2) {
assertEquals(BOOLEAN, gte.type());
assertEquals(compare(functionProperties, v1, v2) >= 0,
ExprValueUtils.getBooleanValue(gte.valueOf(valueEnv())));
assertEquals(String.format(">=(%s, %s)", v1.toString(), v2.toString()), gte.toString());
assertStringRepr(v1, v2, ">=", gte);
}

@Test
Expand Down

0 comments on commit 2705283

Please sign in to comment.