-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add derivation expression evaluator (#63)
Adds support for the evaluation of extension function's return type expressions. It's meant to be used by producers to correctly infer return types. It uses antlr grammar (temporarily) copy-pasted from `substrait-java` w/o changes.
- Loading branch information
Showing
14 changed files
with
3,940 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
FROM mcr.microsoft.com/vscode/devcontainers/python:3.10-buster | ||
USER vscode | ||
RUN curl -s "https://get.sdkman.io" | bash | ||
SHELL ["/bin/bash", "-c"] | ||
RUN source "/home/vscode/.sdkman/bin/sdkman-init.sh" && sdk install java 20.0.2-graalce | ||
RUN mkdir -p ~/lib && cd ~/lib && curl -L -O http://www.antlr.org/download/antlr-4.13.1-complete.jar | ||
ENV ANTLR_JAR="~/lib/antlr-4.13.1-complete.jar" | ||
USER root |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"name": "substrait-python-devcontainer", | ||
"build": { | ||
"context": "..", | ||
"dockerfile": "Dockerfile" | ||
}, | ||
|
||
// Features to add to the dev container. More info: https://containers.dev/features. | ||
// "features": { | ||
// "ghcr.io/devcontainers/features/nix:1": {} | ||
// }, | ||
|
||
// Use 'forwardPorts' to make a list of ports inside the container available locally. | ||
// "forwardPorts": [], | ||
|
||
// Use 'postCreateCommand' to run commands after the container is created. | ||
// "postCreateCommand": "poetry install" | ||
|
||
// Configure tool-specific properties. | ||
// "customizations": {}, | ||
|
||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. | ||
// "remoteUser": "root" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
antlr: | ||
java -jar ${ANTLR_JAR} -o src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
grammar SubstraitType; | ||
|
||
// | ||
fragment A : [aA]; | ||
fragment B : [bB]; | ||
fragment C : [cC]; | ||
fragment D : [dD]; | ||
fragment E : [eE]; | ||
fragment F : [fF]; | ||
fragment G : [gG]; | ||
fragment H : [hH]; | ||
fragment I : [iI]; | ||
fragment J : [jJ]; | ||
fragment K : [kK]; | ||
fragment L : [lL]; | ||
fragment M : [mM]; | ||
fragment N : [nN]; | ||
fragment O : [oO]; | ||
fragment P : [pP]; | ||
fragment Q : [qQ]; | ||
fragment R : [rR]; | ||
fragment S : [sS]; | ||
fragment T : [tT]; | ||
fragment U : [uU]; | ||
fragment V : [vV]; | ||
fragment W : [wW]; | ||
fragment X : [xX]; | ||
fragment Y : [yY]; | ||
fragment Z : [zZ]; | ||
|
||
|
||
If : I F; | ||
Then : T H E N; | ||
Else : E L S E; | ||
|
||
// TYPES | ||
Boolean : B O O L E A N; | ||
I8 : I '8'; | ||
I16 : I '16'; | ||
I32 : I '32'; | ||
I64 : I '64'; | ||
FP32 : F P '32'; | ||
FP64 : F P '64'; | ||
String : S T R I N G; | ||
Binary : B I N A R Y; | ||
Timestamp: T I M E S T A M P; | ||
TimestampTZ: T I M E S T A M P '_' T Z; | ||
Date : D A T E; | ||
Time : T I M E; | ||
IntervalYear: I N T E R V A L '_' Y E A R; | ||
IntervalDay: I N T E R V A L '_' D A Y; | ||
IntervalCompound: I N T E R V A L '_' C O M P O U N D; | ||
UUID : U U I D; | ||
Decimal : D E C I M A L; | ||
PrecisionTimestamp: P R E C I S I O N '_' T I M E S T A M P; | ||
PrecisionTimestampTZ: P R E C I S I O N '_' T I M E S T A M P '_' T Z; | ||
FixedChar: F I X E D C H A R; | ||
VarChar : V A R C H A R; | ||
FixedBinary: F I X E D B I N A R Y; | ||
Struct : S T R U C T; | ||
NStruct : N S T R U C T; | ||
List : L I S T; | ||
Map : M A P; | ||
ANY : A N Y; | ||
UserDefined: U '!'; | ||
|
||
|
||
// OPERATIONS | ||
And : A N D; | ||
Or : O R; | ||
Assign : ':='; | ||
|
||
// COMPARE | ||
Eq : '='; | ||
NotEquals: '!='; | ||
Gte : '>='; | ||
Lte : '<='; | ||
Gt : '>'; | ||
Lt : '<'; | ||
Bang : '!'; | ||
|
||
|
||
// MATH | ||
Plus : '+'; | ||
Minus : '-'; | ||
Asterisk : '*'; | ||
ForwardSlash : '/'; | ||
Percent : '%'; | ||
|
||
// ORGANIZE | ||
OBracket : '['; | ||
CBracket : ']'; | ||
OParen : '('; | ||
CParen : ')'; | ||
SColon : ';'; | ||
Comma : ','; | ||
QMark : '?'; | ||
Colon : ':'; | ||
SingleQuote: '\''; | ||
|
||
|
||
Number | ||
: '-'? Int | ||
; | ||
|
||
Identifier | ||
: ('a'..'z' | 'A'..'Z' | '_' | '$') ('a'..'z' | 'A'..'Z' | '_' | '$' | Digit)* | ||
; | ||
|
||
LineComment | ||
: '//' ~[\r\n]* -> channel(HIDDEN) | ||
; | ||
|
||
BlockComment | ||
: ( '/*' | ||
( '/'* BlockComment | ||
| ~[/*] | ||
| '/'+ ~[/*] | ||
| '*'+ ~[/*] | ||
)* | ||
'*'* | ||
'*/' | ||
) -> channel(HIDDEN) | ||
; | ||
Whitespace | ||
: [ \t]+ -> channel(HIDDEN) | ||
; | ||
Newline | ||
: ( '\r' '\n'? | ||
| '\n' | ||
) | ||
; | ||
fragment Int | ||
: '1'..'9' Digit* | ||
| '0' | ||
; | ||
fragment Digit | ||
: '0'..'9' | ||
; | ||
start: expr EOF; | ||
scalarType | ||
: Boolean #Boolean | ||
| I8 #i8 | ||
| I16 #i16 | ||
| I32 #i32 | ||
| I64 #i64 | ||
| FP32 #fp32 | ||
| FP64 #fp64 | ||
| String #string | ||
| Binary #binary | ||
| Timestamp #timestamp | ||
| TimestampTZ #timestampTz | ||
| Date #date | ||
| Time #time | ||
| IntervalYear #intervalYear | ||
| UUID #uuid | ||
| UserDefined Identifier #userDefined | ||
; | ||
parameterizedType | ||
: FixedChar isnull='?'? Lt len=numericParameter Gt #fixedChar | ||
| VarChar isnull='?'? Lt len=numericParameter Gt #varChar | ||
| FixedBinary isnull='?'? Lt len=numericParameter Gt #fixedBinary | ||
| Decimal isnull='?'? Lt precision=numericParameter Comma scale=numericParameter Gt #decimal | ||
| IntervalDay isnull='?'? Lt precision=numericParameter Gt #intervalDay | ||
| IntervalCompound isnull='?'? Lt precision=numericParameter Gt #intervalCompound | ||
| PrecisionTimestamp isnull='?'? Lt precision=numericParameter Gt #precisionTimestamp | ||
| PrecisionTimestampTZ isnull='?'? Lt precision=numericParameter Gt #precisionTimestampTZ | ||
| Struct isnull='?'? Lt expr (Comma expr)* Gt #struct | ||
| NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct | ||
| List isnull='?'? Lt expr Gt #list | ||
| Map isnull='?'? Lt key=expr Comma value=expr Gt #map | ||
; | ||
numericParameter | ||
: Number #numericLiteral | ||
| Identifier #numericParameterName | ||
| expr #numericExpression | ||
; | ||
anyType: ANY; | ||
type | ||
: scalarType isnull='?'? | ||
| parameterizedType | ||
| anyType isnull='?'? | ||
; | ||
// : (OParen innerExpr CParen | innerExpr) | ||
expr | ||
: OParen expr CParen #ParenExpression | ||
| Identifier Eq expr Newline+ (Identifier Eq expr Newline+)* finalType=type Newline* #MultilineDefinition | ||
| type #TypeLiteral | ||
| number=Number #LiteralNumber | ||
| identifier=Identifier isnull='?'? #TypeParam | ||
| Identifier OParen (expr (Comma expr)*)? CParen #FunctionCall | ||
| left=expr op=(And | Or | Plus | Minus | Lt | Gt | Eq | NotEquals | Lte | Gte | Asterisk | ForwardSlash) right=expr #BinaryExpr | ||
| If ifExpr=expr Then thenExpr=expr Else elseExpr=expr #IfExpr | ||
| (Bang) expr #NotExpr | ||
| ifExpr=expr QMark thenExpr=expr Colon elseExpr=expr #Ternary | ||
; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from typing import Optional | ||
from antlr4 import InputStream, CommonTokenStream | ||
from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer | ||
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser | ||
from substrait.gen.proto.type_pb2 import Type | ||
|
||
|
||
def _evaluate(x, values: dict): | ||
if type(x) == SubstraitTypeParser.BinaryExprContext: | ||
left = _evaluate(x.left, values) | ||
right = _evaluate(x.right, values) | ||
|
||
if x.op.text == "+": | ||
return left + right | ||
elif x.op.text == "-": | ||
return left - right | ||
elif x.op.text == "*": | ||
return left * right | ||
elif x.op.text == ">": | ||
return left > right | ||
elif x.op.text == ">=": | ||
return left >= right | ||
elif x.op.text == "<": | ||
return left < right | ||
elif x.op.text == "<=": | ||
return left <= right | ||
else: | ||
raise Exception(f"Unknown binary op {x.op.text}") | ||
elif type(x) == SubstraitTypeParser.LiteralNumberContext: | ||
return int(x.number.text) | ||
elif type(x) == SubstraitTypeParser.TypeParamContext: | ||
return values[x.identifier.text] | ||
elif type(x) == SubstraitTypeParser.NumericParameterNameContext: | ||
return values[x.Identifier().symbol.text] | ||
elif type(x) == SubstraitTypeParser.ParenExpressionContext: | ||
return _evaluate(x.expr(), values) | ||
elif type(x) == SubstraitTypeParser.FunctionCallContext: | ||
exprs = [_evaluate(e, values) for e in x.expr()] | ||
func = x.Identifier().symbol.text | ||
|
||
if func == "min": | ||
return min(*exprs) | ||
elif func == "max": | ||
return max(*exprs) | ||
else: | ||
raise Exception(f"Unknown function {func}") | ||
elif type(x) == SubstraitTypeParser.TypeContext: | ||
scalar_type = x.scalarType() | ||
parametrized_type = x.parameterizedType() | ||
if scalar_type: | ||
if isinstance(scalar_type, SubstraitTypeParser.I8Context): | ||
return Type(i8=Type.I8()) | ||
elif isinstance(scalar_type, SubstraitTypeParser.I16Context): | ||
return Type(i16=Type.I16()) | ||
elif isinstance(scalar_type, SubstraitTypeParser.I32Context): | ||
return Type(i32=Type.I32()) | ||
elif isinstance(scalar_type, SubstraitTypeParser.I64Context): | ||
return Type(i64=Type.I64()) | ||
elif isinstance(scalar_type, SubstraitTypeParser.Fp32Context): | ||
return Type(fp32=Type.FP32()) | ||
elif isinstance(scalar_type, SubstraitTypeParser.Fp64Context): | ||
return Type(fp64=Type.FP64()) | ||
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): | ||
return Type(bool=Type.Boolean()) | ||
else: | ||
raise Exception(f"Unknown scalar type {type(scalar_type)}") | ||
elif parametrized_type: | ||
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): | ||
precision = _evaluate(parametrized_type.precision, values) | ||
scale = _evaluate(parametrized_type.scale, values) | ||
return Type(decimal=Type.Decimal(precision=precision, scale=scale)) | ||
raise Exception(f"Unknown parametrized type {type(parametrized_type)}") | ||
else: | ||
raise Exception("either scalar_type or parametrized_type is required") | ||
elif type(x) == SubstraitTypeParser.NumericExpressionContext: | ||
return _evaluate(x.expr(), values) | ||
elif type(x) == SubstraitTypeParser.TernaryContext: | ||
ifExpr = _evaluate(x.ifExpr, values) | ||
thenExpr = _evaluate(x.thenExpr, values) | ||
elseExpr = _evaluate(x.elseExpr, values) | ||
|
||
return thenExpr if ifExpr else elseExpr | ||
elif type(x) == SubstraitTypeParser.MultilineDefinitionContext: | ||
lines = zip(x.Identifier(), x.expr()) | ||
|
||
for i, e in lines: | ||
identifier = i.symbol.text | ||
expr_eval = _evaluate(e, values) | ||
values[identifier] = expr_eval | ||
|
||
return _evaluate(x.finalType, values) | ||
elif type(x) == SubstraitTypeParser.TypeLiteralContext: | ||
return _evaluate(x.type_(), values) | ||
else: | ||
raise Exception(f"Unknown token type {type(x)}") | ||
|
||
|
||
def evaluate(x: str, values: Optional[dict] = None): | ||
lexer = SubstraitTypeLexer(InputStream(x)) | ||
stream = CommonTokenStream(lexer) | ||
parser = SubstraitTypeParser(stream) | ||
return _evaluate(parser.expr(), values) |
Oops, something went wrong.