diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java index 5c49936d..3e841471 100644 --- a/common/src/main/java/dev/cel/common/CelOptions.java +++ b/common/src/main/java/dev/cel/common/CelOptions.java @@ -67,6 +67,8 @@ public enum ProtoUnsetFieldOptions { public abstract boolean retainUnbalancedLogicalExpressions(); + public abstract boolean enableHiddenAccumulatorVar(); + // Type-Checker related options public abstract boolean enableCompileTimeOverloadResolution(); @@ -188,6 +190,7 @@ public static Builder newBuilder() { .populateMacroCalls(false) .retainRepeatedUnaryOperators(false) .retainUnbalancedLogicalExpressions(false) + .enableHiddenAccumulatorVar(false) // Type-Checker options .enableCompileTimeOverloadResolution(false) .enableHomogeneousLiterals(false) @@ -319,6 +322,16 @@ public abstract static class Builder { */ public abstract Builder retainUnbalancedLogicalExpressions(boolean value); + /** + * Enable the use of a hidden accumulator variable name. + * + *

This is a temporary option to transition to using an internal identifier for the + * accumulator variable used by builtin comprehension macros. When enabled, parses result in a + * semantically equivalent AST, but with a different accumulator variable that can't be directly + * referenced in the source expression. + */ + public abstract Builder enableHiddenAccumulatorVar(boolean value); + // Type-Checker related options /** diff --git a/parser/src/main/java/dev/cel/parser/CelMacroExprFactory.java b/parser/src/main/java/dev/cel/parser/CelMacroExprFactory.java index d7c917c2..2363bd51 100644 --- a/parser/src/main/java/dev/cel/parser/CelMacroExprFactory.java +++ b/parser/src/main/java/dev/cel/parser/CelMacroExprFactory.java @@ -51,6 +51,9 @@ public final CelExpr reportError(String message) { /** Reports a {@link CelIssue} and returns a sentinel {@link CelExpr} that indicates an error. */ public abstract CelExpr reportError(CelIssue error); + /** Returns the default accumulator variable name used by macros implementing comprehensions. */ + public abstract String getAccumulatorVarName(); + /** Retrieves the source location for the given {@link CelExpr} ID. */ public final CelSourceLocation getSourceLocation(CelExpr expr) { return getSourceLocation(expr.id()); diff --git a/parser/src/main/java/dev/cel/parser/CelStandardMacro.java b/parser/src/main/java/dev/cel/parser/CelStandardMacro.java index 3ea54f9e..0dce8b29 100644 --- a/parser/src/main/java/dev/cel/parser/CelStandardMacro.java +++ b/parser/src/main/java/dev/cel/parser/CelStandardMacro.java @@ -78,8 +78,6 @@ public enum CelStandardMacro { public static final ImmutableSet STANDARD_MACROS = ImmutableSet.of(HAS, ALL, EXISTS, EXISTS_ONE, MAP, MAP_FILTER, FILTER); - private static final String ACCUMULATOR_VAR = "__result__"; - private final CelMacro macro; CelStandardMacro(CelMacro macro) { @@ -123,14 +121,23 @@ private static Optional expandAllMacro( CelExpr accuInit = exprFactory.newBoolLiteral(true); CelExpr condition = exprFactory.newGlobalCall( - Operator.NOT_STRICTLY_FALSE.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR)); + Operator.NOT_STRICTLY_FALSE.getFunction(), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())); CelExpr step = exprFactory.newGlobalCall( - Operator.LOGICAL_AND.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR), arg1); - CelExpr result = exprFactory.newIdentifier(ACCUMULATOR_VAR); + Operator.LOGICAL_AND.getFunction(), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()), + arg1); + CelExpr result = exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()); return Optional.of( exprFactory.fold( - arg0.ident().name(), target, ACCUMULATOR_VAR, accuInit, condition, step, result)); + arg0.ident().name(), + target, + exprFactory.getAccumulatorVarName(), + accuInit, + condition, + step, + result)); } // CelMacroExpander implementation for CEL's exists() macro. @@ -149,14 +156,23 @@ private static Optional expandExistsMacro( exprFactory.newGlobalCall( Operator.NOT_STRICTLY_FALSE.getFunction(), exprFactory.newGlobalCall( - Operator.LOGICAL_NOT.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR))); + Operator.LOGICAL_NOT.getFunction(), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()))); CelExpr step = exprFactory.newGlobalCall( - Operator.LOGICAL_OR.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR), arg1); - CelExpr result = exprFactory.newIdentifier(ACCUMULATOR_VAR); + Operator.LOGICAL_OR.getFunction(), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()), + arg1); + CelExpr result = exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()); return Optional.of( exprFactory.fold( - arg0.ident().name(), target, ACCUMULATOR_VAR, accuInit, condition, step, result)); + arg0.ident().name(), + target, + exprFactory.getAccumulatorVarName(), + accuInit, + condition, + step, + result)); } // CelMacroExpander implementation for CEL's exists_one() macro. @@ -178,17 +194,23 @@ private static Optional expandExistsOneMacro( arg1, exprFactory.newGlobalCall( Operator.ADD.getFunction(), - exprFactory.newIdentifier(ACCUMULATOR_VAR), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()), exprFactory.newIntLiteral(1)), - exprFactory.newIdentifier(ACCUMULATOR_VAR)); + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())); CelExpr result = exprFactory.newGlobalCall( Operator.EQUALS.getFunction(), - exprFactory.newIdentifier(ACCUMULATOR_VAR), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()), exprFactory.newIntLiteral(1)); return Optional.of( exprFactory.fold( - arg0.ident().name(), target, ACCUMULATOR_VAR, accuInit, condition, step, result)); + arg0.ident().name(), + target, + exprFactory.getAccumulatorVarName(), + accuInit, + condition, + step, + result)); } // CelMacroExpander implementation for CEL's map() macro. @@ -218,7 +240,7 @@ private static Optional expandMapMacro( CelExpr step = exprFactory.newGlobalCall( Operator.ADD.getFunction(), - exprFactory.newIdentifier(ACCUMULATOR_VAR), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()), exprFactory.newList(arg1)); if (arg2 != null) { step = @@ -226,17 +248,17 @@ private static Optional expandMapMacro( Operator.CONDITIONAL.getFunction(), arg2, step, - exprFactory.newIdentifier(ACCUMULATOR_VAR)); + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())); } return Optional.of( exprFactory.fold( arg0.ident().name(), target, - ACCUMULATOR_VAR, + exprFactory.getAccumulatorVarName(), accuInit, condition, step, - exprFactory.newIdentifier(ACCUMULATOR_VAR))); + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()))); } // CelMacroExpander implementation for CEL's filter() macro. @@ -255,23 +277,23 @@ private static Optional expandFilterMacro( CelExpr step = exprFactory.newGlobalCall( Operator.ADD.getFunction(), - exprFactory.newIdentifier(ACCUMULATOR_VAR), + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()), exprFactory.newList(arg0)); step = exprFactory.newGlobalCall( Operator.CONDITIONAL.getFunction(), arg1, step, - exprFactory.newIdentifier(ACCUMULATOR_VAR)); + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())); return Optional.of( exprFactory.fold( arg0.ident().name(), target, - ACCUMULATOR_VAR, + exprFactory.getAccumulatorVarName(), accuInit, condition, step, - exprFactory.newIdentifier(ACCUMULATOR_VAR))); + exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()))); } private static CelExpr reportArgumentError(CelMacroExprFactory exprFactory, CelExpr argument) { diff --git a/parser/src/main/java/dev/cel/parser/Parser.java b/parser/src/main/java/dev/cel/parser/Parser.java index 61caff01..52ba6db3 100644 --- a/parser/src/main/java/dev/cel/parser/Parser.java +++ b/parser/src/main/java/dev/cel/parser/Parser.java @@ -125,6 +125,8 @@ final class Parser extends CELBaseVisitor { "var", "void", "while"); + private static final String ACCUMULATOR_NAME = "__result__"; + private static final String HIDDEN_ACCUMULATOR_NAME = "@result"; static CelValidationResult parse(CelParserImpl parser, CelSource source, CelOptions options) { if (source.getContent().size() > options.maxExpressionCodePointSize()) { @@ -142,7 +144,11 @@ static CelValidationResult parse(CelParserImpl parser, CelSource source, CelOpti CELParser antlrParser = new CELParser(new CommonTokenStream(antlrLexer)); CelSource.Builder sourceInfo = source.toBuilder(); sourceInfo.setDescription(source.getDescription()); - ExprFactory exprFactory = new ExprFactory(antlrParser, sourceInfo); + ExprFactory exprFactory = + new ExprFactory( + antlrParser, + sourceInfo, + options.enableHiddenAccumulatorVar() ? HIDDEN_ACCUMULATOR_NAME : ACCUMULATOR_NAME); Parser parserImpl = new Parser(parser, options, sourceInfo, exprFactory); ErrorListener errorListener = new ErrorListener(exprFactory); antlrLexer.removeErrorListeners(); @@ -1033,12 +1039,17 @@ private static final class ExprFactory extends CelMacroExprFactory { private final CelSource.Builder sourceInfo; private final ArrayList issues; private final ArrayDeque positions; + private final String accumulatorVarName; - private ExprFactory(org.antlr.v4.runtime.Parser recognizer, CelSource.Builder sourceInfo) { + private ExprFactory( + org.antlr.v4.runtime.Parser recognizer, + CelSource.Builder sourceInfo, + String accumulatorVarName) { this.recognizer = recognizer; this.sourceInfo = sourceInfo; this.issues = new ArrayList<>(); this.positions = new ArrayDeque<>(1); // Currently this usually contains at most 1 position. + this.accumulatorVarName = accumulatorVarName; } // Implementation of CelExprFactory. @@ -1062,6 +1073,11 @@ public CelExpr reportError(CelIssue error) { return ERROR; } + @Override + public String getAccumulatorVarName() { + return accumulatorVarName; + } + // Internal methods used by the parser but not part of the public API. @FormatMethod @CanIgnoreReturnValue @@ -1079,6 +1095,7 @@ private CelExpr reportError(ParserRuleContext context, String message) { private CelExpr reportError(Token token, String message) { return reportError(CelIssue.formatError(getLocation(token), message)); } + // Implementation of CelExprFactory. diff --git a/parser/src/test/java/dev/cel/parser/CelMacroExprFactoryTest.java b/parser/src/test/java/dev/cel/parser/CelMacroExprFactoryTest.java index 4d110b79..daac056d 100644 --- a/parser/src/test/java/dev/cel/parser/CelMacroExprFactoryTest.java +++ b/parser/src/test/java/dev/cel/parser/CelMacroExprFactoryTest.java @@ -61,6 +61,11 @@ public CelExpr reportError(CelIssue issue) { return CelExpr.newBuilder().setId(nextExprId()).setConstant(Constants.ERROR).build(); } + @Override + public String getAccumulatorVarName() { + return "__result__"; + } + @Override protected CelSourceLocation getSourceLocation(long exprId) { return CelSourceLocation.NONE; diff --git a/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java b/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java index a091cc60..073b548b 100644 --- a/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java +++ b/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java @@ -72,6 +72,16 @@ public final class CelParserParameterizedTest extends BaselineTestCase { .setOptions(CelOptions.current().populateMacroCalls(true).build()) .build(); + private static final CelParser PARSER_WITH_UPDATED_ACCU_VAR = + PARSER + .toParserBuilder() + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHiddenAccumulatorVar(true) + .build()) + .build(); + @Test public void parser() { runTest(PARSER, "x * 2"); @@ -193,6 +203,17 @@ public void parser() { "while"); } + @Test + public void parser_updatedAccuVar() { + runTest(PARSER_WITH_UPDATED_ACCU_VAR, "x * 2"); + runTest(PARSER_WITH_UPDATED_ACCU_VAR, "has(m.f)"); + runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.exists_one(v, f)"); + runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.all(v, f)"); + runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.map(v, f)"); + runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.map(v, p, f)"); + runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.filter(v, p)"); + } + @Test public void parser_errors() { runTest(PARSER, "*@a | b"); diff --git a/parser/src/test/resources/parser_updatedAccuVar.baseline b/parser/src/test/resources/parser_updatedAccuVar.baseline new file mode 100644 index 00000000..e2d198fc --- /dev/null +++ b/parser/src/test/resources/parser_updatedAccuVar.baseline @@ -0,0 +1,280 @@ +I: x * 2 +=====> +P: _*_( + x^#1:Expr.Ident#, + 2^#3:int64# +)^#2:Expr.Call# +L: _*_( + x^#1[1,0]#, + 2^#3[1,4]# +)^#2[1,2]# + +I: has(m.f) +=====> +P: m^#2:Expr.Ident#.f~test-only~^#4:Expr.Select# +L: m^#2[1,4]#.f~test-only~^#4[1,3]# +M: has( + m^#2:Expr.Ident#.f^#3:Expr.Select# +)^#0:Expr.Call# + +I: m.exists_one(v, f) +=====> +P: __comprehension__( + // Variable + v, + // Target + m^#1:Expr.Ident#, + // Accumulator + @result, + // Init + 0^#5:int64#, + // LoopCondition + true^#6:bool#, + // LoopStep + _?_:_( + f^#4:Expr.Ident#, + _+_( + @result^#7:Expr.Ident#, + 1^#8:int64# + )^#9:Expr.Call#, + @result^#10:Expr.Ident# + )^#11:Expr.Call#, + // Result + _==_( + @result^#12:Expr.Ident#, + 1^#13:int64# + )^#14:Expr.Call#)^#15:Expr.Comprehension# +L: __comprehension__( + // Variable + v, + // Target + m^#1[1,0]#, + // Accumulator + @result, + // Init + 0^#5[1,12]#, + // LoopCondition + true^#6[1,12]#, + // LoopStep + _?_:_( + f^#4[1,16]#, + _+_( + @result^#7[1,12]#, + 1^#8[1,12]# + )^#9[1,12]#, + @result^#10[1,12]# + )^#11[1,12]#, + // Result + _==_( + @result^#12[1,12]#, + 1^#13[1,12]# + )^#14[1,12]#)^#15[1,12]# +M: m^#1:Expr.Ident#.exists_one( + v^#3:Expr.Ident#, + f^#4:Expr.Ident# +)^#0:Expr.Call# + +I: m.all(v, f) +=====> +P: __comprehension__( + // Variable + v, + // Target + m^#1:Expr.Ident#, + // Accumulator + @result, + // Init + true^#5:bool#, + // LoopCondition + @not_strictly_false( + @result^#6:Expr.Ident# + )^#7:Expr.Call#, + // LoopStep + _&&_( + @result^#8:Expr.Ident#, + f^#4:Expr.Ident# + )^#9:Expr.Call#, + // Result + @result^#10:Expr.Ident#)^#11:Expr.Comprehension# +L: __comprehension__( + // Variable + v, + // Target + m^#1[1,0]#, + // Accumulator + @result, + // Init + true^#5[1,5]#, + // LoopCondition + @not_strictly_false( + @result^#6[1,5]# + )^#7[1,5]#, + // LoopStep + _&&_( + @result^#8[1,5]#, + f^#4[1,9]# + )^#9[1,5]#, + // Result + @result^#10[1,5]#)^#11[1,5]# +M: m^#1:Expr.Ident#.all( + v^#3:Expr.Ident#, + f^#4:Expr.Ident# +)^#0:Expr.Call# + +I: m.map(v, f) +=====> +P: __comprehension__( + // Variable + v, + // Target + m^#1:Expr.Ident#, + // Accumulator + @result, + // Init + []^#5:Expr.CreateList#, + // LoopCondition + true^#6:bool#, + // LoopStep + _+_( + @result^#7:Expr.Ident#, + [ + f^#4:Expr.Ident# + ]^#8:Expr.CreateList# + )^#9:Expr.Call#, + // Result + @result^#10:Expr.Ident#)^#11:Expr.Comprehension# +L: __comprehension__( + // Variable + v, + // Target + m^#1[1,0]#, + // Accumulator + @result, + // Init + []^#5[1,5]#, + // LoopCondition + true^#6[1,5]#, + // LoopStep + _+_( + @result^#7[1,5]#, + [ + f^#4[1,9]# + ]^#8[1,5]# + )^#9[1,5]#, + // Result + @result^#10[1,5]#)^#11[1,5]# +M: m^#1:Expr.Ident#.map( + v^#3:Expr.Ident#, + f^#4:Expr.Ident# +)^#0:Expr.Call# + +I: m.map(v, p, f) +=====> +P: __comprehension__( + // Variable + v, + // Target + m^#1:Expr.Ident#, + // Accumulator + @result, + // Init + []^#6:Expr.CreateList#, + // LoopCondition + true^#7:bool#, + // LoopStep + _?_:_( + p^#4:Expr.Ident#, + _+_( + @result^#8:Expr.Ident#, + [ + f^#5:Expr.Ident# + ]^#9:Expr.CreateList# + )^#10:Expr.Call#, + @result^#11:Expr.Ident# + )^#12:Expr.Call#, + // Result + @result^#13:Expr.Ident#)^#14:Expr.Comprehension# +L: __comprehension__( + // Variable + v, + // Target + m^#1[1,0]#, + // Accumulator + @result, + // Init + []^#6[1,5]#, + // LoopCondition + true^#7[1,5]#, + // LoopStep + _?_:_( + p^#4[1,9]#, + _+_( + @result^#8[1,5]#, + [ + f^#5[1,12]# + ]^#9[1,5]# + )^#10[1,5]#, + @result^#11[1,5]# + )^#12[1,5]#, + // Result + @result^#13[1,5]#)^#14[1,5]# +M: m^#1:Expr.Ident#.map( + v^#3:Expr.Ident#, + p^#4:Expr.Ident#, + f^#5:Expr.Ident# +)^#0:Expr.Call# + +I: m.filter(v, p) +=====> +P: __comprehension__( + // Variable + v, + // Target + m^#1:Expr.Ident#, + // Accumulator + @result, + // Init + []^#5:Expr.CreateList#, + // LoopCondition + true^#6:bool#, + // LoopStep + _?_:_( + p^#4:Expr.Ident#, + _+_( + @result^#7:Expr.Ident#, + [ + v^#3:Expr.Ident# + ]^#8:Expr.CreateList# + )^#9:Expr.Call#, + @result^#10:Expr.Ident# + )^#11:Expr.Call#, + // Result + @result^#12:Expr.Ident#)^#13:Expr.Comprehension# +L: __comprehension__( + // Variable + v, + // Target + m^#1[1,0]#, + // Accumulator + @result, + // Init + []^#5[1,8]#, + // LoopCondition + true^#6[1,8]#, + // LoopStep + _?_:_( + p^#4[1,12]#, + _+_( + @result^#7[1,8]#, + [ + v^#3[1,9]# + ]^#8[1,8]# + )^#9[1,8]#, + @result^#10[1,8]# + )^#11[1,8]#, + // Result + @result^#12[1,8]#)^#13[1,8]# +M: m^#1:Expr.Ident#.filter( + v^#3:Expr.Ident#, + p^#4:Expr.Ident# +)^#0:Expr.Call# \ No newline at end of file