Skip to content

Commit

Permalink
feat: Snowflake math functions, complete
Browse files Browse the repository at this point in the history
Signed-off-by: Andreas Reichel <andreas@manticore-projects.com>
  • Loading branch information
manticore-projects committed May 4, 2024
1 parent 649edd4 commit 4cb7e85
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import net.sf.jsqlparser.expression.CaseExpression;
import net.sf.jsqlparser.expression.CastExpression;
import net.sf.jsqlparser.expression.DateTimeLiteralExpression;
import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.HexValue;
Expand Down Expand Up @@ -71,6 +72,8 @@ enum TranspiledFunction {

, VARIANCE_POP, VARIANCE_SAMP, BITAND_AGG, BITOR_AGG, BITXOR_AGG, BOOLAND_AGG, BOOLOR_AGG, BOOLXOR_AGG, SKEW, GROUPING_ID, TO_VARCHAR, TO_BINARY, TRY_TO_BINARY, TO_DECIMAL, TO_NUMBER, TO_NUMERIC, TRY_TO_DECIMAL, TRY_TO_NUMBER, TRY_TO_NUMERIC, TO_DOUBLE, TRY_TO_DOUBLE, TO_BOOLEAN, TRY_TO_BOOLEAN, TRY_TO_DATE, TRY_TO_TIME, TRY_TO_TIMESTAMP, TRY_TO_TIMESTAMP_LTZ, TRY_TO_TIMESTAMP_NTZ, TRY_TO_TIMESTAMP_TZ

, RANDOM, DIV0, DIV0NULL, ROUND, SQUARE

;
// @FORMATTER:ON

Expand Down Expand Up @@ -266,12 +269,12 @@ public void visit(Function function) {
}
break;
case DATE_PART:
if (paramCount==2) {
if (paramCount == 2) {
function.setParameters(toDateTimePart(parameters.get(0)), parameters.get(1));
}
break;
case LAST_DAY:
if (paramCount==2) {
if (paramCount == 2) {
throw new RuntimeException("LAST_DATE with DatePart is not supported.");
}
break;
Expand Down Expand Up @@ -837,19 +840,16 @@ public void visit(Function function) {
case TRY_TO_NUMBER:
case TRY_TO_NUMERIC:
// TO_DECIMAL( <expr> [, '<format>' ] [, <precision> [, <scale> ] ] )
// list_aggregate(regexp_extract_all('-1,000.00', '[\+|\-\d|\.]'),'string_agg', '')::NUMERIC
// list_aggregate(regexp_extract_all('-1,000.00', '[\+|\-\d|\.]'),'string_agg',
// '')::NUMERIC

Function f1 = new Function(
"List_Aggregate"
, new Function("Regexp_Extract_All", parameters.get(0), new StringValue("[\\+|\\-\\d|\\.]"))
, new StringValue("string_agg")
, new StringValue("")
);
f1 = new Function("If"
, new EqualsTo(new Function("TypeOf", parameters.get(0)), new StringValue("VARCHAR"))
, f1
, parameters.get(0)
);
Function f1 = new Function("List_Aggregate",
new Function("Regexp_Extract_All", parameters.get(0),
new StringValue("[\\+|\\-\\d|\\.]")),
new StringValue("string_agg"), new StringValue(""));
f1 = new Function("If",
new EqualsTo(new Function("TypeOf", parameters.get(0)), new StringValue("VARCHAR")),
f1, parameters.get(0));

switch (paramCount) {
case 4:
Expand All @@ -858,45 +858,39 @@ public void visit(Function function) {
&& parameters.get(3) instanceof LongValue) {
String typeStr = "DECIMAL(" + ((LongValue) parameters.get(2)).getValue() + ", "
+ ((LongValue) parameters.get(3)).getValue() + ")";
rewrittenExpression =
new CastExpression(functionName.startsWith("TRY") ? "Try_Cast" : "Cast",
f1, typeStr);
rewrittenExpression = new CastExpression(
functionName.startsWith("TRY") ? "Try_Cast" : "Cast", f1, typeStr);
}
break;
case 3:
if (parameters.get(1) instanceof StringValue
&& parameters.get(2) instanceof LongValue) {
warning("Format Parameter not supported.");
String typeStr = "DECIMAL(" + ((LongValue) parameters.get(2)).getValue() + ")";
rewrittenExpression =
new CastExpression(functionName.startsWith("TRY") ? "Try_Cast" : "Cast",
f1, typeStr);
rewrittenExpression = new CastExpression(
functionName.startsWith("TRY") ? "Try_Cast" : "Cast", f1, typeStr);
} else if (parameters.get(1) instanceof LongValue
&& parameters.get(2) instanceof LongValue) {
String typeStr = "DECIMAL(" + ((LongValue) parameters.get(1)).getValue() + ", "
+ ((LongValue) parameters.get(2)).getValue() + ")";
rewrittenExpression =
new CastExpression(functionName.startsWith("TRY") ? "Try_Cast" : "Cast",
f1, typeStr);
rewrittenExpression = new CastExpression(
functionName.startsWith("TRY") ? "Try_Cast" : "Cast", f1, typeStr);
}
break;
case 2:
if (parameters.get(1) instanceof StringValue) {
warning("Format Parameter not supported.");
rewrittenExpression =
new CastExpression(functionName.startsWith("TRY") ? "Try_Cast" : "Cast",
f1, "DECIMAL(12,0)");
rewrittenExpression = new CastExpression(
functionName.startsWith("TRY") ? "Try_Cast" : "Cast", f1, "DECIMAL(12,0)");
} else if (parameters.get(1) instanceof LongValue) {
String typeStr = "DECIMAL(" + ((LongValue) parameters.get(1)).getValue() + ")";
rewrittenExpression =
new CastExpression(functionName.startsWith("TRY") ? "Try_Cast" : "Cast",
f1, typeStr);
rewrittenExpression = new CastExpression(
functionName.startsWith("TRY") ? "Try_Cast" : "Cast", f1, typeStr);
}
break;
case 1:
rewrittenExpression =
new CastExpression(functionName.startsWith("TRY") ? "Try_Cast" : "Cast",
f1, "DECIMAL(12,0)");
rewrittenExpression = new CastExpression(
functionName.startsWith("TRY") ? "Try_Cast" : "Cast", f1, "DECIMAL(12,0)");
break;
}
break;
Expand All @@ -922,6 +916,44 @@ public void visit(Function function) {
parameters.get(0), "BOOLEAN");
}
break;
case RANDOM:
if (paramCount == 1) {
warning("SEED parameter not supported");
}
// ((random() - 0.5) * 1E19)::int64
rewrittenExpression = new CastExpression("Cast",
BinaryExpression.multiply(
new ParenthesedExpressionList<>(
BinaryExpression.subtract(new Function("Random$$"), new DoubleValue(0.5d))),
new DoubleValue("1E19")),
"INT64");
break;
case DIV0:
case DIV0NULL:
if (paramCount == 2) {
function.setName("Coalesce");
function.setParameters(new Function("Divide", parameters.get(0), parameters.get(1)),
new LongValue(0));
}
break;
case ROUND:
switch (paramCount) {
case 3:
warning("Limited support for rounding mode");
if ("'HALF_TO_EVEN'".equalsIgnoreCase(parameters.get(2).toString())) {
function.setName("Round_Even");
}
case 2:
function.setParameters(parameters.get(0), parameters.get(1));
break;
}
break;
case SQUARE:
if (paramCount == 1) {
function.setName("Power");
function.setParameters(parameters.get(0), new LongValue(2));
}
break;
}
}
if (rewrittenExpression == null) {
Expand Down Expand Up @@ -991,7 +1023,7 @@ public void visit(AnalyticExpression function) {
}
}

super.visit(function);
super.visit(function);
}

public void visit(Column column) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public void visit(TableFunction tableFunction) {
PlainSelect select = new PlainSelect();
for (Expression expression : tableFunction.getFunction().getParameters()) {
Alias alias = null;
boolean addExpression = false;
if (expression instanceof Function) {
Function f = (Function) expression;
String fName = f.getName().toUpperCase();
Expand All @@ -91,15 +92,24 @@ public void visit(TableFunction tableFunction) {

select.addSelectItem(expression, alias);
alias = new Alias("value", true);
addExpression = true;
} else if (fName.equalsIgnoreCase("GENERATOR")) {
// select range AS seq4 FROM range(0,3);
select.addSelectItem(new Column("range"), new Alias("seq4", true));
select.setFromItem(new TableFunction(
new Function("Range", new LongValue(0), f.getParameters().get(0))));
alias = new Alias("value", true);
}
}
select.addSelectItem(expression, alias);
if (addExpression) {
select.addSelectItem(expression, alias);
}
}
ParenthesedSelect parenthesedSelect =
new ParenthesedSelect().withSelect(select).withAlias(tableFunction.getAlias());

visit(parenthesedSelect);
} else if (prefix.equalsIgnoreCase("lateral")) {
} else if ("lateral".equalsIgnoreCase(prefix)) {
PlainSelect select = new PlainSelect();
if (name.equalsIgnoreCase("SPLIT_TO_TABLE")
|| name.equalsIgnoreCase("STRTOK_SPLIT_TO_TABLE")) {
Expand Down
Binary file modified src/main/resources/doc/JSQLTranspiler.ods
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
-- provided
SELECT random() AS rand FROM table(generator(rowCount => 3));

-- expected
SELECT CAST((RANDOM()-0.5)*1E19 AS INT64)AS RAND FROM(SELECT RANGE AS SEQ4 FROM RANGE(0,ROWCOUNT=>3))

-- count
3


-- provided
SELECT DIV0(1, 0) AS div0;

-- expected
SELECT COALESCE(DIVIDE(1,0), 0) AS div0;

-- result
"div0"
"0"


-- provided
SELECT MOD(3, 2) AS mod1, MOD(4.5, 1.2) AS mod2;

-- result
"mod1","mod2"
"1","0.9"

-- provided
SELECT ROUND(2.5, 0) AS r1, ROUND(2.5, 0, 'HALF_TO_EVEN') AS r2;

-- expected
SELECT ROUND(2.5, 0) AS r1, ROUND_EVEN(2.5, 0) AS r2;

-- result
"r1","r2"
"3","2.0"


-- provided
SELECT FACTORIAL(0), FACTORIAL(1), FACTORIAL(5), FACTORIAL(10);

-- result
"factorial(0)","factorial(1)","factorial(5)","factorial(10)"
"1","1","120","3628800"


-- provided
select square(12) AS square;

-- expected
select power(12,2) AS square;

-- returns
"square"
"144.0"

0 comments on commit 4cb7e85

Please sign in to comment.