Skip to content

Commit

Permalink
Add option to enable updated accumulator variable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707303232
  • Loading branch information
jnthntatum authored and copybara-github committed Dec 19, 2024
1 parent 47f3ddd commit f34ebbd
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 24 deletions.
13 changes: 13 additions & 0 deletions common/src/main/java/dev/cel/common/CelOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public enum ProtoUnsetFieldOptions {

public abstract boolean retainUnbalancedLogicalExpressions();

public abstract boolean enableHiddenAccumulatorVar();

// Type-Checker related options

public abstract boolean enableCompileTimeOverloadResolution();
Expand Down Expand Up @@ -188,6 +190,7 @@ public static Builder newBuilder() {
.populateMacroCalls(false)
.retainRepeatedUnaryOperators(false)
.retainUnbalancedLogicalExpressions(false)
.enableHiddenAccumulatorVar(false)
// Type-Checker options
.enableCompileTimeOverloadResolution(false)
.enableHomogeneousLiterals(false)
Expand Down Expand Up @@ -319,6 +322,16 @@ public abstract static class Builder {
*/
public abstract Builder retainUnbalancedLogicalExpressions(boolean value);

/**
* Enable the use of a hidden accumulator variable name.
*
* <p>This is a temporary option to transition to using an internal identifier for the
* accumulator variable used by builtin comprehension macros. When enabled, parses result in a
* semantically equivalent AST, but with a different accumulator variable that can't be directly
* referenced in the source expression.
*/
public abstract Builder enableHiddenAccumulatorVar(boolean value);

// Type-Checker related options

/**
Expand Down
3 changes: 3 additions & 0 deletions parser/src/main/java/dev/cel/parser/CelMacroExprFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ public final CelExpr reportError(String message) {
/** Reports a {@link CelIssue} and returns a sentinel {@link CelExpr} that indicates an error. */
public abstract CelExpr reportError(CelIssue error);

/** Returns the default accumulator variable name used by macros implementing comprehensions. */
public abstract String getAccumulatorVarName();

/** Retrieves the source location for the given {@link CelExpr} ID. */
public final CelSourceLocation getSourceLocation(CelExpr expr) {
return getSourceLocation(expr.id());
Expand Down
66 changes: 44 additions & 22 deletions parser/src/main/java/dev/cel/parser/CelStandardMacro.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ public enum CelStandardMacro {
public static final ImmutableSet<CelStandardMacro> STANDARD_MACROS =
ImmutableSet.of(HAS, ALL, EXISTS, EXISTS_ONE, MAP, MAP_FILTER, FILTER);

private static final String ACCUMULATOR_VAR = "__result__";

private final CelMacro macro;

CelStandardMacro(CelMacro macro) {
Expand Down Expand Up @@ -123,14 +121,23 @@ private static Optional<CelExpr> expandAllMacro(
CelExpr accuInit = exprFactory.newBoolLiteral(true);
CelExpr condition =
exprFactory.newGlobalCall(
Operator.NOT_STRICTLY_FALSE.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR));
Operator.NOT_STRICTLY_FALSE.getFunction(),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
CelExpr step =
exprFactory.newGlobalCall(
Operator.LOGICAL_AND.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR), arg1);
CelExpr result = exprFactory.newIdentifier(ACCUMULATOR_VAR);
Operator.LOGICAL_AND.getFunction(),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
arg1);
CelExpr result = exprFactory.newIdentifier(exprFactory.getAccumulatorVarName());
return Optional.of(
exprFactory.fold(
arg0.ident().name(), target, ACCUMULATOR_VAR, accuInit, condition, step, result));
arg0.ident().name(),
target,
exprFactory.getAccumulatorVarName(),
accuInit,
condition,
step,
result));
}

// CelMacroExpander implementation for CEL's exists() macro.
Expand All @@ -149,14 +156,23 @@ private static Optional<CelExpr> expandExistsMacro(
exprFactory.newGlobalCall(
Operator.NOT_STRICTLY_FALSE.getFunction(),
exprFactory.newGlobalCall(
Operator.LOGICAL_NOT.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR)));
Operator.LOGICAL_NOT.getFunction(),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
CelExpr step =
exprFactory.newGlobalCall(
Operator.LOGICAL_OR.getFunction(), exprFactory.newIdentifier(ACCUMULATOR_VAR), arg1);
CelExpr result = exprFactory.newIdentifier(ACCUMULATOR_VAR);
Operator.LOGICAL_OR.getFunction(),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
arg1);
CelExpr result = exprFactory.newIdentifier(exprFactory.getAccumulatorVarName());
return Optional.of(
exprFactory.fold(
arg0.ident().name(), target, ACCUMULATOR_VAR, accuInit, condition, step, result));
arg0.ident().name(),
target,
exprFactory.getAccumulatorVarName(),
accuInit,
condition,
step,
result));
}

// CelMacroExpander implementation for CEL's exists_one() macro.
Expand All @@ -178,17 +194,23 @@ private static Optional<CelExpr> expandExistsOneMacro(
arg1,
exprFactory.newGlobalCall(
Operator.ADD.getFunction(),
exprFactory.newIdentifier(ACCUMULATOR_VAR),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
exprFactory.newIntLiteral(1)),
exprFactory.newIdentifier(ACCUMULATOR_VAR));
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
CelExpr result =
exprFactory.newGlobalCall(
Operator.EQUALS.getFunction(),
exprFactory.newIdentifier(ACCUMULATOR_VAR),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
exprFactory.newIntLiteral(1));
return Optional.of(
exprFactory.fold(
arg0.ident().name(), target, ACCUMULATOR_VAR, accuInit, condition, step, result));
arg0.ident().name(),
target,
exprFactory.getAccumulatorVarName(),
accuInit,
condition,
step,
result));
}

// CelMacroExpander implementation for CEL's map() macro.
Expand Down Expand Up @@ -218,25 +240,25 @@ private static Optional<CelExpr> expandMapMacro(
CelExpr step =
exprFactory.newGlobalCall(
Operator.ADD.getFunction(),
exprFactory.newIdentifier(ACCUMULATOR_VAR),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
exprFactory.newList(arg1));
if (arg2 != null) {
step =
exprFactory.newGlobalCall(
Operator.CONDITIONAL.getFunction(),
arg2,
step,
exprFactory.newIdentifier(ACCUMULATOR_VAR));
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
}
return Optional.of(
exprFactory.fold(
arg0.ident().name(),
target,
ACCUMULATOR_VAR,
exprFactory.getAccumulatorVarName(),
accuInit,
condition,
step,
exprFactory.newIdentifier(ACCUMULATOR_VAR)));
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
}

// CelMacroExpander implementation for CEL's filter() macro.
Expand All @@ -255,23 +277,23 @@ private static Optional<CelExpr> expandFilterMacro(
CelExpr step =
exprFactory.newGlobalCall(
Operator.ADD.getFunction(),
exprFactory.newIdentifier(ACCUMULATOR_VAR),
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
exprFactory.newList(arg0));
step =
exprFactory.newGlobalCall(
Operator.CONDITIONAL.getFunction(),
arg1,
step,
exprFactory.newIdentifier(ACCUMULATOR_VAR));
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
return Optional.of(
exprFactory.fold(
arg0.ident().name(),
target,
ACCUMULATOR_VAR,
exprFactory.getAccumulatorVarName(),
accuInit,
condition,
step,
exprFactory.newIdentifier(ACCUMULATOR_VAR)));
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
}

private static CelExpr reportArgumentError(CelMacroExprFactory exprFactory, CelExpr argument) {
Expand Down
21 changes: 19 additions & 2 deletions parser/src/main/java/dev/cel/parser/Parser.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ final class Parser extends CELBaseVisitor<CelExpr> {
"var",
"void",
"while");
private static final String ACCUMULATOR_NAME = "__result__";
private static final String HIDDEN_ACCUMULATOR_NAME = "@result";

static CelValidationResult parse(CelParserImpl parser, CelSource source, CelOptions options) {
if (source.getContent().size() > options.maxExpressionCodePointSize()) {
Expand All @@ -142,7 +144,11 @@ static CelValidationResult parse(CelParserImpl parser, CelSource source, CelOpti
CELParser antlrParser = new CELParser(new CommonTokenStream(antlrLexer));
CelSource.Builder sourceInfo = source.toBuilder();
sourceInfo.setDescription(source.getDescription());
ExprFactory exprFactory = new ExprFactory(antlrParser, sourceInfo);
ExprFactory exprFactory =
new ExprFactory(
antlrParser,
sourceInfo,
options.enableHiddenAccumulatorVar() ? HIDDEN_ACCUMULATOR_NAME : ACCUMULATOR_NAME);
Parser parserImpl = new Parser(parser, options, sourceInfo, exprFactory);
ErrorListener errorListener = new ErrorListener(exprFactory);
antlrLexer.removeErrorListeners();
Expand Down Expand Up @@ -1033,12 +1039,17 @@ private static final class ExprFactory extends CelMacroExprFactory {
private final CelSource.Builder sourceInfo;
private final ArrayList<CelIssue> issues;
private final ArrayDeque<Integer> positions;
private final String accumulatorVarName;

private ExprFactory(org.antlr.v4.runtime.Parser recognizer, CelSource.Builder sourceInfo) {
private ExprFactory(
org.antlr.v4.runtime.Parser recognizer,
CelSource.Builder sourceInfo,
String accumulatorVarName) {
this.recognizer = recognizer;
this.sourceInfo = sourceInfo;
this.issues = new ArrayList<>();
this.positions = new ArrayDeque<>(1); // Currently this usually contains at most 1 position.
this.accumulatorVarName = accumulatorVarName;
}

// Implementation of CelExprFactory.
Expand All @@ -1062,6 +1073,11 @@ public CelExpr reportError(CelIssue error) {
return ERROR;
}

@Override
public String getAccumulatorVarName() {
return accumulatorVarName;
}

// Internal methods used by the parser but not part of the public API.
@FormatMethod
@CanIgnoreReturnValue
Expand All @@ -1079,6 +1095,7 @@ private CelExpr reportError(ParserRuleContext context, String message) {
private CelExpr reportError(Token token, String message) {
return reportError(CelIssue.formatError(getLocation(token), message));
}


// Implementation of CelExprFactory.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ public CelExpr reportError(CelIssue issue) {
return CelExpr.newBuilder().setId(nextExprId()).setConstant(Constants.ERROR).build();
}

@Override
public String getAccumulatorVarName() {
return "__result__";
}

@Override
protected CelSourceLocation getSourceLocation(long exprId) {
return CelSourceLocation.NONE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ public final class CelParserParameterizedTest extends BaselineTestCase {
.setOptions(CelOptions.current().populateMacroCalls(true).build())
.build();

private static final CelParser PARSER_WITH_UPDATED_ACCU_VAR =
PARSER
.toParserBuilder()
.setOptions(
CelOptions.current()
.populateMacroCalls(true)
.enableHiddenAccumulatorVar(true)
.build())
.build();

@Test
public void parser() {
runTest(PARSER, "x * 2");
Expand Down Expand Up @@ -193,6 +203,17 @@ public void parser() {
"while");
}

@Test
public void parser_updatedAccuVar() {
runTest(PARSER_WITH_UPDATED_ACCU_VAR, "x * 2");
runTest(PARSER_WITH_UPDATED_ACCU_VAR, "has(m.f)");
runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.exists_one(v, f)");
runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.all(v, f)");
runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.map(v, f)");
runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.map(v, p, f)");
runTest(PARSER_WITH_UPDATED_ACCU_VAR, "m.filter(v, p)");
}

@Test
public void parser_errors() {
runTest(PARSER, "*@a | b");
Expand Down
Loading

0 comments on commit f34ebbd

Please sign in to comment.