From 8cb805a93048528ad1cecba046c69dc36ed67466 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 19 Sep 2023 14:19:52 -0700 Subject: [PATCH] Add MutableAst PiperOrigin-RevId: 566741837 --- .../main/java/dev/cel/common/ast/CelExpr.java | 125 ++++++- .../common/ast/CelExprIdGeneratorFactory.java | 11 +- .../java/dev/cel/common/ast/CelExprTest.java | 23 ++ optimizer/BUILD.bazel | 12 + .../main/java/dev/cel/optimizer/BUILD.bazel | 23 +- .../optimizer/CelOptimizationException.java | 23 ++ .../java/dev/cel/optimizer/CelOptimizer.java | 6 +- .../dev/cel/optimizer/CelOptimizerImpl.java | 17 +- .../java/dev/cel/optimizer/MutableAst.java | 165 +++++++++ .../test/java/dev/cel/optimizer/BUILD.bazel | 7 + .../cel/optimizer/CelOptimizerImplTest.java | 17 +- .../dev/cel/optimizer/MutableAstTest.java | 331 ++++++++++++++++++ .../src/main/java/dev/cel/parser/BUILD.bazel | 1 + .../dev/cel/parser/CelUnparserFactory.java | 24 ++ 14 files changed, 761 insertions(+), 24 deletions(-) create mode 100644 optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java create mode 100644 optimizer/src/main/java/dev/cel/optimizer/MutableAst.java create mode 100644 optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java create mode 100644 parser/src/main/java/dev/cel/parser/CelUnparserFactory.java diff --git a/common/src/main/java/dev/cel/common/ast/CelExpr.java b/common/src/main/java/dev/cel/common/ast/CelExpr.java index 3e9f0bc1..82d3e42e 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExpr.java +++ b/common/src/main/java/dev/cel/common/ast/CelExpr.java @@ -221,6 +221,80 @@ public abstract static class Builder { public abstract Builder setExprKind(ExprKind value); + public abstract ExprKind exprKind(); + + /** + * Gets the underlying constant expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CONSTANT}. + */ + public CelConstant constant() { + return exprKind().constant(); + } + + /** + * Gets the underlying identifier expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#IDENT}. + */ + public CelIdent ident() { + return exprKind().ident(); + } + + /** + * Gets the underlying select expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#SELECT}. + */ + public CelSelect select() { + return exprKind().select(); + } + + /** + * Gets the underlying call expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CALL}. + */ + public CelCall call() { + return exprKind().call(); + } + + /** + * Gets the underlying createList expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CREATE_LIST}. + */ + public CelCreateList createList() { + return exprKind().createList(); + } + + /** + * Gets the underlying createStruct expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CREATE_STRUCT}. + */ + public CelCreateStruct createStruct() { + return exprKind().createStruct(); + } + + /** + * Gets the underlying createMap expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#createMap}. + */ + public CelCreateMap createMap() { + return exprKind().createMap(); + } + + /** + * Gets the underlying comprehension expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#COMPREHENSION}. + */ + public CelComprehension comprehension() { + return exprKind().comprehension(); + } + public Builder setConstant(CelConstant constant) { return setExprKind(AutoOneOf_CelExpr_ExprKind.constant(constant)); } @@ -373,6 +447,11 @@ public abstract static class CelSelect { /** Builder for CelSelect. */ @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr operand(); + + public abstract String field(); + + public abstract boolean testOnly(); public abstract Builder setOperand(CelExpr value); @@ -418,9 +497,9 @@ public abstract static class CelCall { /** Builder for CelCall. */ @AutoValue.Builder public abstract static class Builder { - List mutableArgs = new ArrayList<>(); + private List mutableArgs = new ArrayList<>(); - abstract ImmutableList args(); + public abstract ImmutableList args(); public abstract Builder setTarget(CelExpr value); @@ -428,6 +507,8 @@ public abstract static class Builder { public abstract Builder setFunction(String value); + public abstract Optional target(); + // Not public. This only exists to make AutoValue.Builder work. abstract Builder setArgs(ImmutableList value); @@ -501,16 +582,23 @@ public abstract static class CelCreateList { /** Builder for CelCreateList. */ @AutoValue.Builder public abstract static class Builder { - List mutableElements = new ArrayList<>(); + private List mutableElements = new ArrayList<>(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList elements(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList.Builder optionalIndicesBuilder(); // Not public. This only exists to make AutoValue.Builder work. @CanIgnoreReturnValue abstract Builder setElements(ImmutableList elements); + /** Returns an immutable copy of the current mutable elements present in the builder. */ + public ImmutableList getElements() { + return ImmutableList.copyOf(mutableElements); + } + @CanIgnoreReturnValue public Builder setElement(int index, CelExpr element) { checkNotNull(element); @@ -586,8 +674,9 @@ public abstract static class CelCreateStruct { /** Builder for CelCreateStruct. */ @AutoValue.Builder public abstract static class Builder { - List mutableEntries = new ArrayList<>(); + private List mutableEntries = new ArrayList<>(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList entries(); @CanIgnoreReturnValue @@ -597,6 +686,11 @@ public abstract static class Builder { @CanIgnoreReturnValue abstract Builder setEntries(ImmutableList entries); + /** Returns an immutable copy of the current mutable entries present in the builder. */ + public ImmutableList getEntries() { + return ImmutableList.copyOf(mutableEntries); + } + @CanIgnoreReturnValue public Builder setEntry(int index, CelCreateStruct.Entry entry) { checkNotNull(entry); @@ -669,6 +763,8 @@ public abstract static class Entry { @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr value(); + public abstract Builder setId(long value); public abstract Builder setFieldKey(String value); @@ -704,14 +800,20 @@ public abstract static class CelCreateMap { @AutoValue.Builder public abstract static class Builder { - List mutableEntries = new ArrayList<>(); + private List mutableEntries = new ArrayList<>(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList entries(); // Not public. This only exists to make AutoValue.Builder work. @CanIgnoreReturnValue abstract Builder setEntries(ImmutableList entries); + /** Returns an immutable copy of the current mutable entries present in the builder. */ + public ImmutableList getEntries() { + return ImmutableList.copyOf(mutableEntries); + } + @CanIgnoreReturnValue public Builder setEntry(int index, CelCreateMap.Entry entry) { checkNotNull(entry); @@ -784,6 +886,10 @@ public abstract static class Entry { @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr key(); + + public abstract CelExpr value(); + public abstract CelCreateMap.Entry.Builder setId(long value); public abstract CelCreateMap.Entry.Builder setKey(CelExpr value); @@ -868,6 +974,15 @@ public abstract static class CelComprehension { /** Builder for Comprehension. */ @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr iterRange(); + + public abstract CelExpr accuInit(); + + public abstract CelExpr loopCondition(); + + public abstract CelExpr loopStep(); + + public abstract CelExpr result(); public abstract Builder setIterVar(String value); diff --git a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java index d7f312d2..493bdc2d 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java @@ -18,10 +18,15 @@ import java.util.HashMap; /** Factory for populating expression IDs */ -final class CelExprIdGeneratorFactory { +public final class CelExprIdGeneratorFactory { - /** MonotonicIdGenerator increments expression IDs from an initial seed value. */ - static CelExprIdGenerator newMonotonicIdGenerator(long exprId) { + /** + * MonotonicIdGenerator increments expression IDs from an initial seed value. + * + * @param exprId Seed value. Must be non-negative. For example, if 1 is provided {@link + * CelExprIdGenerator#nextExprId} will return 2. + */ + public static CelExprIdGenerator newMonotonicIdGenerator(long exprId) { return new MonotonicIdGenerator(exprId); } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprTest.java b/common/src/test/java/dev/cel/common/ast/CelExprTest.java index 42a40d7a..12c9e972 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprTest.java @@ -111,6 +111,7 @@ public void celExprBuilder_setConstant() { CelExpr celExpr = CelExpr.newBuilder().setConstant(celConstant).build(); assertThat(celExpr.constant()).isEqualTo(celConstant); + assertThat(celExpr.toBuilder().constant()).isEqualTo(celConstant); } @Test @@ -119,6 +120,7 @@ public void celExprBuilder_setIdent() { CelExpr celExpr = CelExpr.newBuilder().setIdent(celIdent).build(); assertThat(celExpr.ident()).isEqualTo(celIdent); + assertThat(celExpr.toBuilder().ident()).isEqualTo(celIdent); } @Test @@ -131,6 +133,7 @@ public void celExprBuilder_setCall() { CelExpr celExpr = CelExpr.newBuilder().setCall(celCall).build(); assertThat(celExpr.call()).isEqualTo(celCall); + assertThat(celExpr.toBuilder().call()).isEqualTo(celCall); } @Test @@ -144,6 +147,8 @@ public void celExprBuilder_setCall_clearTarget() { CelExpr.newBuilder().setCall(celCall.toBuilder().clearTarget().build()).build(); assertThat(celExpr.call()).isEqualTo(CelCall.newBuilder().setFunction("function").build()); + assertThat(celExpr.toBuilder().call()) + .isEqualTo(CelCall.newBuilder().setFunction("function").build()); } @Test @@ -182,6 +187,7 @@ public void celExprBuilder_setSelect() { assertThat(celExpr.select().testOnly()).isFalse(); assertThat(celExpr.select()).isEqualTo(celSelect); + assertThat(celExpr.toBuilder().select()).isEqualTo(celSelect); } @Test @@ -193,6 +199,7 @@ public void celExprBuilder_setCreateList() { CelExpr celExpr = CelExpr.newBuilder().setCreateList(celCreateList).build(); assertThat(celExpr.createList()).isEqualTo(celCreateList); + assertThat(celExpr.toBuilder().createList()).isEqualTo(celCreateList); } @Test @@ -236,6 +243,7 @@ public void celExprBuilder_setCreateStruct() { assertThat(celExpr.createStruct().entries().get(0).optionalEntry()).isFalse(); assertThat(celExpr.createStruct()).isEqualTo(celCreateStruct); + assertThat(celExpr.toBuilder().createStruct()).isEqualTo(celCreateStruct); } @Test @@ -309,6 +317,7 @@ public void celExprBuilder_setComprehension() { CelExpr celExpr = CelExpr.newBuilder().setComprehension(celComprehension).build(); assertThat(celExpr.comprehension()).isEqualTo(celComprehension); + assertThat(celExpr.toBuilder().comprehension()).isEqualTo(celComprehension); } @Test @@ -316,30 +325,44 @@ public void getUnderlyingExpression_unmatchedKind_throws( @TestParameter BuilderExprKindTestCase testCase) { if (!testCase.expectedExprKind.equals(Kind.NOT_SET)) { assertThrows(UnsupportedOperationException.class, () -> testCase.expr.exprKind().notSet()); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().exprKind().notSet()); } if (!testCase.expectedExprKind.equals(Kind.CONSTANT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::constant); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().constant()); } if (!testCase.expectedExprKind.equals(Kind.IDENT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::ident); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().ident()); } if (!testCase.expectedExprKind.equals(Kind.SELECT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::select); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().select()); } if (!testCase.expectedExprKind.equals(Kind.CALL)) { assertThrows(UnsupportedOperationException.class, testCase.expr::call); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().call()); } if (!testCase.expectedExprKind.equals(Kind.CREATE_LIST)) { assertThrows(UnsupportedOperationException.class, testCase.expr::createList); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().createList()); } if (!testCase.expectedExprKind.equals(Kind.CREATE_STRUCT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::createStruct); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().createStruct()); } if (!testCase.expectedExprKind.equals(Kind.CREATE_MAP)) { assertThrows(UnsupportedOperationException.class, testCase.expr::createMap); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().createMap()); } if (!testCase.expectedExprKind.equals(Kind.COMPREHENSION)) { assertThrows(UnsupportedOperationException.class, testCase.expr::comprehension); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().comprehension()); } } diff --git a/optimizer/BUILD.bazel b/optimizer/BUILD.bazel index e5060d65..49cfa4ba 100644 --- a/optimizer/BUILD.bazel +++ b/optimizer/BUILD.bazel @@ -18,6 +18,18 @@ java_library( exports = ["//optimizer/src/main/java/dev/cel/optimizer:ast_optimizer"], ) +java_library( + name = "optimization_exception", + exports = ["//optimizer/src/main/java/dev/cel/optimizer:optimization_exception"], +) + +java_library( + name = "mutable_ast", + testonly = 1, + visibility = ["//optimizer/src/test/java/dev/cel/optimizer:__pkg__"], + exports = ["//optimizer/src/main/java/dev/cel/optimizer:mutable_ast"], +) + java_library( name = "optimizer_impl", testonly = 1, diff --git a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel index 888efa8b..f99d21c8 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel @@ -34,8 +34,8 @@ java_library( ], deps = [ ":ast_optimizer", + ":optimization_exception", "//common", - "//common:compiler_common", "@maven//:com_google_errorprone_error_prone_annotations", ], ) @@ -49,6 +49,7 @@ java_library( ], deps = [ ":ast_optimizer", + ":optimization_exception", ":optimizer_builder", "//bundle:cel", "//common", @@ -69,3 +70,23 @@ java_library( "//common/navigation", ], ) + +java_library( + name = "mutable_ast", + srcs = ["MutableAst.java"], + tags = [ + ], + deps = [ + "//common/annotations", + "//common/ast", + "//common/ast:expr_factory", + "@maven//:com_google_guava_guava", + ], +) + +java_library( + name = "optimization_exception", + srcs = ["CelOptimizationException.java"], + tags = [ + ], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java new file mode 100644 index 00000000..baa79e61 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java @@ -0,0 +1,23 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +/** Checked exception thrown by CelOptimizer during AST optimization. */ +public final class CelOptimizationException extends Exception { + + public CelOptimizationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java index 2faa18e3..a47b83e9 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java @@ -15,7 +15,6 @@ package dev.cel.optimizer; import dev.cel.common.CelAbstractSyntaxTree; -import dev.cel.common.CelValidationException; /** Public interface for optimizing an AST. */ public interface CelOptimizer { @@ -31,8 +30,7 @@ public interface CelOptimizer { * equal to the original expression. * * @param ast A type-checked AST. - * @throws CelValidationException If the optimized AST fails to type-check after a single - * optimization pass. + * @throws CelOptimizationException If any failures occur during any of the AST optimization pass. */ - CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelValidationException; + CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelOptimizationException; } diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java index e5049011..8ce74e1a 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java @@ -33,16 +33,23 @@ final class CelOptimizerImpl implements CelOptimizer { } @Override - public CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelValidationException { + public CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelOptimizationException { if (!ast.isChecked()) { throw new IllegalArgumentException("AST must be type-checked."); } CelAbstractSyntaxTree optimizedAst = ast; - for (CelAstOptimizer optimizer : astOptimizers) { - CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); - optimizedAst = optimizer.optimize(navigableAst, cel); - optimizedAst = cel.check(optimizedAst).getAst(); + try { + for (CelAstOptimizer optimizer : astOptimizers) { + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + optimizedAst = optimizer.optimize(navigableAst, cel); + optimizedAst = cel.check(optimizedAst).getAst(); + } + } catch (CelValidationException e) { + throw new CelOptimizationException( + "Optimized AST failed to type-check: " + e.getMessage(), e); + } catch (RuntimeException e) { + throw new CelOptimizationException("Optimization failure: " + e.getMessage(), e); } return optimizedAst; diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java new file mode 100644 index 00000000..26ae0e0d --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -0,0 +1,165 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import com.google.common.collect.ImmutableList; +import dev.cel.common.annotations.Internal; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelCreateList; +import dev.cel.common.ast.CelExpr.CelCreateMap; +import dev.cel.common.ast.CelExpr.CelCreateStruct; +import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.ast.CelExprIdGenerator; +import dev.cel.common.ast.CelExprIdGeneratorFactory; + +/** MutableAst contains logic for mutating a {@link CelExpr}. */ +@Internal +final class MutableAst { + private static final int MAX_ITERATION_COUNT = 500; + private final CelExpr.Builder newExpr; + private final long exprIdToReplace; + private final CelExprIdGenerator celExprIdGenerator; + private int iterationCount; + + private MutableAst(CelExprIdGenerator celExprIdGenerator, CelExpr.Builder newExpr, long exprId) { + this.celExprIdGenerator = celExprIdGenerator; + this.newExpr = newExpr; + this.exprIdToReplace = exprId; + } + + /** + * Replaces a subtree in the given CelExpr. This is a very dangerous operation. Callers should + * re-typecheck the mutated AST and additionally verify that the resulting AST is semantically + * valid. + * + *

This method should remain package-private. + */ + static CelExpr replaceSubtree(CelExpr root, CelExpr newExpr, long exprIdToReplace) { + // Zero out the expr IDs in the new expression tree first. This ensures that no ID collision + // occurs while attempting to replace the subtree, potentially leading to infinite loop + CelExpr.Builder newExprBuilder = newExpr.toBuilder(); + MutableAst mutableAst = new MutableAst(() -> 0, CelExpr.newBuilder(), -1); + newExprBuilder = mutableAst.visit(newExprBuilder); + + // Replace the subtree + mutableAst = + new MutableAst( + CelExprIdGeneratorFactory.newMonotonicIdGenerator(0), newExprBuilder, exprIdToReplace); + + // TODO: Normalize IDs for macro calls + + return mutableAst.visit(root.toBuilder()).build(); + } + + private CelExpr.Builder visit(CelExpr.Builder expr) { + if (++iterationCount > MAX_ITERATION_COUNT) { + throw new IllegalStateException("Max iteration count reached."); + } + + if (expr.id() == exprIdToReplace) { + return visit(newExpr); + } + + switch (expr.exprKind().getKind()) { + case SELECT: + return visit(expr, expr.select().toBuilder()); + case CALL: + return visit(expr, expr.call().toBuilder()); + case CREATE_LIST: + return visit(expr, expr.createList().toBuilder()); + case CREATE_STRUCT: + return visit(expr, expr.createStruct().toBuilder()); + case CREATE_MAP: + return visit(expr, expr.createMap().toBuilder()); + case COMPREHENSION: + // TODO: Implement functionality. + throw new UnsupportedOperationException("Augmenting comprehensions is not supported yet."); + case CONSTANT: // Fall-through is intended. + case IDENT: + return expr.setId(celExprIdGenerator.nextExprId()); + default: + throw new IllegalArgumentException("unexpected expr kind: " + expr.exprKind().getKind()); + } + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelSelect.Builder selectExpr) { + CelExpr.Builder visitedOperand = visit(selectExpr.operand().toBuilder()); + selectExpr = selectExpr.setOperand(visitedOperand.build()); + + return celExpr.setSelect(selectExpr.build()).setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCall.Builder callExpr) { + if (callExpr.target().isPresent()) { + CelExpr.Builder visitedTargetExpr = visit(callExpr.target().get().toBuilder()); + callExpr = callExpr.setTarget(visitedTargetExpr.build()); + + celExpr.setCall(callExpr.build()); + } + + ImmutableList args = callExpr.args(); + for (int i = 0; i < args.size(); i++) { + CelExpr arg = args.get(i); + CelExpr.Builder visitedArg = visit(arg.toBuilder()); + callExpr.setArg(i, visitedArg.build()); + } + + return celExpr.setCall(callExpr.build()).setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCreateList.Builder createListBuilder) { + ImmutableList elements = createListBuilder.getElements(); + for (int i = 0; i < elements.size(); i++) { + CelExpr.Builder visitedElement = visit(elements.get(i).toBuilder()); + createListBuilder.setElement(i, visitedElement.build()); + } + + return celExpr.setCreateList(createListBuilder.build()).setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit( + CelExpr.Builder celExpr, CelCreateStruct.Builder createStructBuilder) { + ImmutableList entries = createStructBuilder.getEntries(); + for (int i = 0; i < entries.size(); i++) { + CelCreateStruct.Entry.Builder entryBuilder = + entries.get(i).toBuilder().setId(celExprIdGenerator.nextExprId()); + CelExpr.Builder visitedValue = visit(entryBuilder.value().toBuilder()); + entryBuilder.setValue(visitedValue.build()); + + createStructBuilder.setEntry(i, entryBuilder.build()); + } + + return celExpr + .setCreateStruct(createStructBuilder.build()) + .setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCreateMap.Builder createMapBuilder) { + ImmutableList entries = createMapBuilder.getEntries(); + for (int i = 0; i < entries.size(); i++) { + CelCreateMap.Entry.Builder entryBuilder = + entries.get(i).toBuilder().setId(celExprIdGenerator.nextExprId()); + CelExpr.Builder visitedKey = visit(entryBuilder.key().toBuilder()); + entryBuilder.setKey(visitedKey.build()); + CelExpr.Builder visitedValue = visit(entryBuilder.value().toBuilder()); + entryBuilder.setValue(visitedValue.build()); + + createMapBuilder.setEntry(i, entryBuilder.build()); + } + + return celExpr.setCreateMap(createMapBuilder.build()).setId(celExprIdGenerator.nextExprId()); + } +} diff --git a/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel index 99f5babd..075aaad2 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel @@ -12,12 +12,19 @@ java_library( "//common", "//common:compiler_common", "//common/ast", + "//common/resources/testdata/proto3:test_all_types_java_proto", + "//common/types", "//compiler", "//optimizer", + "//optimizer:optimization_exception", "//optimizer:optimizer_builder", "//optimizer:optimizer_impl", + "//optimizer/src/main/java/dev/cel/optimizer:mutable_ast", "//parser", + "//parser:macro", + "//parser:unparser", "//runtime", + "@maven//:com_google_testparameterinjector_test_parameter_injector", "@maven//:junit_junit", ], ) diff --git a/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java b/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java index 903e726c..fdb6cb96 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java @@ -87,11 +87,12 @@ public void optimizer_whenAstOptimizerThrows_throwsException() { }) .build(); - IllegalArgumentException e = + CelOptimizationException e = assertThrows( - IllegalArgumentException.class, + CelOptimizationException.class, () -> celOptimizer.optimize(CEL.compile("'hello world'").getAst())); - assertThat(e).hasMessageThat().isEqualTo("Test exception"); + assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Test exception"); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); } @Test @@ -115,12 +116,16 @@ public void optimizedAst_failsToTypeCheck_throwsException() { CelExpr.ofIdentExpr(1, "undeclared_ident"), CelSource.newBuilder().build())) .build(); - CelValidationException e = + CelOptimizationException e = assertThrows( - CelValidationException.class, + CelOptimizationException.class, () -> celOptimizer.optimize(CEL.compile("'hello world'").getAst())); + assertThat(e) .hasMessageThat() - .contains("ERROR: :1:1: undeclared reference to 'undeclared_ident' (in container '')"); + .contains( + "Optimized AST failed to type-check: ERROR: :1:1: undeclared reference to" + + " 'undeclared_ident' (in container '')"); + assertThat(e).hasCauseThat().isInstanceOf(CelValidationException.class); } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java new file mode 100644 index 00000000..5cff92d0 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -0,0 +1,331 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOverloadDecl; +import dev.cel.common.CelSource; +import dev.cel.common.ast.CelConstant; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelIdent; +import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.parser.CelStandardMacro; +import dev.cel.parser.CelUnparser; +import dev.cel.parser.CelUnparserFactory; +import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class MutableAstTest { + private static final Cel CEL = + CelFactory.standardCelBuilder() + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("dev.cel.testing.testdata.proto3") + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .addVar("x", SimpleType.INT) + .build(); + + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + + @Test + public void constExpr() throws Exception { + CelExpr root = CEL.compile("10").getAst().getExpr(); + + CelExpr replacedExpr = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); + + assertThat(replacedExpr).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(true))); + } + + @Test + public void globalCallExpr_replaceRoot() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 4); + + assertThat(replacedRoot).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(10))); + } + + @Test + public void globalCallExpr_replaceLeaf() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 1); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10 + 2 + x"); + } + + @Test + public void globalCallExpr_replaceMiddleBranch() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 2); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10 + x"); + } + + @Test + public void globalCallExpr_replaceMiddleBranch_withCallExpr() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + CelExpr root2 = CEL.compile("4 + 5 + 6").getAst().getExpr(); + + CelExpr replacedRoot = MutableAst.replaceSubtree(root, root2, 2); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("4 + 5 + 6 + x"); + } + + @Test + public void memberCallExpr_replaceLeafTarget() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(20.func(5))"); + } + + @Test + public void memberCallExpr_replaceLeafArgument() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 5); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(4.func(20))"); + } + + @Test + public void memberCallExpr_replaceMiddleBranchTarget() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 1); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("20.func(4.func(5))"); + } + + @Test + public void memberCallExpr_replaceMiddleBranchArgument() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(20)"); + } + + @Test + public void select_replaceField() throws Exception { + // Tree shape (brackets are expr IDs): + // + [2] + // 5 [1] select [4] + // msg [3] + CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), + CelExpr.newBuilder() + .setSelect( + CelSelect.newBuilder() + .setField("single_sint32") + .setOperand( + CelExpr.newBuilder() + .setIdent(CelIdent.newBuilder().setName("test").build()) + .build()) + .build()) + .build(), + 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("5 + test.single_sint32"); + } + + @Test + public void select_replaceOperand() throws Exception { + // Tree shape (brackets are expr IDs): + // + [2] + // 5 [1] select [4] + // msg [3] + CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), + CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName("test").build()).build(), + 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("5 + test.single_int64"); + } + + @Test + public void list_replaceElement() throws Exception { + // Tree shape (brackets are expr IDs): + // list [1] + // 2 [2] 3 [3] 4 [4] + CelAbstractSyntaxTree ast = CEL.compile("[2, 3, 4]").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("[2, 3, 5]"); + } + + @Test + public void createStruct_replaceValue() throws Exception { + // Tree shape (brackets are expr IDs): + // TestAllTypes [1] + // single_int64 [2] + // 2 [3] + CelAbstractSyntaxTree ast = CEL.compile("TestAllTypes{single_int64: 2}").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("TestAllTypes{single_int64: 5}"); + } + + @Test + public void createMap_replaceKey() throws Exception { + // Tree shape (brackets are expr IDs): + // map [1] + // map_entry [2] + // 'a' [3] : 1 [4] + CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("{5: 1}"); + } + + @Test + public void createMap_replaceValue() throws Exception { + // Tree shape (brackets are expr IDs): + // map [1] + // map_entry [2] + // 'a' [3] : 1 [4] + CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("{\"a\": 5}"); + } + + @Test + public void invalidCelExprKind_throwsException() { + assertThrows( + IllegalArgumentException.class, + () -> + MutableAst.replaceSubtree( + CelExpr.ofConstantExpr(1, CelConstant.ofValue("test")), CelExpr.ofNotSet(1), 1)); + } + + private static String getUnparsedExpression(CelExpr expr) { + CelAbstractSyntaxTree ast = + CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()); + return CEL_UNPARSER.unparse(ast); + } +} diff --git a/parser/src/main/java/dev/cel/parser/BUILD.bazel b/parser/src/main/java/dev/cel/parser/BUILD.bazel index d9d826c4..3def9c47 100644 --- a/parser/src/main/java/dev/cel/parser/BUILD.bazel +++ b/parser/src/main/java/dev/cel/parser/BUILD.bazel @@ -34,6 +34,7 @@ MACRO_SOURCES = [ # keep sorted UNPARSER_SOURCES = [ "CelUnparser.java", + "CelUnparserFactory.java", "CelUnparserImpl.java", ] diff --git a/parser/src/main/java/dev/cel/parser/CelUnparserFactory.java b/parser/src/main/java/dev/cel/parser/CelUnparserFactory.java new file mode 100644 index 00000000..e7a14bb8 --- /dev/null +++ b/parser/src/main/java/dev/cel/parser/CelUnparserFactory.java @@ -0,0 +1,24 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.parser; + +/** Factory class for producing {@link CelUnparser} instances and builders. */ +public final class CelUnparserFactory { + public static CelUnparser newUnparser() { + return new CelUnparserImpl(); + } + + private CelUnparserFactory() {} +}