Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: substrait-io/substrait-python
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.23.0
Choose a base ref
...
head repository: substrait-io/substrait-python
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
  • 1 commit
  • 14 files changed
  • 1 contributor

Commits on Oct 28, 2024

  1. feat: add derivation expression evaluator (#63)

    Adds support for the evaluation of extension function's return type
    expressions.
    It's meant to be used by producers to correctly infer return types.
    It uses antlr grammar (temporarily) copy-pasted from `substrait-java`
    w/o changes.
    tokoko authored Oct 28, 2024
    Copy the full SHA
    b78a614 View commit details
8 changes: 8 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM mcr.microsoft.com/vscode/devcontainers/python:3.10-buster
USER vscode
RUN curl -s "https://get.sdkman.io" | bash
SHELL ["/bin/bash", "-c"]
RUN source "/home/vscode/.sdkman/bin/sdkman-init.sh" && sdk install java 20.0.2-graalce
RUN mkdir -p ~/lib && cd ~/lib && curl -L -O http://www.antlr.org/download/antlr-4.13.1-complete.jar
ENV ANTLR_JAR="~/lib/antlr-4.13.1-complete.jar"
USER root
24 changes: 24 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"name": "substrait-python-devcontainer",
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},

// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {
// "ghcr.io/devcontainers/features/nix:1": {}
// },

// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "poetry install"

// Configure tool-specific properties.
// "customizations": {},

// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
antlr:
java -jar ${ANTLR_JAR} -o src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4
209 changes: 209 additions & 0 deletions SubstraitType.g4
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
grammar SubstraitType;

//
fragment A : [aA];
fragment B : [bB];
fragment C : [cC];
fragment D : [dD];
fragment E : [eE];
fragment F : [fF];
fragment G : [gG];
fragment H : [hH];
fragment I : [iI];
fragment J : [jJ];
fragment K : [kK];
fragment L : [lL];
fragment M : [mM];
fragment N : [nN];
fragment O : [oO];
fragment P : [pP];
fragment Q : [qQ];
fragment R : [rR];
fragment S : [sS];
fragment T : [tT];
fragment U : [uU];
fragment V : [vV];
fragment W : [wW];
fragment X : [xX];
fragment Y : [yY];
fragment Z : [zZ];


If : I F;
Then : T H E N;
Else : E L S E;

// TYPES
Boolean : B O O L E A N;
I8 : I '8';
I16 : I '16';
I32 : I '32';
I64 : I '64';
FP32 : F P '32';
FP64 : F P '64';
String : S T R I N G;
Binary : B I N A R Y;
Timestamp: T I M E S T A M P;
TimestampTZ: T I M E S T A M P '_' T Z;
Date : D A T E;
Time : T I M E;
IntervalYear: I N T E R V A L '_' Y E A R;
IntervalDay: I N T E R V A L '_' D A Y;
IntervalCompound: I N T E R V A L '_' C O M P O U N D;
UUID : U U I D;
Decimal : D E C I M A L;
PrecisionTimestamp: P R E C I S I O N '_' T I M E S T A M P;
PrecisionTimestampTZ: P R E C I S I O N '_' T I M E S T A M P '_' T Z;
FixedChar: F I X E D C H A R;
VarChar : V A R C H A R;
FixedBinary: F I X E D B I N A R Y;
Struct : S T R U C T;
NStruct : N S T R U C T;
List : L I S T;
Map : M A P;
ANY : A N Y;
UserDefined: U '!';


// OPERATIONS
And : A N D;
Or : O R;
Assign : ':=';

// COMPARE
Eq : '=';
NotEquals: '!=';
Gte : '>=';
Lte : '<=';
Gt : '>';
Lt : '<';
Bang : '!';


// MATH
Plus : '+';
Minus : '-';
Asterisk : '*';
ForwardSlash : '/';
Percent : '%';

// ORGANIZE
OBracket : '[';
CBracket : ']';
OParen : '(';
CParen : ')';
SColon : ';';
Comma : ',';
QMark : '?';
Colon : ':';
SingleQuote: '\'';


Number
: '-'? Int
;

Identifier
: ('a'..'z' | 'A'..'Z' | '_' | '$') ('a'..'z' | 'A'..'Z' | '_' | '$' | Digit)*
;

LineComment
: '//' ~[\r\n]* -> channel(HIDDEN)
;

BlockComment
: ( '/*'
( '/'* BlockComment
| ~[/*]
| '/'+ ~[/*]
| '*'+ ~[/*]
)*
'*'*
'*/'
) -> channel(HIDDEN)
;
Whitespace
: [ \t]+ -> channel(HIDDEN)
;
Newline
: ( '\r' '\n'?
| '\n'
)
;
fragment Int
: '1'..'9' Digit*
| '0'
;
fragment Digit
: '0'..'9'
;
start: expr EOF;
scalarType
: Boolean #Boolean
| I8 #i8
| I16 #i16
| I32 #i32
| I64 #i64
| FP32 #fp32
| FP64 #fp64
| String #string
| Binary #binary
| Timestamp #timestamp
| TimestampTZ #timestampTz
| Date #date
| Time #time
| IntervalYear #intervalYear
| UUID #uuid
| UserDefined Identifier #userDefined
;
parameterizedType
: FixedChar isnull='?'? Lt len=numericParameter Gt #fixedChar
| VarChar isnull='?'? Lt len=numericParameter Gt #varChar
| FixedBinary isnull='?'? Lt len=numericParameter Gt #fixedBinary
| Decimal isnull='?'? Lt precision=numericParameter Comma scale=numericParameter Gt #decimal
| IntervalDay isnull='?'? Lt precision=numericParameter Gt #intervalDay
| IntervalCompound isnull='?'? Lt precision=numericParameter Gt #intervalCompound
| PrecisionTimestamp isnull='?'? Lt precision=numericParameter Gt #precisionTimestamp
| PrecisionTimestampTZ isnull='?'? Lt precision=numericParameter Gt #precisionTimestampTZ
| Struct isnull='?'? Lt expr (Comma expr)* Gt #struct
| NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct
| List isnull='?'? Lt expr Gt #list
| Map isnull='?'? Lt key=expr Comma value=expr Gt #map
;
numericParameter
: Number #numericLiteral
| Identifier #numericParameterName
| expr #numericExpression
;
anyType: ANY;
type
: scalarType isnull='?'?
| parameterizedType
| anyType isnull='?'?
;
// : (OParen innerExpr CParen | innerExpr)
expr
: OParen expr CParen #ParenExpression
| Identifier Eq expr Newline+ (Identifier Eq expr Newline+)* finalType=type Newline* #MultilineDefinition
| type #TypeLiteral
| number=Number #LiteralNumber
| identifier=Identifier isnull='?'? #TypeParam
| Identifier OParen (expr (Comma expr)*)? CParen #FunctionCall
| left=expr op=(And | Or | Plus | Minus | Lt | Gt | Eq | NotEquals | Lte | Gte | Asterisk | ForwardSlash) right=expr #BinaryExpr
| If ifExpr=expr Then thenExpr=expr Else elseExpr=expr #IfExpr
| (Bang) expr #NotExpr
| ifExpr=expr QMark thenExpr=expr Colon elseExpr=expr #Ternary
;
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -12,8 +12,9 @@ dynamic = ["version"]
write_to = "src/substrait/_version.py"

[project.optional-dependencies]
extensions = ["antlr4-python3-runtime"]
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
test = ["pytest >= 7.0.0"]
test = ["pytest >= 7.0.0", "antlr4-python3-runtime"]

[tool.pytest.ini_options]
pythonpath = "src"
102 changes: 102 additions & 0 deletions src/substrait/derivation_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Optional
from antlr4 import InputStream, CommonTokenStream
from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.proto.type_pb2 import Type


def _evaluate(x, values: dict):
if type(x) == SubstraitTypeParser.BinaryExprContext:
left = _evaluate(x.left, values)
right = _evaluate(x.right, values)

if x.op.text == "+":
return left + right
elif x.op.text == "-":
return left - right
elif x.op.text == "*":
return left * right
elif x.op.text == ">":
return left > right
elif x.op.text == ">=":
return left >= right
elif x.op.text == "<":
return left < right
elif x.op.text == "<=":
return left <= right
else:
raise Exception(f"Unknown binary op {x.op.text}")
elif type(x) == SubstraitTypeParser.LiteralNumberContext:
return int(x.number.text)
elif type(x) == SubstraitTypeParser.TypeParamContext:
return values[x.identifier.text]
elif type(x) == SubstraitTypeParser.NumericParameterNameContext:
return values[x.Identifier().symbol.text]
elif type(x) == SubstraitTypeParser.ParenExpressionContext:
return _evaluate(x.expr(), values)
elif type(x) == SubstraitTypeParser.FunctionCallContext:
exprs = [_evaluate(e, values) for e in x.expr()]
func = x.Identifier().symbol.text

if func == "min":
return min(*exprs)
elif func == "max":
return max(*exprs)
else:
raise Exception(f"Unknown function {func}")
elif type(x) == SubstraitTypeParser.TypeContext:
scalar_type = x.scalarType()
parametrized_type = x.parameterizedType()
if scalar_type:
if isinstance(scalar_type, SubstraitTypeParser.I8Context):
return Type(i8=Type.I8())
elif isinstance(scalar_type, SubstraitTypeParser.I16Context):
return Type(i16=Type.I16())
elif isinstance(scalar_type, SubstraitTypeParser.I32Context):
return Type(i32=Type.I32())
elif isinstance(scalar_type, SubstraitTypeParser.I64Context):
return Type(i64=Type.I64())
elif isinstance(scalar_type, SubstraitTypeParser.Fp32Context):
return Type(fp32=Type.FP32())
elif isinstance(scalar_type, SubstraitTypeParser.Fp64Context):
return Type(fp64=Type.FP64())
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
return Type(bool=Type.Boolean())
else:
raise Exception(f"Unknown scalar type {type(scalar_type)}")
elif parametrized_type:
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
precision = _evaluate(parametrized_type.precision, values)
scale = _evaluate(parametrized_type.scale, values)
return Type(decimal=Type.Decimal(precision=precision, scale=scale))
raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
else:
raise Exception("either scalar_type or parametrized_type is required")
elif type(x) == SubstraitTypeParser.NumericExpressionContext:
return _evaluate(x.expr(), values)
elif type(x) == SubstraitTypeParser.TernaryContext:
ifExpr = _evaluate(x.ifExpr, values)
thenExpr = _evaluate(x.thenExpr, values)
elseExpr = _evaluate(x.elseExpr, values)

return thenExpr if ifExpr else elseExpr
elif type(x) == SubstraitTypeParser.MultilineDefinitionContext:
lines = zip(x.Identifier(), x.expr())

for i, e in lines:
identifier = i.symbol.text
expr_eval = _evaluate(e, values)
values[identifier] = expr_eval

return _evaluate(x.finalType, values)
elif type(x) == SubstraitTypeParser.TypeLiteralContext:
return _evaluate(x.type_(), values)
else:
raise Exception(f"Unknown token type {type(x)}")


def evaluate(x: str, values: Optional[dict] = None):
lexer = SubstraitTypeLexer(InputStream(x))
stream = CommonTokenStream(lexer)
parser = SubstraitTypeParser(stream)
return _evaluate(parser.expr(), values)
Loading