From 3833e1426ff813c43fd0bca90a5bf38afa1428ef Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 22 Sep 2023 14:34:28 -0700 Subject: [PATCH] Support folding macros and lists PiperOrigin-RevId: 567719306 --- .../main/java/dev/cel/common/CelSource.java | 9 +- .../java/dev/cel/common/ast/CelConstant.java | 43 ++ .../main/java/dev/cel/common/ast/CelExpr.java | 35 +- .../dev/cel/common/ast/CelExprFactory.java | 38 +- .../cel/common/ast/CelExprIdGenerator.java | 10 +- .../common/ast/CelExprIdGeneratorFactory.java | 22 +- .../common/navigation/CelNavigableAst.java | 1 - .../common/navigation/CelNavigableExpr.java | 40 +- .../navigation/CelNavigableExprVisitor.java | 1 - .../dev/cel/common/ast/CelConstantTest.java | 24 + .../cel/common/ast/CelExprFactoryTest.java | 28 +- .../common/ast/CelExprIdGeneratorTest.java | 10 +- .../CelNavigableExprVisitorTest.java | 79 ++- optimizer/optimizers/BUILD.bazel | 9 + .../main/java/dev/cel/optimizer/BUILD.bazel | 5 + .../dev/cel/optimizer/CelAstOptimizer.java | 30 +- .../java/dev/cel/optimizer/MutableAst.java | 225 ++++--- .../dev/cel/optimizer/optimizers/BUILD.bazel | 28 + .../optimizers/ConstantFoldingOptimizer.java | 549 ++++++++++++++++++ .../test/java/dev/cel/optimizer/BUILD.bazel | 4 + .../dev/cel/optimizer/MutableAstTest.java | 236 ++++++-- .../dev/cel/optimizer/optimizers/BUILD.bazel | 35 ++ .../ConstantFoldingOptimizerTest.java | 301 ++++++++++ .../src/main/java/dev/cel/parser/Parser.java | 2 +- .../HomogeneousLiteralValidator.java | 2 +- .../validators/LiteralValidator.java | 2 +- .../validators/RegexLiteralValidator.java | 2 +- 27 files changed, 1510 insertions(+), 260 deletions(-) create mode 100644 optimizer/optimizers/BUILD.bazel create mode 100644 optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel create mode 100644 optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java create mode 100644 optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel create mode 100644 optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java diff --git a/common/src/main/java/dev/cel/common/CelSource.java b/common/src/main/java/dev/cel/common/CelSource.java index c7547084e..88ab19c65 100644 --- a/common/src/main/java/dev/cel/common/CelSource.java +++ b/common/src/main/java/dev/cel/common/CelSource.java @@ -176,7 +176,8 @@ private static LineAndOffset findLine(List lineOffsets, int offset) { public Builder toBuilder() { return new Builder(codePoints, lineOffsets) .setDescription(description) - .addPositionsMap(positions); + .addPositionsMap(positions) + .addAllMacroCalls(macroCalls); } public static Builder newBuilder() { @@ -270,6 +271,12 @@ public Builder addAllMacroCalls(Map macroCalls) { return this; } + @CanIgnoreReturnValue + public Builder clearMacroCall(long exprId) { + this.macroCalls.remove(exprId); + return this; + } + /** See {@link #getLocationOffset(int, int)}. */ public Optional getLocationOffset(CelSourceLocation location) { checkNotNull(location); diff --git a/common/src/main/java/dev/cel/common/ast/CelConstant.java b/common/src/main/java/dev/cel/common/ast/CelConstant.java index a21639fce..f981225e3 100644 --- a/common/src/main/java/dev/cel/common/ast/CelConstant.java +++ b/common/src/main/java/dev/cel/common/ast/CelConstant.java @@ -16,6 +16,7 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableSet; import com.google.common.primitives.UnsignedLong; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.ByteString; @@ -33,6 +34,16 @@ @Internal @Immutable public abstract class CelConstant { + private static final ImmutableSet> CONSTANT_CLASSES = + ImmutableSet.of( + NullValue.class, + Boolean.class, + Long.class, + UnsignedLong.class, + Double.class, + String.class, + ByteString.class); + /** Represents the type of the Constant */ public enum Kind { NOT_SET, @@ -127,6 +138,38 @@ public static CelConstant ofValue(ByteString value) { return AutoOneOf_CelConstant.bytesValue(value); } + /** Checks whether the provided Java object is a valid CelConstant value. */ + public static boolean isConstantValue(Object value) { + return CONSTANT_CLASSES.contains(value.getClass()); + } + + /** + * Converts the given Java object into a CelConstant value. This is equivalent of calling {@link + * CelConstant#ofValue} with concrete types. + * + * @throws IllegalArgumentException If the value is not a supported CelConstant. This includes the + * deprecated duration and timestamp values. + */ + public static CelConstant ofObjectValue(Object value) { + if (value instanceof NullValue) { + return ofValue((NullValue) value); + } else if (value instanceof Boolean) { + return ofValue((boolean) value); + } else if (value instanceof Long) { + return ofValue((long) value); + } else if (value instanceof UnsignedLong) { + return ofValue((UnsignedLong) value); + } else if (value instanceof Double) { + return ofValue((double) value); + } else if (value instanceof String) { + return ofValue((String) value); + } else if (value instanceof ByteString) { + return ofValue((ByteString) value); + } + + throw new IllegalArgumentException("Value is not a CelConstant: " + value); + } + /** * @deprecated Do not use. Duration is no longer built-in CEL type. */ diff --git a/common/src/main/java/dev/cel/common/ast/CelExpr.java b/common/src/main/java/dev/cel/common/ast/CelExpr.java index 82d3e42eb..6bafad107 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExpr.java +++ b/common/src/main/java/dev/cel/common/ast/CelExpr.java @@ -15,6 +15,7 @@ package dev.cel.common.ast; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; @@ -499,7 +500,8 @@ public abstract static class CelCall { public abstract static class Builder { private List mutableArgs = new ArrayList<>(); - public abstract ImmutableList args(); + // Not public. This only exists to make AutoValue.Builder work. + abstract ImmutableList args(); public abstract Builder setTarget(CelExpr value); @@ -512,6 +514,16 @@ public abstract static class Builder { // Not public. This only exists to make AutoValue.Builder work. abstract Builder setArgs(ImmutableList value); + /** Returns an immutable copy of the current mutable arguments present in the builder. */ + public ImmutableList getArgs() { + return ImmutableList.copyOf(mutableArgs); + } + + /** Returns an immutable copy of the builders from the current mutable arguments. */ + public ImmutableList getArgsBuilders() { + return mutableArgs.stream().map(CelExpr::toBuilder).collect(toImmutableList()); + } + public Builder setArg(int index, CelExpr arg) { checkNotNull(arg); mutableArgs.set(index, arg); @@ -599,6 +611,11 @@ public ImmutableList getElements() { return ImmutableList.copyOf(mutableElements); } + /** Returns an immutable copy of the builders from the current mutable elements. */ + public ImmutableList getElementsBuilders() { + return mutableElements.stream().map(CelExpr::toBuilder).collect(toImmutableList()); + } + @CanIgnoreReturnValue public Builder setElement(int index, CelExpr element) { checkNotNull(element); @@ -691,6 +708,13 @@ public ImmutableList getEntries() { return ImmutableList.copyOf(mutableEntries); } + /** Returns an immutable copy of the builders from the current mutable entries. */ + public ImmutableList getEntriesBuilders() { + return mutableEntries.stream() + .map(CelCreateStruct.Entry::toBuilder) + .collect(toImmutableList()); + } + @CanIgnoreReturnValue public Builder setEntry(int index, CelCreateStruct.Entry entry) { checkNotNull(entry); @@ -814,6 +838,13 @@ public ImmutableList getEntries() { return ImmutableList.copyOf(mutableEntries); } + /** Returns an immutable copy of the builders from the current mutable entries. */ + public ImmutableList getEntriesBuilders() { + return mutableEntries.stream() + .map(CelCreateMap.Entry::toBuilder) + .collect(toImmutableList()); + } + @CanIgnoreReturnValue public Builder setEntry(int index, CelCreateMap.Entry entry) { checkNotNull(entry); @@ -905,7 +936,7 @@ public abstract static class Builder { public abstract CelCreateMap.Entry.Builder toBuilder(); public static CelCreateMap.Entry.Builder newBuilder() { - return new AutoValue_CelExpr_CelCreateMap_Entry.Builder().setOptionalEntry(false); + return new AutoValue_CelExpr_CelCreateMap_Entry.Builder().setId(0).setOptionalEntry(false); } } } diff --git a/common/src/main/java/dev/cel/common/ast/CelExprFactory.java b/common/src/main/java/dev/cel/common/ast/CelExprFactory.java index 08f9b787b..b96ca0b85 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprFactory.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprFactory.java @@ -17,10 +17,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.isNullOrEmpty; -import com.google.common.base.Preconditions; import com.google.common.primitives.UnsignedLong; -import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.errorprone.annotations.CheckReturnValue; import com.google.protobuf.ByteString; import java.util.Arrays; @@ -28,31 +25,8 @@ public class CelExprFactory { private final CelExprIdGenerator idGenerator; - /** Builder for configuring {@link CelExprFactory}. */ - public static final class Builder { - - private CelExprIdGenerator exprIdGenerator; - - @CanIgnoreReturnValue - public Builder setIdGenerator(CelExprIdGenerator exprIdGenerator) { - this.exprIdGenerator = exprIdGenerator; - Preconditions.checkNotNull(exprIdGenerator); - return this; - } - - @CheckReturnValue - public CelExprFactory build() { - return new CelExprFactory(exprIdGenerator); - } - - private Builder() { - exprIdGenerator = CelExprIdGeneratorFactory.newMonotonicIdGenerator(0); - } - } - - /** Creates a new builder to configure CelExprFactory. */ - public static CelExprFactory.Builder newBuilder() { - return new Builder(); + public static CelExprFactory newInstance() { + return new CelExprFactory(); } /** Create a new constant expression. */ @@ -569,14 +543,10 @@ public final CelExpr newSelect(CelExpr operand, String field, boolean testOnly) /** Returns the next unique expression ID. */ protected long nextExprId() { - return idGenerator.nextExprId(); - } - - protected CelExprFactory(CelExprIdGenerator idGenerator) { - this.idGenerator = idGenerator; + return idGenerator.generate(0); } protected CelExprFactory() { - this(CelExprIdGeneratorFactory.newMonotonicIdGenerator(0)); + idGenerator = CelExprIdGeneratorFactory.newMonotonicIdGenerator(0); } } diff --git a/common/src/main/java/dev/cel/common/ast/CelExprIdGenerator.java b/common/src/main/java/dev/cel/common/ast/CelExprIdGenerator.java index 703b95d5c..234241243 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprIdGenerator.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprIdGenerator.java @@ -19,14 +19,10 @@ * renumbering existing expression IDs */ public interface CelExprIdGenerator { - /** Returns the next unique expression ID. */ - long nextExprId(); /** - * Renumber an existing expression ID, ensuring that new IDs are only created the first time they - * are encountered. + * Generates an expression ID. See {@code CelExprIdGeneratorFactory} for specifications on how an + * ID is generated. */ - default long renumberId(long id) { - throw new UnsupportedOperationException(); - } + long generate(long id); } diff --git a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java index 493bdc2de..99c4403c8 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java @@ -24,14 +24,19 @@ public final class CelExprIdGeneratorFactory { * MonotonicIdGenerator increments expression IDs from an initial seed value. * * @param exprId Seed value. Must be non-negative. For example, if 1 is provided {@link - * CelExprIdGenerator#nextExprId} will return 2. + * CelExprIdGenerator#generate} will return 2. */ public static CelExprIdGenerator newMonotonicIdGenerator(long exprId) { return new MonotonicIdGenerator(exprId); } - /** StableIdGenerator ensures new IDs are only created the first time they are encountered. */ - static CelExprIdGenerator newStableIdGenerator(long exprId) { + /** + * StableIdGenerator ensures new IDs are only created the first time they are encountered. + * + * @param exprId Seed value. Must be non-negative. For example, if 1 is provided {@link + * CelExprIdGenerator#generate} will return 2. + */ + public static CelExprIdGenerator newStableIdGenerator(long exprId) { return new StableIdGenerator(exprId); } @@ -39,7 +44,7 @@ private static class MonotonicIdGenerator implements CelExprIdGenerator { private long exprId; @Override - public long nextExprId() { + public long generate(long id) { return ++exprId; } @@ -54,12 +59,7 @@ private static class StableIdGenerator implements CelExprIdGenerator { private long exprId; @Override - public long nextExprId() { - return ++exprId; - } - - @Override - public long renumberId(long id) { + public long generate(long id) { Preconditions.checkArgument(id >= 0); if (id == 0) { return 0; @@ -69,7 +69,7 @@ public long renumberId(long id) { return idSet.get(id); } - long nextExprId = nextExprId(); + long nextExprId = ++exprId; idSet.put(id, nextExprId); return nextExprId; } diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableAst.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableAst.java index 9744c0264..3d185bea7 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableAst.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableAst.java @@ -21,7 +21,6 @@ * node's children, descendants or its parent with ease. */ public final class CelNavigableAst { - private final CelAbstractSyntaxTree ast; private final CelNavigableExpr root; diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java index 1e4411f2a..76fa5bfb0 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java @@ -55,36 +55,58 @@ public enum TraversalOrder { /** Represents the count of transitive parents. Depth of an AST's root is 0. */ public abstract int depth(); + /** Constructs a new instance of {@link CelNavigableExpr} from {@link CelExpr}. */ + public static CelNavigableExpr fromExpr(CelExpr expr) { + return CelNavigableExpr.builder().setExpr(expr).build(); + } + /** * Returns a stream of {@link CelNavigableExpr} collected from the current node down to the last * leaf-level member using post-order traversal. */ - public Stream descendants() { - return descendants(TraversalOrder.POST_ORDER); + public Stream allNodes() { + return allNodes(TraversalOrder.POST_ORDER); } /** * Returns a stream of {@link CelNavigableExpr} collected from the current node down to the last * leaf-level member using the specified traversal order. */ - public Stream descendants(TraversalOrder traversalOrder) { + public Stream allNodes(TraversalOrder traversalOrder) { return CelNavigableExprVisitor.collect(this, traversalOrder); } /** - * Returns a stream of {@link CelNavigableExpr} collected from the current node to its immediate - * children using post-order traversal. + * Returns a stream of {@link CelNavigableExpr} collected down to the last leaf-level member using + * post-order traversal. + */ + public Stream descendants() { + return descendants(TraversalOrder.POST_ORDER); + } + + /** + * Returns a stream of {@link CelNavigableExpr} collected down to the last leaf-level member using + * the specified traversal order. + */ + public Stream descendants(TraversalOrder traversalOrder) { + return CelNavigableExprVisitor.collect(this, traversalOrder).filter(node -> !node.equals(this)); + } + + /** + * Returns a stream of {@link CelNavigableExpr} collected from its immediate children using + * post-order traversal. */ public Stream children() { return children(TraversalOrder.POST_ORDER); } /** - * Returns a stream of {@link CelNavigableExpr} collected from the current node to its immediate - * children using the specified traversal order. + * Returns a stream of {@link CelNavigableExpr} collected from its immediate children using the + * specified traversal order. */ public Stream children(TraversalOrder traversalOrder) { - return CelNavigableExprVisitor.collect(this, 1, traversalOrder); + return CelNavigableExprVisitor.collect(this, this.depth() + 1, traversalOrder) + .filter(node -> !node.equals(this)); } /** Returns the underlying kind of the {@link CelExpr}. */ @@ -94,7 +116,7 @@ public ExprKind.Kind getKind() { /** Create a new builder to construct a {@link CelNavigableExpr} instance. */ public static Builder builder() { - return new AutoValue_CelNavigableExpr.Builder(); + return new AutoValue_CelNavigableExpr.Builder().setDepth(0); } /** Builder to configure {@link CelNavigableExpr}. */ diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java index c265354a2..8abcf017f 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java @@ -27,7 +27,6 @@ /** Visitor implementation to navigate an AST. */ final class CelNavigableExprVisitor { - private static final int MAX_DESCENDANTS_RECURSION_DEPTH = 500; private final Stream.Builder streamBuilder; diff --git a/common/src/test/java/dev/cel/common/ast/CelConstantTest.java b/common/src/test/java/dev/cel/common/ast/CelConstantTest.java index abf09ac41..462267ba7 100644 --- a/common/src/test/java/dev/cel/common/ast/CelConstantTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelConstantTest.java @@ -196,4 +196,28 @@ public void getValueOnInvalidKindCase_throwsException( assertThrows(UnsupportedOperationException.class, constant::timestampValue); } } + + @Test + public void getObjectValue_success() { + assertThat(CelConstant.ofObjectValue(NullValue.NULL_VALUE)) + .isEqualTo(CelConstant.ofValue(NullValue.NULL_VALUE)); + assertThat(CelConstant.ofObjectValue(true)).isEqualTo(CelConstant.ofValue(true)); + assertThat(CelConstant.ofObjectValue(2L)).isEqualTo(CelConstant.ofValue(2L)); + assertThat(CelConstant.ofObjectValue(UnsignedLong.valueOf(3L))) + .isEqualTo(CelConstant.ofValue(UnsignedLong.valueOf(3L))); + assertThat(CelConstant.ofObjectValue(3.0d)).isEqualTo(CelConstant.ofValue(3.0d)); + assertThat(CelConstant.ofObjectValue("test")).isEqualTo(CelConstant.ofValue("test")); + assertThat(CelConstant.ofObjectValue(ByteString.copyFromUtf8("hello"))) + .isEqualTo(CelConstant.ofValue(ByteString.copyFromUtf8("hello"))); + } + + @Test + public void getObjectValue_invalidParameter_throws() { + assertThrows( + IllegalArgumentException.class, + () -> CelConstant.ofObjectValue(Duration.getDefaultInstance())); + assertThrows( + IllegalArgumentException.class, + () -> CelConstant.ofObjectValue(Timestamp.getDefaultInstance())); + } } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java b/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java index 75bf7bef7..f3d77ea9e 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java @@ -25,40 +25,16 @@ public class CelExprFactoryTest { @Test public void construct_success() { - CelExprFactory exprFactory = CelExprFactory.newBuilder().build(); + CelExprFactory exprFactory = CelExprFactory.newInstance(); assertThat(exprFactory).isNotNull(); } @Test public void nextExprId_startingDefaultIsOne() { - CelExprFactory exprFactory = CelExprFactory.newBuilder().build(); + CelExprFactory exprFactory = CelExprFactory.newInstance(); assertThat(exprFactory.nextExprId()).isEqualTo(1L); assertThat(exprFactory.nextExprId()).isEqualTo(2L); } - - @Test - public void construct_witMonotonicIdGenerator_success() { - CelExprFactory exprFactory = - CelExprFactory.newBuilder() - .setIdGenerator(CelExprIdGeneratorFactory.newMonotonicIdGenerator(3L)) - .build(); - - assertThat(exprFactory).isNotNull(); - assertThat(exprFactory.nextExprId()).isEqualTo(4L); - assertThat(exprFactory.nextExprId()).isEqualTo(5L); - } - - @Test - public void construct_withStableIdGenerator_success() { - CelExprFactory exprFactory = - CelExprFactory.newBuilder() - .setIdGenerator(CelExprIdGeneratorFactory.newStableIdGenerator(3L)) - .build(); - - assertThat(exprFactory).isNotNull(); - assertThat(exprFactory.nextExprId()).isEqualTo(4L); - assertThat(exprFactory.nextExprId()).isEqualTo(5L); - } } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprIdGeneratorTest.java b/common/src/test/java/dev/cel/common/ast/CelExprIdGeneratorTest.java index c7505fb17..c063e2ed7 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprIdGeneratorTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprIdGeneratorTest.java @@ -40,10 +40,10 @@ public void newStableIdGenerator_throwsIfIdIsNegative() { public void stableIdGenerator_renumberId() { CelExprIdGenerator idGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0); - assertThat(idGenerator.renumberId(0)).isEqualTo(0); - assertThat(idGenerator.renumberId(2)).isEqualTo(1); - assertThat(idGenerator.renumberId(2)).isEqualTo(1); - assertThat(idGenerator.renumberId(1)).isEqualTo(2); - assertThat(idGenerator.renumberId(3)).isEqualTo(3); + assertThat(idGenerator.generate(0)).isEqualTo(0); + assertThat(idGenerator.generate(2)).isEqualTo(1); + assertThat(idGenerator.generate(2)).isEqualTo(1); + assertThat(idGenerator.generate(1)).isEqualTo(2); + assertThat(idGenerator.generate(3)).isEqualTo(3); } } diff --git a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java index 5b83e64d2..a496a133a 100644 --- a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java +++ b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java @@ -72,7 +72,7 @@ public void collectWithMaxDepth_expectedNodeCountReturned(int maxDepth, int expe } @Test - public void add_descendants_allNodesReturned() throws Exception { + public void add_allNodes_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); // Tree shape: @@ -83,7 +83,7 @@ public void add_descendants_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); CelExpr childAddCall = CelExpr.ofCallExpr( @@ -109,7 +109,7 @@ public void add_descendants_allNodesReturned() throws Exception { } @Test - public void add_filterConstants_descendantsReturned() throws Exception { + public void add_filterConstants_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); // Tree shape: @@ -122,7 +122,7 @@ public void add_filterConstants_descendantsReturned() throws Exception { ImmutableList allConstants = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CONSTANT)) .collect(toImmutableList()); @@ -147,7 +147,7 @@ public void add_filterConstants_parentsPopulated() throws Exception { ImmutableList allConstants = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CONSTANT)) .collect(toImmutableList()); @@ -221,6 +221,33 @@ public void add_filterConstants_parentOfChildPopulated() throws Exception { assertThat(parentExpr.call().args().get(1).constant()).isEqualTo(CelConstant.ofValue(2)); } + @Test + public void add_childrenOfMiddleBranch_success() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); + // Tree shape: + // + + // + 2 + // 1 a + CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + CelNavigableExpr ident = + navigableAst + .getRoot() + .allNodes() + .filter(node -> node.getKind().equals(Kind.IDENT)) // Find "a" + .findAny() + .get(); + + ImmutableList children = + ident.parent().get().children().collect(toImmutableList()); + + // Assert that the children of add call in the middle branch are const(1) and ident("a") + assertThat(children).hasSize(2); + assertThat(children.get(0).expr()).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(1))); + assertThat(children.get(1)).isEqualTo(ident); + } + @Test public void stringFormatCall_filterList_success() throws Exception { CelCompiler compiler = @@ -239,7 +266,7 @@ public void stringFormatCall_filterList_success() throws Exception { ImmutableList allConstants = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CREATE_LIST)) .collect(toImmutableList()); @@ -270,7 +297,7 @@ public void stringFormatCall_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList targetExprs = - navigableAst.getRoot().descendants().collect(toImmutableList()); + navigableAst.getRoot().allNodes().collect(toImmutableList()); assertThat(targetExprs).hasSize(5); } @@ -286,7 +313,7 @@ public void message_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); CelExpr operand = CelExpr.ofIdentExpr(1, "msg"); assertThat(allNodes) @@ -294,7 +321,7 @@ public void message_allNodesReturned() throws Exception { } @Test - public void nestedMessage_filterSelect_descendantsReturned() throws Exception { + public void nestedMessage_filterSelect_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder() .addMessageTypes(TestAllTypes.getDescriptor()) @@ -306,7 +333,7 @@ public void nestedMessage_filterSelect_descendantsReturned() throws Exception { ImmutableList allSelects = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.SELECT)) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -360,7 +387,7 @@ public void presenceTest_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); assertThat(allNodes).hasSize(2); assertThat(allNodes) @@ -380,7 +407,7 @@ public void messageConstruction_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); CelExpr constExpr = CelExpr.ofConstantExpr(3, CelConstant.ofValue(1)); assertThat(allNodes) @@ -394,7 +421,7 @@ public void messageConstruction_allNodesReturned() throws Exception { } @Test - public void messageConstruction_filterCreateStruct_descendantsReturned() throws Exception { + public void messageConstruction_filterCreateStruct_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder() .addMessageTypes(TestAllTypes.getDescriptor()) @@ -406,7 +433,7 @@ public void messageConstruction_filterCreateStruct_descendantsReturned() throws ImmutableList allNodes = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CREATE_STRUCT)) .collect(toImmutableList()); @@ -431,7 +458,7 @@ public void mapConstruction_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); assertThat(allNodes).hasSize(3); CelExpr mapKeyExpr = CelExpr.ofConstantExpr(3, CelConstant.ofValue("key")); @@ -447,7 +474,7 @@ public void mapConstruction_allNodesReturned() throws Exception { } @Test - public void mapConstruction_filterCreateMap_descendantsReturned() throws Exception { + public void mapConstruction_filterCreateMap_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().build(); CelAbstractSyntaxTree ast = compiler.compile("{'key': 2}").getAst(); CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); @@ -455,7 +482,7 @@ public void mapConstruction_filterCreateMap_descendantsReturned() throws Excepti ImmutableList allNodes = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CREATE_MAP)) .collect(toImmutableList()); @@ -477,7 +504,7 @@ public void emptyMapConstruction_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().collect(toImmutableList()); + navigableAst.getRoot().allNodes().collect(toImmutableList()); assertThat(allNodes).hasSize(1); assertThat(allNodes.get(0).expr()).isEqualTo(CelExpr.ofCreateMapExpr(1, ImmutableList.of())); @@ -495,7 +522,7 @@ public void comprehension_preOrder_allNodesReturned() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.PRE_ORDER) + .allNodes(TraversalOrder.PRE_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -557,7 +584,7 @@ public void comprehension_postOrder_allNodesReturned() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.POST_ORDER) + .allNodes(TraversalOrder.POST_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -617,7 +644,7 @@ public void comprehension_allNodes_parentsPopulated() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants(TraversalOrder.PRE_ORDER).collect(toImmutableList()); + navigableAst.getRoot().allNodes(TraversalOrder.PRE_ORDER).collect(toImmutableList()); CelExpr iterRangeConstExpr = CelExpr.ofConstantExpr(2, CelConstant.ofValue(true)); CelExpr iterRange = @@ -668,7 +695,7 @@ public void comprehension_allNodes_parentsPopulated() throws Exception { } @Test - public void comprehension_filterComprehension_descendantsReturned() throws Exception { + public void comprehension_filterComprehension_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder() .setStandardMacros(CelStandardMacro.EXISTS) @@ -679,7 +706,7 @@ public void comprehension_filterComprehension_descendantsReturned() throws Excep ImmutableList allNodes = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.COMPREHENSION)) .collect(toImmutableList()); @@ -736,7 +763,7 @@ public void callExpr_preOrder() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.PRE_ORDER) + .allNodes(TraversalOrder.PRE_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -773,7 +800,7 @@ public void callExpr_postOrder() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.POST_ORDER) + .allNodes(TraversalOrder.POST_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -805,7 +832,7 @@ public void maxRecursionLimitReached_throws() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); IllegalStateException e = - assertThrows(IllegalStateException.class, () -> navigableAst.getRoot().descendants()); + assertThrows(IllegalStateException.class, () -> navigableAst.getRoot().allNodes()); assertThat(e).hasMessageThat().contains("Max recursion depth reached."); } } diff --git a/optimizer/optimizers/BUILD.bazel b/optimizer/optimizers/BUILD.bazel new file mode 100644 index 000000000..f4095d8b3 --- /dev/null +++ b/optimizer/optimizers/BUILD.bazel @@ -0,0 +1,9 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], # TODO: Expose to public +) + +java_library( + name = "constant_folding", + exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:constant_folding"], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel index f99d21c89..dd596aba5 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel @@ -65,8 +65,11 @@ java_library( tags = [ ], deps = [ + ":mutable_ast", + ":optimization_exception", "//bundle:cel", "//common", + "//common/ast", "//common/navigation", ], ) @@ -77,9 +80,11 @@ java_library( tags = [ ], deps = [ + "//common", "//common/annotations", "//common/ast", "//common/ast:expr_factory", + "//common/navigation", "@maven//:com_google_guava_guava", ], ) diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index 8dc80f150..182cf0fbe 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -16,13 +16,37 @@ import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.ast.CelExpr; import dev.cel.common.navigation.CelNavigableAst; /** Public interface for performing a single, custom optimization on an AST. */ public interface CelAstOptimizer { + /** Optimizes a single AST. */ + CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) + throws CelOptimizationException; + /** - * Optimizes a single AST. - **/ - CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel); + * Replaces a subtree in the given AST. This operation is intended for AST optimization purposes. + * + *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and + * additionally verify that the resulting AST is semantically valid. + * + *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision + * between the nodes. The renumbering occurs even if the subtree was not replaced. + * + * @param ast Original ast to mutate. + * @param newExpr New CelExpr to replace the subtree with. + * @param exprIdToReplace Expression id of the subtree that is getting replaced. + */ + default CelAbstractSyntaxTree replaceSubtree( + CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { + // Clear any macro metadata associated. + // CelSource newSource = ast.getSource().toBuilder().clearMacroCall(exprIdToReplace).build(); + // ast = + // CelAbstractSyntaxTree.newCheckedAst( + // ast.getExpr(), newSource, ast.getReferenceMap(), ast.getTypeMap()); + + return MutableAst.replaceSubtree(ast, newExpr, exprIdToReplace); + } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 26ae0e0df..6957ccaeb 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -14,25 +14,35 @@ package dev.cel.optimizer; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; + import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelSource; import dev.cel.common.annotations.Internal; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelComprehension; import dev.cel.common.ast.CelExpr.CelCreateList; import dev.cel.common.ast.CelExpr.CelCreateMap; import dev.cel.common.ast.CelExpr.CelCreateStruct; import dev.cel.common.ast.CelExpr.CelSelect; import dev.cel.common.ast.CelExprIdGenerator; import dev.cel.common.ast.CelExprIdGeneratorFactory; +import dev.cel.common.navigation.CelNavigableExpr; +import java.util.Map.Entry; +import java.util.NoSuchElementException; /** MutableAst contains logic for mutating a {@link CelExpr}. */ @Internal final class MutableAst { private static final int MAX_ITERATION_COUNT = 500; private final CelExpr.Builder newExpr; - private final long exprIdToReplace; private final CelExprIdGenerator celExprIdGenerator; private int iterationCount; + private long exprIdToReplace; private MutableAst(CelExprIdGenerator celExprIdGenerator, CelExpr.Builder newExpr, long exprId) { this.celExprIdGenerator = celExprIdGenerator; @@ -41,27 +51,108 @@ private MutableAst(CelExprIdGenerator celExprIdGenerator, CelExpr.Builder newExp } /** - * Replaces a subtree in the given CelExpr. This is a very dangerous operation. Callers should - * re-typecheck the mutated AST and additionally verify that the resulting AST is semantically - * valid. + * Replaces a subtree in the given CelExpr. * *

This method should remain package-private. */ - static CelExpr replaceSubtree(CelExpr root, CelExpr newExpr, long exprIdToReplace) { - // Zero out the expr IDs in the new expression tree first. This ensures that no ID collision + static CelAbstractSyntaxTree replaceSubtree( + CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { + // Update the IDs in the new expression tree first. This ensures that no ID collision // occurs while attempting to replace the subtree, potentially leading to infinite loop - CelExpr.Builder newExprBuilder = newExpr.toBuilder(); - MutableAst mutableAst = new MutableAst(() -> 0, CelExpr.newBuilder(), -1); - newExprBuilder = mutableAst.visit(newExprBuilder); + CelExpr.Builder newExprBuilder = + renumberExprIds( + CelExprIdGeneratorFactory.newMonotonicIdGenerator(getMaxId(ast.getExpr())), + newExpr.toBuilder()); + + CelExprIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0); + CelExpr.Builder mutatedRoot = + replaceSubtreeImpl( + stableIdGenerator, ast.getExpr().toBuilder(), newExprBuilder, exprIdToReplace); + + // If the source info contained macro call information, their IDs must be normalized. + CelSource normalizedSource = + normalizeMacroSource(ast.getSource(), exprIdToReplace, mutatedRoot, stableIdGenerator); + + return CelAbstractSyntaxTree.newParsedAst(mutatedRoot.build(), normalizedSource); + } + + private static CelSource normalizeMacroSource( + CelSource celSource, + long exprIdToReplace, + CelExpr.Builder mutatedRoot, + CelExprIdGenerator idGenerator) { + // Remove the macro metadata that no longer exists in the AST due to being replaced. + celSource = celSource.toBuilder().clearMacroCall(exprIdToReplace).build(); + if (celSource.getMacroCalls().isEmpty()) { + return CelSource.newBuilder().build(); + } + + CelSource.Builder sourceBuilder = CelSource.newBuilder(); + ImmutableMap allExprs = + CelNavigableExpr.fromExpr(mutatedRoot.build()) + .allNodes() + .map(CelNavigableExpr::expr) + .collect( + toImmutableMap( + CelExpr::id, + expr -> expr, + (expr1, expr2) -> { + // Comprehensions can reuse same expression (result). We just need to ensure + // that they are identical. + if (expr1.equals(expr2)) { + return expr1; + } + throw new IllegalStateException( + "Expected expressions to be the same for id: " + expr1.id()); + })); + + // Update the macro call IDs and their call references + for (Entry macroCall : celSource.getMacroCalls().entrySet()) { + long macroId = macroCall.getKey(); + long callId = idGenerator.generate(macroId); + + CelExpr.Builder newCall = renumberExprIds(idGenerator, macroCall.getValue().toBuilder()); + CelNavigableExpr callNav = CelNavigableExpr.fromExpr(newCall.build()); + ImmutableList callDescendants = + callNav.descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + + for (CelExpr callChild : callDescendants) { + if (!allExprs.containsKey(callChild.id())) { + continue; + } + CelExpr mutatedExpr = allExprs.get(callChild.id()); + if (!callChild.equals(mutatedExpr)) { + newCall = + replaceSubtreeImpl((arg) -> arg, newCall, mutatedExpr.toBuilder(), callChild.id()); + } + } + sourceBuilder.addMacroCalls(callId, newCall.build()); + } - // Replace the subtree - mutableAst = - new MutableAst( - CelExprIdGeneratorFactory.newMonotonicIdGenerator(0), newExprBuilder, exprIdToReplace); + return sourceBuilder.build(); + } + + private static CelExpr.Builder replaceSubtreeImpl( + CelExprIdGenerator idGenerator, + CelExpr.Builder root, + CelExpr.Builder newExpr, + long exprIdToReplace) { + MutableAst mutableAst = new MutableAst(idGenerator, newExpr, exprIdToReplace); + return mutableAst.visit(root); + } - // TODO: Normalize IDs for macro calls + private static CelExpr.Builder renumberExprIds( + CelExprIdGenerator idGenerator, CelExpr.Builder root) { + MutableAst mutableAst = new MutableAst(idGenerator, root, Integer.MIN_VALUE); + return mutableAst.visit(root); + } - return mutableAst.visit(root.toBuilder()).build(); + private static long getMaxId(CelExpr newExpr) { + return CelNavigableExpr.fromExpr(newExpr) + .allNodes() + .mapToLong(node -> node.expr().id()) + .max() + .orElseThrow(NoSuchElementException::new); } private CelExpr.Builder visit(CelExpr.Builder expr) { @@ -70,9 +161,12 @@ private CelExpr.Builder visit(CelExpr.Builder expr) { } if (expr.id() == exprIdToReplace) { - return visit(newExpr); + exprIdToReplace = Integer.MIN_VALUE; // Marks that the subtree has been replaced. + return visit(newExpr.setId(expr.id())); } + expr.setId(celExprIdGenerator.generate(expr.id())); + switch (expr.exprKind().getKind()) { case SELECT: return visit(expr, expr.select().toBuilder()); @@ -85,81 +179,76 @@ private CelExpr.Builder visit(CelExpr.Builder expr) { case CREATE_MAP: return visit(expr, expr.createMap().toBuilder()); case COMPREHENSION: - // TODO: Implement functionality. - throw new UnsupportedOperationException("Augmenting comprehensions is not supported yet."); - case CONSTANT: // Fall-through is intended. + return visit(expr, expr.comprehension().toBuilder()); + case CONSTANT: // Fall-through is intended case IDENT: - return expr.setId(celExprIdGenerator.nextExprId()); + case NOT_SET: // Note: comprehension arguments can contain a not set expr. + return expr; default: throw new IllegalArgumentException("unexpected expr kind: " + expr.exprKind().getKind()); } } - private CelExpr.Builder visit(CelExpr.Builder celExpr, CelSelect.Builder selectExpr) { - CelExpr.Builder visitedOperand = visit(selectExpr.operand().toBuilder()); - selectExpr = selectExpr.setOperand(visitedOperand.build()); - - return celExpr.setSelect(selectExpr.build()).setId(celExprIdGenerator.nextExprId()); + private CelExpr.Builder visit(CelExpr.Builder expr, CelSelect.Builder select) { + select.setOperand(visit(select.operand().toBuilder()).build()); + return expr.setSelect(select.build()); } - private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCall.Builder callExpr) { - if (callExpr.target().isPresent()) { - CelExpr.Builder visitedTargetExpr = visit(callExpr.target().get().toBuilder()); - callExpr = callExpr.setTarget(visitedTargetExpr.build()); - - celExpr.setCall(callExpr.build()); + private CelExpr.Builder visit(CelExpr.Builder expr, CelCall.Builder call) { + if (call.target().isPresent()) { + call.setTarget(visit(call.target().get().toBuilder()).build()); } - - ImmutableList args = callExpr.args(); - for (int i = 0; i < args.size(); i++) { - CelExpr arg = args.get(i); - CelExpr.Builder visitedArg = visit(arg.toBuilder()); - callExpr.setArg(i, visitedArg.build()); + ImmutableList argsBuilders = call.getArgsBuilders(); + for (int i = 0; i < argsBuilders.size(); i++) { + CelExpr.Builder arg = argsBuilders.get(i); + call.setArg(i, visit(arg).build()); } - return celExpr.setCall(callExpr.build()).setId(celExprIdGenerator.nextExprId()); + return expr.setCall(call.build()); } - private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCreateList.Builder createListBuilder) { - ImmutableList elements = createListBuilder.getElements(); - for (int i = 0; i < elements.size(); i++) { - CelExpr.Builder visitedElement = visit(elements.get(i).toBuilder()); - createListBuilder.setElement(i, visitedElement.build()); + private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateStruct.Builder createStruct) { + ImmutableList entries = createStruct.getEntriesBuilders(); + for (int i = 0; i < entries.size(); i++) { + CelCreateStruct.Entry.Builder entry = entries.get(i); + entry.setValue(visit(entry.value().toBuilder()).build()); + + createStruct.setEntry(i, entry.build()); } - return celExpr.setCreateList(createListBuilder.build()).setId(celExprIdGenerator.nextExprId()); + return expr.setCreateStruct(createStruct.build()); } - private CelExpr.Builder visit( - CelExpr.Builder celExpr, CelCreateStruct.Builder createStructBuilder) { - ImmutableList entries = createStructBuilder.getEntries(); - for (int i = 0; i < entries.size(); i++) { - CelCreateStruct.Entry.Builder entryBuilder = - entries.get(i).toBuilder().setId(celExprIdGenerator.nextExprId()); - CelExpr.Builder visitedValue = visit(entryBuilder.value().toBuilder()); - entryBuilder.setValue(visitedValue.build()); + private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateMap.Builder createMap) { + ImmutableList entriesBuilders = createMap.getEntriesBuilders(); + for (int i = 0; i < entriesBuilders.size(); i++) { + CelCreateMap.Entry.Builder entry = entriesBuilders.get(i); + entry.setKey(visit(entry.key().toBuilder()).build()); + entry.setValue(visit(entry.value().toBuilder()).build()); - createStructBuilder.setEntry(i, entryBuilder.build()); + createMap.setEntry(i, entry.build()); } - return celExpr - .setCreateStruct(createStructBuilder.build()) - .setId(celExprIdGenerator.nextExprId()); + return expr.setCreateMap(createMap.build()); } - private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCreateMap.Builder createMapBuilder) { - ImmutableList entries = createMapBuilder.getEntries(); - for (int i = 0; i < entries.size(); i++) { - CelCreateMap.Entry.Builder entryBuilder = - entries.get(i).toBuilder().setId(celExprIdGenerator.nextExprId()); - CelExpr.Builder visitedKey = visit(entryBuilder.key().toBuilder()); - entryBuilder.setKey(visitedKey.build()); - CelExpr.Builder visitedValue = visit(entryBuilder.value().toBuilder()); - entryBuilder.setValue(visitedValue.build()); - - createMapBuilder.setEntry(i, entryBuilder.build()); + private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateList.Builder createList) { + ImmutableList elementsBuilders = createList.getElementsBuilders(); + for (int i = 0; i < elementsBuilders.size(); i++) { + CelExpr.Builder elem = elementsBuilders.get(i); + createList.setElement(i, visit(elem).build()); } - return celExpr.setCreateMap(createMapBuilder.build()).setId(celExprIdGenerator.nextExprId()); + return expr.setCreateList(createList.build()); + } + + private CelExpr.Builder visit(CelExpr.Builder expr, CelComprehension.Builder comprehension) { + comprehension.setIterRange(visit(comprehension.iterRange().toBuilder()).build()); + comprehension.setAccuInit(visit(comprehension.accuInit().toBuilder()).build()); + comprehension.setLoopCondition(visit(comprehension.loopCondition().toBuilder()).build()); + comprehension.setLoopStep(visit(comprehension.loopStep().toBuilder()).build()); + comprehension.setResult(visit(comprehension.result().toBuilder()).build()); + + return expr.setComprehension(comprehension.build()); } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel new file mode 100644 index 000000000..f463604da --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -0,0 +1,28 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//optimizer/optimizers:__pkg__", + ], +) + +java_library( + name = "constant_folding", + srcs = [ + "ConstantFoldingOptimizer.java", + ], + tags = [ + ], + deps = [ + "//bundle:cel", + "//common", + "//common:compiler_common", + "//common/ast", + "//common/ast:expr_util", + "//common/navigation", + "//optimizer:ast_optimizer", + "//optimizer:optimization_exception", + "//parser:operator", + "//runtime", + "@maven//:com_google_guava_guava", + ], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java new file mode 100644 index 000000000..83114d144 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -0,0 +1,549 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dev.cel.optimizer.optimizers; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.MoreCollectors.onlyElement; + +import com.google.common.collect.ImmutableList; +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.common.ast.CelConstant; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelCreateList; +import dev.cel.common.ast.CelExpr.CelCreateMap; +import dev.cel.common.ast.CelExpr.CelCreateStruct; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.ast.CelExprUtil; +import dev.cel.common.navigation.CelNavigableAst; +import dev.cel.common.navigation.CelNavigableExpr; +import dev.cel.optimizer.CelAstOptimizer; +import dev.cel.optimizer.CelOptimizationException; +import dev.cel.parser.Operator; +import dev.cel.runtime.CelEvaluationException; +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; + +/** + * Performs optimization for inlining constant scalar and aggregate literal values within function + * calls and select statements with their evaluated result. + */ +public final class ConstantFoldingOptimizer implements CelAstOptimizer { + public static final ConstantFoldingOptimizer INSTANCE = new ConstantFoldingOptimizer(); + private static final int MAX_ITERATION_COUNT = 400; + + // Use optional.of and optional.none as sentinel function names for folding optional calls. + // TODO: Leverage CelValue representation of Optionals instead when available. + private static final String OPTIONAL_OF_FUNCTION = "optional.of"; + private static final String OPTIONAL_NONE_FUNCTION = "optional.none"; + private static final CelExpr OPTIONAL_NONE_EXPR = + CelExpr.ofCallExpr(0, Optional.empty(), OPTIONAL_NONE_FUNCTION, ImmutableList.of()); + + @Override + public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) + throws CelOptimizationException { + Set visitedExprs = new HashSet<>(); + int iterCount = 0; + while (true) { + iterCount++; + if (iterCount == MAX_ITERATION_COUNT) { + throw new IllegalStateException("Max iteration count reached."); + } + Optional foldableExpr = + navigableAst + .getRoot() + .allNodes() + .filter(ConstantFoldingOptimizer::canFold) + .map(CelNavigableExpr::expr) + .filter(expr -> !visitedExprs.contains(expr)) + .findAny(); + if (!foldableExpr.isPresent()) { + break; + } + visitedExprs.add(foldableExpr.get()); + + Optional mutatedAst; + // Attempt to prune if it is a non-strict call + mutatedAst = maybePruneBranches(navigableAst.getAst(), foldableExpr.get()); + if (!mutatedAst.isPresent()) { + // Evaluate the call then fold + mutatedAst = maybeFold(cel, navigableAst.getAst(), foldableExpr.get()); + } + + if (!mutatedAst.isPresent()) { + // Skip this expr. It's neither prune-able nor foldable. + continue; + } + + visitedExprs.clear(); + navigableAst = CelNavigableAst.fromAst(mutatedAst.get()); + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + navigableAst = CelNavigableAst.fromAst(pruneOptionalElements(navigableAst)); + + return navigableAst.getAst(); + } + + private static boolean canFold(CelNavigableExpr navigableExpr) { + switch (navigableExpr.getKind()) { + case CALL: + CelCall celCall = navigableExpr.expr().call(); + String functionName = celCall.function(); + + // These are already folded or do not need to be folded. + if (functionName.equals(OPTIONAL_OF_FUNCTION) + || functionName.equals(OPTIONAL_NONE_FUNCTION)) { + return false; + } + + // Check non-strict calls + if (functionName.equals(Operator.LOGICAL_AND.getFunction()) + || functionName.equals(Operator.LOGICAL_OR.getFunction())) { + + // If any element is a constant, this could be a foldable expr (e.g: x && false -> x) + return celCall.args().stream() + .anyMatch(node -> node.exprKind().getKind().equals(Kind.CONSTANT)); + } + + if (functionName.equals(Operator.CONDITIONAL.getFunction())) { + CelExpr cond = celCall.args().get(0); + + // A ternary with a constant condition is trivially foldable + return cond.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE); + } + + if (functionName.equals(Operator.IN.getFunction())) { + return true; + } + + // Default case: all call arguments must be constants. If the argument is a container (ex: + // list, map), then its arguments must be a constant. + return areChildrenArgConstant(navigableExpr); + case SELECT: + CelNavigableExpr operand = navigableExpr.children().collect(onlyElement()); + return areChildrenArgConstant(operand); + case COMPREHENSION: + return !isNestedComprehension(navigableExpr); + default: + return false; + } + } + + private static boolean areChildrenArgConstant(CelNavigableExpr expr) { + if (expr.getKind().equals(Kind.CONSTANT)) { + return true; + } + + if (expr.getKind().equals(Kind.CALL) + || expr.getKind().equals(Kind.CREATE_LIST) + || expr.getKind().equals(Kind.CREATE_MAP) + || expr.getKind().equals(Kind.CREATE_STRUCT)) { + return expr.children().allMatch(ConstantFoldingOptimizer::areChildrenArgConstant); + } + + return false; + } + + private static boolean isNestedComprehension(CelNavigableExpr expr) { + Optional maybeParent = expr.parent(); + while (maybeParent.isPresent()) { + CelNavigableExpr parent = maybeParent.get(); + if (parent.getKind().equals(Kind.COMPREHENSION)) { + return true; + } + maybeParent = parent.parent(); + } + + return false; + } + + private Optional maybeFold( + Cel cel, CelAbstractSyntaxTree ast, CelExpr expr) throws CelOptimizationException { + Object result; + try { + result = CelExprUtil.evaluateExpr(cel, expr); + } catch (CelValidationException | CelEvaluationException e) { + throw new CelOptimizationException( + "Constant folding failure. Failed to evaluate subtree due to: " + e.getMessage(), e); + } + + // Rewrite optional calls to use the sentinel optional functions. + // ex1: optional.ofNonZeroValue(0) -> optional.none(). + // ex2: optional.ofNonZeroValue(5) -> optional.of(5) + if (result instanceof Optional) { + Optional optResult = ((Optional) result); + return maybeRewriteOptional(optResult, ast, expr); + } + + return maybeAdaptEvaluatedResult(result) + .map(celExpr -> replaceSubtree(ast, celExpr, expr.id())); + } + + private Optional maybeAdaptEvaluatedResult(Object result) { + if (CelConstant.isConstantValue(result)) { + return Optional.of( + CelExpr.newBuilder().setConstant(CelConstant.ofObjectValue(result)).build()); + } else if (result instanceof Collection) { + Collection collection = (Collection) result; + CelCreateList.Builder createListBuilder = CelCreateList.newBuilder(); + for (Object evaluatedElement : collection) { + Optional adaptedExpr = maybeAdaptEvaluatedResult(evaluatedElement); + if (!adaptedExpr.isPresent()) { + return Optional.empty(); + } + createListBuilder.addElements(adaptedExpr.get()); + } + + return Optional.of(CelExpr.newBuilder().setCreateList(createListBuilder.build()).build()); + } else if (result instanceof Map) { + Map map = (Map) result; + CelCreateMap.Builder createMapBuilder = CelCreateMap.newBuilder(); + for (Entry entry : map.entrySet()) { + Optional adaptedKey = maybeAdaptEvaluatedResult(entry.getKey()); + if (!adaptedKey.isPresent()) { + return Optional.empty(); + } + Optional adaptedValue = maybeAdaptEvaluatedResult(entry.getValue()); + if (!adaptedValue.isPresent()) { + return Optional.empty(); + } + + createMapBuilder.addEntries( + CelCreateMap.Entry.newBuilder() + .setKey(adaptedKey.get()) + .setValue(adaptedValue.get()) + .build()); + } + + return Optional.of(CelExpr.newBuilder().setCreateMap(createMapBuilder.build()).build()); + } + + // Evaluated result cannot be folded (e.g: unknowns) + return Optional.empty(); + } + + private Optional maybeRewriteOptional( + Optional optResult, CelAbstractSyntaxTree ast, CelExpr expr) { + if (!optResult.isPresent()) { + if (!expr.callOrDefault().function().equals(OPTIONAL_NONE_FUNCTION)) { + // An empty optional value was encountered. Rewrite the tree with optional.none call. + // This is to account for other optional functions returning an empty optional value + // e.g: optional.ofNonZeroValue(0) + return Optional.of(replaceSubtree(ast, OPTIONAL_NONE_EXPR, expr.id())); + } + } else if (!expr.callOrDefault().function().equals(OPTIONAL_OF_FUNCTION)) { + Object unwrappedResult = optResult.get(); + if (!CelConstant.isConstantValue(unwrappedResult)) { + // Evaluated result is not a constant. Leave the optional as is. + return Optional.empty(); + } + + CelExpr newOptionalOfCall = + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(OPTIONAL_OF_FUNCTION) + .addArgs( + CelExpr.newBuilder() + .setConstant(CelConstant.ofObjectValue(unwrappedResult)) + .build()) + .build()) + .build(); + return Optional.of(replaceSubtree(ast, newOptionalOfCall, expr.id())); + } + + return Optional.empty(); + } + + /** Inspects the non-strict calls to determine whether a branch can be removed. */ + private Optional maybePruneBranches( + CelAbstractSyntaxTree ast, CelExpr expr) { + if (!expr.exprKind().getKind().equals(Kind.CALL)) { + return Optional.empty(); + } + + CelCall call = expr.call(); + String function = call.function(); + if (function.equals(Operator.LOGICAL_AND.getFunction()) + || function.equals(Operator.LOGICAL_OR.getFunction())) { + return maybeShortCircuitCall(ast, expr); + } else if (function.equals(Operator.CONDITIONAL.getFunction())) { + CelExpr cond = call.args().get(0); + CelExpr truthy = call.args().get(1); + CelExpr falsy = call.args().get(2); + if (!cond.exprKind().getKind().equals(Kind.CONSTANT)) { + throw new IllegalStateException( + String.format( + "Expected constant condition. Got: %s instead.", cond.exprKind().getKind())); + } + CelExpr result = cond.constant().booleanValue() ? truthy : falsy; + + return Optional.of(replaceSubtree(ast, result, expr.id())); + } else if (function.equals(Operator.IN.getFunction())) { + CelCreateList haystack = call.args().get(1).createList(); + if (haystack.elements().isEmpty()) { + return Optional.of( + replaceSubtree( + ast, + CelExpr.newBuilder().setConstant(CelConstant.ofValue(false)).build(), + expr.id())); + } + + CelExpr needle = call.args().get(0); + if (needle.exprKind().getKind().equals(Kind.CONSTANT) + || needle.exprKind().getKind().equals(Kind.IDENT)) { + Object needleValue = + needle.exprKind().getKind().equals(Kind.CONSTANT) ? needle.constant() : needle.ident(); + for (CelExpr elem : haystack.elements()) { + if (elem.constantOrDefault().equals(needleValue) + || elem.identOrDefault().equals(needleValue)) { + return Optional.of( + replaceSubtree( + ast, + CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), + expr.id())); + } + } + } + } + + return Optional.empty(); + } + + private Optional maybeShortCircuitCall( + CelAbstractSyntaxTree ast, CelExpr expr) { + CelCall call = expr.call(); + boolean shortCircuit = false; + boolean skip = true; + if (call.function().equals(Operator.LOGICAL_OR.getFunction())) { + shortCircuit = true; + skip = false; + } + ImmutableList.Builder newArgsBuilder = new ImmutableList.Builder<>(); + + for (CelExpr arg : call.args()) { + if (!arg.exprKind().getKind().equals(Kind.CONSTANT)) { + newArgsBuilder.add(arg); + continue; + } + if (arg.constant().booleanValue() == skip) { + continue; + } + + if (arg.constant().booleanValue() == shortCircuit) { + return Optional.of(replaceSubtree(ast, arg, expr.id())); + } + } + + ImmutableList newArgs = newArgsBuilder.build(); + if (newArgs.isEmpty()) { + return Optional.of(replaceSubtree(ast, call.args().get(0), expr.id())); + } + if (newArgs.size() == 1) { + return Optional.of(replaceSubtree(ast, newArgs.get(0), expr.id())); + } + + // TODO: Support folding variadic AND/ORs. + throw new UnsupportedOperationException( + "Folding variadic logical operator is not supported yet."); + } + + private CelAbstractSyntaxTree pruneOptionalElements(CelNavigableAst navigableAst) { + ImmutableList aggregateLiterals = + navigableAst + .getRoot() + .allNodes() + .filter( + node -> + node.getKind().equals(Kind.CREATE_LIST) + || node.getKind().equals(Kind.CREATE_MAP) + || node.getKind().equals(Kind.CREATE_STRUCT)) + .map(CelNavigableExpr::expr) + .collect(toImmutableList()); + + CelAbstractSyntaxTree ast = navigableAst.getAst(); + for (CelExpr expr : aggregateLiterals) { + switch (expr.exprKind().getKind()) { + case CREATE_LIST: + ast = pruneOptionalListElements(ast, expr); + break; + case CREATE_MAP: + ast = pruneOptionalMapElements(ast, expr); + break; + case CREATE_STRUCT: + ast = pruneOptionalStructElements(ast, expr); + break; + default: + throw new IllegalArgumentException("Unexpected exprKind: " + expr.exprKind()); + } + } + return ast; + } + + private CelAbstractSyntaxTree pruneOptionalListElements(CelAbstractSyntaxTree ast, CelExpr expr) { + CelCreateList createList = expr.createList(); + if (createList.optionalIndices().isEmpty()) { + return ast; + } + + HashSet optionalIndices = new HashSet<>(createList.optionalIndices()); + ImmutableList.Builder updatedElemBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder updatedIndicesBuilder = new ImmutableList.Builder<>(); + int newOptIndex = -1; + for (int i = 0; i < createList.elements().size(); i++) { + newOptIndex++; + CelExpr element = createList.elements().get(i); + if (!optionalIndices.contains(i)) { + updatedElemBuilder.add(element); + continue; + } + + if (element.exprKind().getKind().equals(Kind.CALL)) { + CelCall call = element.call(); + if (call.function().equals(OPTIONAL_NONE_FUNCTION)) { + // Skip optional.none. + // Skipping causes the list to get smaller. + newOptIndex--; + continue; + } else if (call.function().equals(OPTIONAL_OF_FUNCTION)) { + CelExpr arg = call.args().get(0); + if (arg.exprKind().getKind().equals(Kind.CONSTANT)) { + updatedElemBuilder.add(call.args().get(0)); + continue; + } + } + } + + updatedElemBuilder.add(element); + updatedIndicesBuilder.add(newOptIndex); + } + + return replaceSubtree( + ast, + CelExpr.newBuilder() + .setCreateList( + CelCreateList.newBuilder() + .addElements(updatedElemBuilder.build()) + .addOptionalIndices(updatedIndicesBuilder.build()) + .build()) + .build(), + expr.id()); + } + + private CelAbstractSyntaxTree pruneOptionalMapElements(CelAbstractSyntaxTree ast, CelExpr expr) { + CelCreateMap createMap = expr.createMap(); + ImmutableList.Builder updatedEntryBuilder = new ImmutableList.Builder<>(); + boolean modified = false; + for (CelCreateMap.Entry entry : createMap.entries()) { + CelExpr key = entry.key(); + Kind keyKind = key.exprKind().getKind(); + CelExpr value = entry.value(); + Kind valueKind = value.exprKind().getKind(); + if (!entry.optionalEntry() + || !keyKind.equals(Kind.CONSTANT) + || !valueKind.equals(Kind.CALL)) { + updatedEntryBuilder.add(entry); + continue; + } + + CelCall call = value.call(); + if (call.function().equals(OPTIONAL_NONE_FUNCTION)) { + // Skip the element. This is resolving an optional.none: ex {?1: optional.none()}. + modified = true; + continue; + } else if (call.function().equals(OPTIONAL_OF_FUNCTION)) { + CelExpr arg = call.args().get(0); + if (arg.exprKind().getKind().equals(Kind.CONSTANT)) { + modified = true; + updatedEntryBuilder.add( + entry.toBuilder().setOptionalEntry(false).setValue(call.args().get(0)).build()); + continue; + } + } + + updatedEntryBuilder.add(entry); + } + + if (modified) { + return replaceSubtree( + ast, + CelExpr.newBuilder() + .setCreateMap( + CelCreateMap.newBuilder().addEntries(updatedEntryBuilder.build()).build()) + .build(), + expr.id()); + } + + return ast; + } + + private CelAbstractSyntaxTree pruneOptionalStructElements( + CelAbstractSyntaxTree ast, CelExpr expr) { + CelCreateStruct createStruct = expr.createStruct(); + ImmutableList.Builder updatedEntryBuilder = + new ImmutableList.Builder<>(); + boolean modified = false; + for (CelCreateStruct.Entry entry : createStruct.entries()) { + CelExpr value = entry.value(); + Kind valueKind = value.exprKind().getKind(); + if (!entry.optionalEntry() || !valueKind.equals(Kind.CALL)) { + // Preserve the entry as is + updatedEntryBuilder.add(entry); + continue; + } + + CelCall call = value.call(); + if (call.function().equals(OPTIONAL_NONE_FUNCTION)) { + // Skip the element. This is resolving an optional.none: ex msg{?field: optional.none()}. + modified = true; + continue; + } else if (call.function().equals(OPTIONAL_OF_FUNCTION)) { + CelExpr arg = call.args().get(0); + if (arg.exprKind().getKind().equals(Kind.CONSTANT)) { + modified = true; + updatedEntryBuilder.add( + entry.toBuilder().setOptionalEntry(false).setValue(call.args().get(0)).build()); + continue; + } + } + + updatedEntryBuilder.add(entry); + } + + if (modified) { + return replaceSubtree( + ast, + CelExpr.newBuilder() + .setCreateStruct( + CelCreateStruct.newBuilder() + .setMessageName(createStruct.messageName()) + .addEntries(updatedEntryBuilder.build()) + .build()) + .build(), + expr.id()); + } + + return ast; + } + + private ConstantFoldingOptimizer() {} +} diff --git a/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel index 075aaad2d..842fcbf52 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel @@ -11,10 +11,13 @@ java_library( "//bundle:cel", "//common", "//common:compiler_common", + "//common:options", "//common/ast", + "//common/navigation", "//common/resources/testdata/proto3:test_all_types_java_proto", "//common/types", "//compiler", + "//extensions:optional_library", "//optimizer", "//optimizer:optimization_exception", "//optimizer:optimizer_builder", @@ -24,6 +27,7 @@ java_library( "//parser:macro", "//parser:unparser", "//runtime", + "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", "@maven//:junit_junit", ], diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index 5cff92d09..65d94b70e 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -14,22 +14,26 @@ package dev.cel.optimizer; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableMap; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; -import dev.cel.common.CelSource; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.navigation.CelNavigableAst; +import dev.cel.common.navigation.CelNavigableExpr; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; +import dev.cel.extensions.CelOptionalLibrary; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; @@ -42,7 +46,10 @@ public class MutableAstTest { private static final Cel CEL = CelFactory.standardCelBuilder() .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions(CelOptions.current().populateMacroCalls(true).build()) .addMessageTypes(TestAllTypes.getDescriptor()) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) .setContainer("dev.cel.testing.testdata.proto3") .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) .addVar("x", SimpleType.INT) @@ -52,13 +59,54 @@ public class MutableAstTest { @Test public void constExpr() throws Exception { - CelExpr root = CEL.compile("10").getAst().getExpr(); + CelAbstractSyntaxTree ast = CEL.compile("10").getAst(); - CelExpr replacedExpr = + CelAbstractSyntaxTree mutatedAst = MutableAst.replaceSubtree( - root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); - assertThat(replacedExpr).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(true))); + assertThat(mutatedAst.getExpr()) + .isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(true))); + } + + @Test + public void mutableAst_returnsParsedAst() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("10").getAst(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtree( + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); + + assertThat(ast.isChecked()).isTrue(); + assertThat(mutatedAst.isChecked()).isFalse(); + } + + @Test + public void mutableAst_nonMacro_sourceCleared() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("10").getAst(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtree( + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); + + assertThat(mutatedAst.getSource().getDescription()).isEmpty(); + assertThat(mutatedAst.getSource().getLineOffsets()).isEmpty(); + assertThat(mutatedAst.getSource().getPositionsMap()).isEmpty(); + assertThat(mutatedAst.getSource().getMacroCalls()).isEmpty(); + } + + @Test + public void mutableAst_macro_sourceMacroCallsPopulated() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("has(TestAllTypes{}.single_int32)").getAst(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtree( + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); + + assertThat(mutatedAst.getSource().getDescription()).isEmpty(); + assertThat(mutatedAst.getSource().getLineOffsets()).isEmpty(); + assertThat(mutatedAst.getSource().getPositionsMap()).isEmpty(); + assertThat(mutatedAst.getSource().getMacroCalls()).isNotEmpty(); } @Test @@ -67,13 +115,13 @@ public void globalCallExpr_replaceRoot() throws Exception { // + [4] // + [2] x [5] // 1 [1] 2 [3] - CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 4); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 4); - assertThat(replacedRoot).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(10))); + assertThat(replacedAst.getExpr()).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(10))); } @Test @@ -82,13 +130,13 @@ public void globalCallExpr_replaceLeaf() throws Exception { // + [4] // + [2] x [5] // 1 [1] 2 [3] - CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 1); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 1); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10 + 2 + x"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10 + 2 + x"); } @Test @@ -97,13 +145,13 @@ public void globalCallExpr_replaceMiddleBranch() throws Exception { // + [4] // + [2] x [5] // 1 [1] 2 [3] - CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 2); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 2); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10 + x"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10 + x"); } @Test @@ -112,12 +160,12 @@ public void globalCallExpr_replaceMiddleBranch_withCallExpr() throws Exception { // + [4] // + [2] x [5] // 1 [1] 2 [3] - CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); - CelExpr root2 = CEL.compile("4 + 5 + 6").getAst().getExpr(); + CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); + CelAbstractSyntaxTree ast2 = CEL.compile("4 + 5 + 6").getAst(); - CelExpr replacedRoot = MutableAst.replaceSubtree(root, root2, 2); + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree(ast, ast2.getExpr(), 2); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("4 + 5 + 6 + x"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("4 + 5 + 6 + x"); } @Test @@ -136,11 +184,11 @@ public void memberCallExpr_replaceLeafTarget() throws Exception { .build(); CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 3); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 3); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(20.func(5))"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10.func(20.func(5))"); } @Test @@ -159,11 +207,11 @@ public void memberCallExpr_replaceLeafArgument() throws Exception { .build(); CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 5); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 5); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(4.func(20))"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10.func(4.func(20))"); } @Test @@ -182,11 +230,11 @@ public void memberCallExpr_replaceMiddleBranchTarget() throws Exception { .build(); CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 1); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 1); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("20.func(4.func(5))"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("20.func(4.func(5))"); } @Test @@ -205,11 +253,11 @@ public void memberCallExpr_replaceMiddleBranchArgument() throws Exception { .build(); CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 4); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 4); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(20)"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10.func(20)"); } @Test @@ -220,9 +268,9 @@ public void select_replaceField() throws Exception { // msg [3] CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), + ast, CelExpr.newBuilder() .setSelect( CelSelect.newBuilder() @@ -235,7 +283,7 @@ public void select_replaceField() throws Exception { .build(), 4); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("5 + test.single_sint32"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("5 + test.single_sint32"); } @Test @@ -246,13 +294,13 @@ public void select_replaceOperand() throws Exception { // msg [3] CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), + ast, CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName("test").build()).build(), 3); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("5 + test.single_int64"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("5 + test.single_int64"); } @Test @@ -262,11 +310,11 @@ public void list_replaceElement() throws Exception { // 2 [2] 3 [3] 4 [4] CelAbstractSyntaxTree ast = CEL.compile("[2, 3, 4]").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("[2, 3, 5]"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[2, 3, 5]"); } @Test @@ -277,11 +325,11 @@ public void createStruct_replaceValue() throws Exception { // 2 [3] CelAbstractSyntaxTree ast = CEL.compile("TestAllTypes{single_int64: 2}").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("TestAllTypes{single_int64: 5}"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("TestAllTypes{single_int64: 5}"); } @Test @@ -292,11 +340,11 @@ public void createMap_replaceKey() throws Exception { // 'a' [3] : 1 [4] CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("{5: 1}"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("{5: 1}"); } @Test @@ -307,25 +355,89 @@ public void createMap_replaceValue() throws Exception { // 'a' [3] : 1 [4] CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); - CelExpr replacedRoot = + CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree( - ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); - assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("{\"a\": 5}"); + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("{\"a\": 5}"); } @Test - public void invalidCelExprKind_throwsException() { - assertThrows( - IllegalArgumentException.class, - () -> - MutableAst.replaceSubtree( - CelExpr.ofConstantExpr(1, CelConstant.ofValue("test")), CelExpr.ofNotSet(1), 1)); + public void comprehension_replaceIterRange() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[true].exists(i, i)").getAst(); + + CelAbstractSyntaxTree replacedAst = + MutableAst.replaceSubtree( + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(false)).build(), 2); + + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[false].exists(i, i)"); + assertConsistentMacroCalls(ast); + assertThat(CEL.createProgram(CEL.check(replacedAst).getAst()).eval()).isEqualTo(false); + } + + @Test + public void comprehension_replaceAccuInit() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); + + CelAbstractSyntaxTree replacedAst = + MutableAst.replaceSubtree( + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 6); + + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[false].exists(i, i)"); + assertConsistentMacroCalls(ast); + assertThat(CEL.createProgram(CEL.check(replacedAst).getAst()).eval()).isEqualTo(true); + } + + @Test + public void comprehension_replaceLoopStep() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); + + CelAbstractSyntaxTree replacedAst = + MutableAst.replaceSubtree( + ast, + CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName("test").build()).build(), + 5); + + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[false].exists(i, test)"); + assertConsistentMacroCalls(ast); + } + + @Test + public void comprehension_astContainsDuplicateNodes() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[{\"a\": 1}].map(i, i)").getAst(); + + // AST contains two duplicate expr (ID: 9). Just ensure that it doesn't throw. + CelAbstractSyntaxTree replacedAst = + MutableAst.replaceSubtree(ast, CelExpr.newBuilder().build(), -1); + + assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[{\"a\": 1}].map(i, i)"); + assertConsistentMacroCalls(ast); } - private static String getUnparsedExpression(CelExpr expr) { - CelAbstractSyntaxTree ast = - CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()); - return CEL_UNPARSER.unparse(ast); + /** + * Asserts that the expressions that appears in source_info's macro calls are consistent with the + * actual expr nodes in the AST. + */ + private void assertConsistentMacroCalls(CelAbstractSyntaxTree ast) { + assertThat(ast.getSource().getMacroCalls()).isNotEmpty(); + ImmutableMap allExprs = + CelNavigableAst.fromAst(ast) + .getRoot() + .allNodes() + .map(CelNavigableExpr::expr) + .collect(toImmutableMap(CelExpr::id, node -> node, (expr1, expr2) -> expr1)); + for (CelExpr macroCall : ast.getSource().getMacroCalls().values()) { + assertThat(macroCall.id()).isEqualTo(0); + CelNavigableExpr.fromExpr(macroCall) + .descendants() + .map(CelNavigableExpr::expr) + .forEach( + node -> { + CelExpr e = allExprs.get(node.id()); + if (e != null) { + assertThat(node).isEqualTo(e); + } + }); + } } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel new file mode 100644 index 000000000..3e1119847 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -0,0 +1,35 @@ +load("//:testing.bzl", "junit4_test_suites") + +package(default_applicable_licenses = ["//:license"]) + +java_library( + name = "tests", + testonly = 1, + srcs = glob(["*.java"]), + deps = [ + "//:java_truth", + "//bundle:cel", + "//common", + "//common:options", + "//common/resources/testdata/proto3:test_all_types_java_proto", + "//common/types", + "//extensions:optional_library", + "//optimizer", + "//optimizer:optimization_exception", + "//optimizer:optimizer_builder", + "//optimizer/optimizers:constant_folding", + "//parser:macro", + "//parser:unparser", + "@maven//:com_google_testparameterinjector_test_parameter_injector", + "@maven//:junit_junit", + ], +) + +junit4_test_suites( + name = "test_suites", + sizes = [ + "small", + ], + src_dir = "src/test/java", + deps = [":tests"], +) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java new file mode 100644 index 000000000..d93c1e551 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -0,0 +1,301 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer.optimizers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelOptions; +import dev.cel.common.types.SimpleType; +import dev.cel.extensions.CelOptionalLibrary; +import dev.cel.optimizer.CelOptimizationException; +import dev.cel.optimizer.CelOptimizer; +import dev.cel.optimizer.CelOptimizerFactory; +import dev.cel.parser.CelStandardMacro; +import dev.cel.parser.CelUnparser; +import dev.cel.parser.CelUnparserFactory; +import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class ConstantFoldingOptimizerTest { + private static final Cel CEL = + CelFactory.standardCelBuilder() + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("dev.cel.testing.testdata.proto3") + .addCompilerLibraries(CelOptionalLibrary.INSTANCE) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .build(); + + private static final CelOptimizer CEL_OPTIMIZER = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers(ConstantFoldingOptimizer.INSTANCE) + .build(); + + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + + @Test + @TestParameters("{source: '1 + 2', expected: '3'}") + @TestParameters("{source: '1 + 2 + 3', expected: '6'}") + @TestParameters("{source: '1 + 2 + x', expected: '3 + x'}") + @TestParameters("{source: 'true && true', expected: 'true'}") + @TestParameters("{source: 'true && false', expected: 'false'}") + @TestParameters("{source: 'true || false', expected: 'true'}") + @TestParameters("{source: 'false || false', expected: 'false'}") + @TestParameters("{source: 'true && false || true', expected: 'true'}") + @TestParameters("{source: 'false && true || false', expected: 'false'}") + @TestParameters("{source: 'true && x', expected: 'x'}") + @TestParameters("{source: 'x && true', expected: 'x'}") + @TestParameters("{source: 'false && x', expected: 'false'}") + @TestParameters("{source: 'x && false', expected: 'false'}") + @TestParameters("{source: 'true || x', expected: 'true'}") + @TestParameters("{source: 'x || true', expected: 'true'}") + @TestParameters("{source: 'false || x', expected: 'x'}") + @TestParameters("{source: 'x || false', expected: 'x'}") + @TestParameters("{source: 'true && x && true && x', expected: 'x && x'}") + @TestParameters("{source: 'false || x || false || x', expected: 'x || x'}") + @TestParameters("{source: 'false || x || false || y', expected: 'x || y'}") + @TestParameters("{source: 'true ? x + 1 : x + 2', expected: 'x + 1'}") + @TestParameters("{source: 'false ? x + 1 : x + 2', expected: 'x + 2'}") + @TestParameters( + "{source: 'false ? x + ''world'' : ''hello'' + ''world''', expected: '\"helloworld\"'}") + @TestParameters("{source: 'true ? (false ? x + 1 : x + 2) : x', expected: 'x + 2'}") + @TestParameters("{source: 'false ? x : (true ? x + 1 : x + 2)', expected: 'x + 1'}") + @TestParameters("{source: '[1, 1 + 2, 1 + (2 + 3)]', expected: '[1, 3, 6]'}") + @TestParameters("{source: '1 in []', expected: 'false'}") + @TestParameters("{source: 'x in []', expected: 'false'}") + @TestParameters("{source: '6 in [1, 1 + 2, 1 + (2 + 3)]', expected: 'true'}") + @TestParameters("{source: '5 in [1, 1 + 2, 1 + (2 + 3)]', expected: 'false'}") + @TestParameters("{source: '5 in [1, x, y, 5]', expected: 'true'}") + @TestParameters("{source: '!(5 in [1, x, y, 5])', expected: 'false'}") + @TestParameters("{source: 'x in [1, x, y, 5]', expected: 'true'}") + @TestParameters("{source: 'x in [1, 1 + 2, 1 + (2 + 3)]', expected: 'x in [1, 3, 6]'}") + @TestParameters("{source: 'duration(string(7 * 24) + ''h'')', expected: 'duration(\"168h\")'}") + @TestParameters("{source: '[1, ?optional.of(3)]', expected: '[1, 3]'}") + @TestParameters("{source: '[1, ?optional.ofNonZeroValue(0)]', expected: '[1]'}") + @TestParameters("{source: '[1, optional.of(3)]', expected: '[1, optional.of(3)]'}") + @TestParameters("{source: '[?optional.none(), ?x]', expected: '[?x]'}") + @TestParameters("{source: '[?optional.of(1 + 2 + 3)]', expected: '[6]'}") + @TestParameters("{source: '[?optional.of(3)]', expected: '[3]'}") + @TestParameters("{source: '[?optional.of(x)]', expected: '[?optional.of(x)]'}") + @TestParameters("{source: '[?optional.ofNonZeroValue(0)]', expected: '[]'}") + @TestParameters("{source: '[?optional.ofNonZeroValue(3)]', expected: '[3]'}") + @TestParameters("{source: '[optional.none(), ?x]', expected: '[optional.none(), ?x]'}") + @TestParameters("{source: '[optional.of(1 + 2 + 3)]', expected: '[optional.of(6)]'}") + @TestParameters("{source: '[optional.of(3)]', expected: '[optional.of(3)]'}") + @TestParameters("{source: '[optional.of(x)]', expected: '[optional.of(x)]'}") + @TestParameters("{source: '[optional.ofNonZeroValue(1 + 2 + 3)]', expected: '[optional.of(6)]'}") + @TestParameters("{source: '[optional.ofNonZeroValue(3)]', expected: '[optional.of(3)]'}") + @TestParameters("{source: 'optional.none()', expected: 'optional.none()'}") + @TestParameters( + "{source: '[1, x, optional.of(1), ?optional.of(1), optional.ofNonZeroValue(3)," + + " ?optional.ofNonZeroValue(3), ?optional.ofNonZeroValue(0), ?y, ?x.?y]', " + + "expected: '[1, x, optional.of(1), 1, optional.of(3), 3, ?y, ?x.?y]'}") + @TestParameters( + "{source: '[1, x, ?optional.ofNonZeroValue(3), ?x.?y].size() > 3'," + + " expected: '[1, x, 3, ?x.?y].size() > 3'}") + @TestParameters("{source: '{?1: optional.none()}', expected: '{}'}") + @TestParameters( + "{source: '{?1: optional.of(\"hello\"), ?2: optional.ofNonZeroValue(0), 3:" + + " optional.ofNonZeroValue(0), ?4: optional.of(x)}', expected: '{1: \"hello\", 3:" + + " optional.none(), ?4: optional.of(x)}'}") + @TestParameters( + "{source: '{?x: optional.of(1), ?y: optional.ofNonZeroValue(0)}', expected: '{?x:" + + " optional.of(1), ?y: optional.none()}'}") + @TestParameters( + "{source: 'TestAllTypes{single_int64: 1 + 2 + 3 + x}', " + + " expected: 'TestAllTypes{single_int64: 6 + x}'}") + @TestParameters( + "{source: 'TestAllTypes{?single_int64: optional.ofNonZeroValue(1)}', " + + " expected: 'TestAllTypes{single_int64: 1}'}") + @TestParameters( + "{source: 'TestAllTypes{?single_int64: optional.ofNonZeroValue(0)}', " + + " expected: 'TestAllTypes{}'}") + @TestParameters( + "{source: 'TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" + + " optional.of(4), ?single_uint64: optional.ofNonZeroValue(x)}', expected:" + + " 'TestAllTypes{single_int64: 1, single_int32: 4, ?single_uint64:" + + " optional.ofNonZeroValue(x)}'}") + @TestParameters("{source: '{\"hello\": \"world\"}.hello == x', expected: '\"world\" == x'}") + @TestParameters("{source: '{\"hello\": \"world\"}[\"hello\"] == x', expected: '\"world\" == x'}") + @TestParameters("{source: '{\"hello\": \"world\"}.?hello', expected: 'optional.of(\"world\")'}") + @TestParameters( + "{source: '{\"hello\": \"world\"}.?hello.orValue(\"default\") == x', " + + "expected: '\"world\" == x'}") + @TestParameters( + "{source: '{?\"hello\": optional.of(\"world\")}[\"hello\"] == x', expected: '\"world\" ==" + + " x'}") + @TestParameters("{source: '[1] + [2] + [3]', expected: '[1, 2, 3]'}") + @TestParameters("{source: '[1] + [?optional.of(2)] + [3]', expected: '[1, 2, 3]'}") + @TestParameters("{source: '[1] + [x]', expected: '[1] + [x]'}") + @TestParameters("{source: 'x + dyn([1, 2] + [3, 4])', expected: 'x + [1, 2, 3, 4]'}") + @TestParameters( + "{source: '{\"a\": dyn([1, 2]), \"b\": x}', expected: '{\"a\": [1, 2], \"b\": x}'}") + // TODO: Support folding lists with mixed types. This requires mutable lists. + // @TestParameters("{source: 'dyn([1]) + [1.0]'}") + public void constantFold_success(String source, String expected) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(expected); + } + + @Test + @TestParameters("{source: '[1 + 1, 1 + 2].exists(i, i < 10)', expected: 'true'}") + @TestParameters("{source: '[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 10)', expected: 'true'}") + @TestParameters("{source: '[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 1 % 2)', expected: 'false'}") + @TestParameters("{source: '[1, 2, 3].map(i, i * 2)', expected: '[2, 4, 6]'}") + @TestParameters( + "{source: '[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))', " + + "expected: '[[1, 2, 3], [2, 4, 6], [3, 6, 9]]'}") + @TestParameters( + "{source: '[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == 0))', " + + "expected: '[[2], [2, 4, 6], [6]]'}") + @TestParameters( + "{source: '[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))', " + + "expected: '[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))'}") + @TestParameters( + "{source: '[{}, {\"a\": 1}, {\"b\": 2}].filter(m, has(m.a))', expected: '[{\"a\": 1}]'}") + @TestParameters( + "{source: '[{}, {?\"a\": optional.of(1)}, {\"b\": optional.of(2)}].filter(m, has(m.a))'," + + " expected: '[{\"a\": 1}]'}") + @TestParameters( + "{source: '[{}, {?\"a\": optional.of(1)}, {\"b\": optional.of(2)}].map(i, i)'," + + " expected: '[{}, {\"a\": 1}, {\"b\": optional.of(2)}].map(i, i)'}") + @TestParameters( + "{source: '[{}, {\"a\": 1}, {\"b\": 2}].filter(m, has({\"a\": true}.a))'," + + " expected: '[{}, {\"a\": 1}, {\"b\": 2}]'}") + public void constantFold_macros_macroCallMetadataPopulated(String source, String expected) + throws Exception { + Cel cel = + CelFactory.standardCelBuilder() + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions(CelOptions.current().populateMacroCalls(true).build()) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .build(); + CelOptimizer celOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(ConstantFoldingOptimizer.INSTANCE) + .build(); + CelAbstractSyntaxTree ast = cel.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(expected); + } + + @Test + @TestParameters("{source: '[1 + 1, 1 + 2].exists(i, i < 10) == true'}") + @TestParameters("{source: '[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 10) == true'}") + @TestParameters("{source: '[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 1 % 2) == false'}") + @TestParameters("{source: '[1, 2, 3].map(i, i * 2) == [2, 4, 6]'}") + @TestParameters( + "{source: '[1, 2, 3].map(i, [1, 2, 3].map(j, i * j)) == [[1, 2, 3], [2, 4, 6], [3, 6, 9]]'}") + @TestParameters( + "{source: '[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == 0)) == " + + "[[2], [2, 4, 6], [6]]'}") + @TestParameters("{source: '[{}, {\"a\": 1}, {\"b\": 2}].filter(m, has(m.a)) == [{\"a\": 1}]'}") + @TestParameters( + "{source: '[{}, {?\"a\": optional.of(1)}, {\"b\": optional.of(2)}].filter(m, has(m.a)) == " + + " [{\"a\": 1}]'}") + @TestParameters( + "{source: '[{}, {?\"a\": optional.of(1)}, {\"b\": optional.of(2)}].map(i, i) == " + + " [{}, {\"a\": 1}, {\"b\": optional.of(2)}].map(i, i)'}") + @TestParameters( + "{source: '[{}, {\"a\": 1}, {\"b\": 2}].filter(m, has({\"a\": true}.a)) == " + + " [{}, {\"a\": 1}, {\"b\": 2}]'}") + public void constantFold_macros_withoutMacroCallMetadata(String source) throws Exception { + Cel cel = + CelFactory.standardCelBuilder() + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions(CelOptions.current().populateMacroCalls(false).build()) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .build(); + CelOptimizer celOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(ConstantFoldingOptimizer.INSTANCE) + .build(); + CelAbstractSyntaxTree ast = cel.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(optimizedAst.getSource().getMacroCalls()).isEmpty(); + assertThat(cel.createProgram(optimizedAst).eval()).isEqualTo(true); + } + + @Test + @TestParameters("{source: '1'}") + @TestParameters("{source: '5.0'}") + @TestParameters("{source: 'true'}") + @TestParameters("{source: 'x'}") + @TestParameters("{source: 'x + 1 + 2'}") + @TestParameters("{source: 'duration(\"16800s\")'}") + @TestParameters("{source: 'timestamp(\"1970-01-01T00:00:00Z\")'}") + @TestParameters( + "{source: 'type(1)'}") // This folds in cel-go implementation but Java does not yet have a + // literal representation for type values + @TestParameters("{source: '[1, 2, 3]'}") + @TestParameters("{source: 'optional.of(\"hello\")'}") + @TestParameters("{source: 'optional.none()'}") + @TestParameters("{source: '[optional.none()]'}") + @TestParameters("{source: '[?x.?y]'}") + @TestParameters("{source: 'TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}'}") + public void constantFold_noOp(String source) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); + } + + @Test + public void maxIterationCountReached_throws() throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("0"); + for (int i = 1; i < 400; i++) { + sb.append(" + ").append(i); + } // 0 + 1 + 2 + 3 + ... 400 + Cel cel = + CelFactory.standardCelBuilder() + .setOptions(CelOptions.current().maxParseRecursionDepth(400).build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile(sb.toString()).getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(ConstantFoldingOptimizer.INSTANCE) + .build(); + + CelOptimizationException e = + assertThrows(CelOptimizationException.class, () -> optimizer.optimize(ast)); + assertThat(e).hasMessageThat().contains("Optimization failure: Max iteration count reached."); + } +} diff --git a/parser/src/main/java/dev/cel/parser/Parser.java b/parser/src/main/java/dev/cel/parser/Parser.java index abc5eb21f..af455fe01 100644 --- a/parser/src/main/java/dev/cel/parser/Parser.java +++ b/parser/src/main/java/dev/cel/parser/Parser.java @@ -1113,7 +1113,7 @@ private long nextExprId(int position) { @Override public long nextExprId() { checkState(!positions.isEmpty()); // Should only be called while expanding macros. - // Do not call this method directly from within the parser, use nextExprId(int). + // Do not call this method directly from within the parser, use generate(int). return nextExprId(peekPosition()); } diff --git a/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java index a6d641311..45ada6ca3 100644 --- a/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java @@ -57,7 +57,7 @@ public static HomogeneousLiteralValidator newInstance(String... exemptFunctions) public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) { navigableAst .getRoot() - .descendants() + .allNodes() .filter( node -> node.getKind().equals(Kind.CREATE_LIST) || node.getKind().equals(Kind.CREATE_MAP)) diff --git a/validator/src/main/java/dev/cel/validator/validators/LiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/LiteralValidator.java index 6183febb4..0848870cc 100644 --- a/validator/src/main/java/dev/cel/validator/validators/LiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/LiteralValidator.java @@ -39,7 +39,7 @@ protected LiteralValidator(String functionName, Class expectedResultType) { @Override public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) { - CelExprFactory exprFactory = CelExprFactory.newBuilder().build(); + CelExprFactory exprFactory = CelExprFactory.newInstance(); navigableAst .getRoot() diff --git a/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java index fc79083c5..a596de9c1 100644 --- a/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java @@ -31,7 +31,7 @@ public final class RegexLiteralValidator implements CelAstValidator { public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) { navigableAst .getRoot() - .descendants() + .allNodes() .filter(node -> node.expr().callOrDefault().function().equals("matches")) .filter(node -> ImmutableList.of(1, 2).contains(node.expr().call().args().size())) .map(node -> node.expr().call())