From 00426f847681a2cb5a2bb930369ef799c7f3b72e Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 3 Oct 2023 12:16:49 -0700 Subject: [PATCH] Add constant folding PiperOrigin-RevId: 570460653 --- .../java/dev/cel/common/ast/CelConstant.java | 43 ++ .../common/navigation/CelNavigableExpr.java | 33 +- .../dev/cel/common/ast/CelConstantTest.java | 24 + .../CelNavigableExprVisitorTest.java | 79 ++- optimizer/optimizers/BUILD.bazel | 9 + .../main/java/dev/cel/optimizer/BUILD.bazel | 3 + .../dev/cel/optimizer/CelAstOptimizer.java | 27 +- .../dev/cel/optimizer/optimizers/BUILD.bazel | 28 + .../optimizers/ConstantFoldingOptimizer.java | 496 ++++++++++++++++++ .../dev/cel/optimizer/optimizers/BUILD.bazel | 34 ++ .../ConstantFoldingOptimizerTest.java | 201 +++++++ .../HomogeneousLiteralValidator.java | 2 +- .../validators/RegexLiteralValidator.java | 2 +- 13 files changed, 942 insertions(+), 39 deletions(-) create mode 100644 optimizer/optimizers/BUILD.bazel create mode 100644 optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel create mode 100644 optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java create mode 100644 optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel create mode 100644 optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java diff --git a/common/src/main/java/dev/cel/common/ast/CelConstant.java b/common/src/main/java/dev/cel/common/ast/CelConstant.java index a21639fc..f981225e 100644 --- a/common/src/main/java/dev/cel/common/ast/CelConstant.java +++ b/common/src/main/java/dev/cel/common/ast/CelConstant.java @@ -16,6 +16,7 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableSet; import com.google.common.primitives.UnsignedLong; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.ByteString; @@ -33,6 +34,16 @@ @Internal @Immutable public abstract class CelConstant { + private static final ImmutableSet> CONSTANT_CLASSES = + ImmutableSet.of( + NullValue.class, + Boolean.class, + Long.class, + UnsignedLong.class, + Double.class, + String.class, + ByteString.class); + /** Represents the type of the Constant */ public enum Kind { NOT_SET, @@ -127,6 +138,38 @@ public static CelConstant ofValue(ByteString value) { return AutoOneOf_CelConstant.bytesValue(value); } + /** Checks whether the provided Java object is a valid CelConstant value. */ + public static boolean isConstantValue(Object value) { + return CONSTANT_CLASSES.contains(value.getClass()); + } + + /** + * Converts the given Java object into a CelConstant value. This is equivalent of calling {@link + * CelConstant#ofValue} with concrete types. + * + * @throws IllegalArgumentException If the value is not a supported CelConstant. This includes the + * deprecated duration and timestamp values. + */ + public static CelConstant ofObjectValue(Object value) { + if (value instanceof NullValue) { + return ofValue((NullValue) value); + } else if (value instanceof Boolean) { + return ofValue((boolean) value); + } else if (value instanceof Long) { + return ofValue((long) value); + } else if (value instanceof UnsignedLong) { + return ofValue((UnsignedLong) value); + } else if (value instanceof Double) { + return ofValue((double) value); + } else if (value instanceof String) { + return ofValue((String) value); + } else if (value instanceof ByteString) { + return ofValue((ByteString) value); + } + + throw new IllegalArgumentException("Value is not a CelConstant: " + value); + } + /** * @deprecated Do not use. Duration is no longer built-in CEL type. */ diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java index 1e4411f2..774176b3 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java @@ -59,32 +59,49 @@ public enum TraversalOrder { * Returns a stream of {@link CelNavigableExpr} collected from the current node down to the last * leaf-level member using post-order traversal. */ - public Stream descendants() { - return descendants(TraversalOrder.POST_ORDER); + public Stream allNodes() { + return allNodes(TraversalOrder.POST_ORDER); } /** * Returns a stream of {@link CelNavigableExpr} collected from the current node down to the last * leaf-level member using the specified traversal order. */ - public Stream descendants(TraversalOrder traversalOrder) { + public Stream allNodes(TraversalOrder traversalOrder) { return CelNavigableExprVisitor.collect(this, traversalOrder); } /** - * Returns a stream of {@link CelNavigableExpr} collected from the current node to its immediate - * children using post-order traversal. + * Returns a stream of {@link CelNavigableExpr} collected down to the last leaf-level member using + * post-order traversal. + */ + public Stream descendants() { + return descendants(TraversalOrder.POST_ORDER); + } + + /** + * Returns a stream of {@link CelNavigableExpr} collected down to the last leaf-level member using + * the specified traversal order. + */ + public Stream descendants(TraversalOrder traversalOrder) { + return CelNavigableExprVisitor.collect(this, traversalOrder).filter(node -> !node.equals(this)); + } + + /** + * Returns a stream of {@link CelNavigableExpr} collected from its immediate children using + * post-order traversal. */ public Stream children() { return children(TraversalOrder.POST_ORDER); } /** - * Returns a stream of {@link CelNavigableExpr} collected from the current node to its immediate - * children using the specified traversal order. + * Returns a stream of {@link CelNavigableExpr} collected from its immediate children using the + * specified traversal order. */ public Stream children(TraversalOrder traversalOrder) { - return CelNavigableExprVisitor.collect(this, 1, traversalOrder); + return CelNavigableExprVisitor.collect(this, this.depth() + 1, traversalOrder) + .filter(node -> !node.equals(this)); } /** Returns the underlying kind of the {@link CelExpr}. */ diff --git a/common/src/test/java/dev/cel/common/ast/CelConstantTest.java b/common/src/test/java/dev/cel/common/ast/CelConstantTest.java index abf09ac4..462267ba 100644 --- a/common/src/test/java/dev/cel/common/ast/CelConstantTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelConstantTest.java @@ -196,4 +196,28 @@ public void getValueOnInvalidKindCase_throwsException( assertThrows(UnsupportedOperationException.class, constant::timestampValue); } } + + @Test + public void getObjectValue_success() { + assertThat(CelConstant.ofObjectValue(NullValue.NULL_VALUE)) + .isEqualTo(CelConstant.ofValue(NullValue.NULL_VALUE)); + assertThat(CelConstant.ofObjectValue(true)).isEqualTo(CelConstant.ofValue(true)); + assertThat(CelConstant.ofObjectValue(2L)).isEqualTo(CelConstant.ofValue(2L)); + assertThat(CelConstant.ofObjectValue(UnsignedLong.valueOf(3L))) + .isEqualTo(CelConstant.ofValue(UnsignedLong.valueOf(3L))); + assertThat(CelConstant.ofObjectValue(3.0d)).isEqualTo(CelConstant.ofValue(3.0d)); + assertThat(CelConstant.ofObjectValue("test")).isEqualTo(CelConstant.ofValue("test")); + assertThat(CelConstant.ofObjectValue(ByteString.copyFromUtf8("hello"))) + .isEqualTo(CelConstant.ofValue(ByteString.copyFromUtf8("hello"))); + } + + @Test + public void getObjectValue_invalidParameter_throws() { + assertThrows( + IllegalArgumentException.class, + () -> CelConstant.ofObjectValue(Duration.getDefaultInstance())); + assertThrows( + IllegalArgumentException.class, + () -> CelConstant.ofObjectValue(Timestamp.getDefaultInstance())); + } } diff --git a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java index 5b83e64d..a496a133 100644 --- a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java +++ b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java @@ -72,7 +72,7 @@ public void collectWithMaxDepth_expectedNodeCountReturned(int maxDepth, int expe } @Test - public void add_descendants_allNodesReturned() throws Exception { + public void add_allNodes_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); // Tree shape: @@ -83,7 +83,7 @@ public void add_descendants_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); CelExpr childAddCall = CelExpr.ofCallExpr( @@ -109,7 +109,7 @@ public void add_descendants_allNodesReturned() throws Exception { } @Test - public void add_filterConstants_descendantsReturned() throws Exception { + public void add_filterConstants_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); // Tree shape: @@ -122,7 +122,7 @@ public void add_filterConstants_descendantsReturned() throws Exception { ImmutableList allConstants = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CONSTANT)) .collect(toImmutableList()); @@ -147,7 +147,7 @@ public void add_filterConstants_parentsPopulated() throws Exception { ImmutableList allConstants = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CONSTANT)) .collect(toImmutableList()); @@ -221,6 +221,33 @@ public void add_filterConstants_parentOfChildPopulated() throws Exception { assertThat(parentExpr.call().args().get(1).constant()).isEqualTo(CelConstant.ofValue(2)); } + @Test + public void add_childrenOfMiddleBranch_success() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); + // Tree shape: + // + + // + 2 + // 1 a + CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + CelNavigableExpr ident = + navigableAst + .getRoot() + .allNodes() + .filter(node -> node.getKind().equals(Kind.IDENT)) // Find "a" + .findAny() + .get(); + + ImmutableList children = + ident.parent().get().children().collect(toImmutableList()); + + // Assert that the children of add call in the middle branch are const(1) and ident("a") + assertThat(children).hasSize(2); + assertThat(children.get(0).expr()).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(1))); + assertThat(children.get(1)).isEqualTo(ident); + } + @Test public void stringFormatCall_filterList_success() throws Exception { CelCompiler compiler = @@ -239,7 +266,7 @@ public void stringFormatCall_filterList_success() throws Exception { ImmutableList allConstants = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CREATE_LIST)) .collect(toImmutableList()); @@ -270,7 +297,7 @@ public void stringFormatCall_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList targetExprs = - navigableAst.getRoot().descendants().collect(toImmutableList()); + navigableAst.getRoot().allNodes().collect(toImmutableList()); assertThat(targetExprs).hasSize(5); } @@ -286,7 +313,7 @@ public void message_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); CelExpr operand = CelExpr.ofIdentExpr(1, "msg"); assertThat(allNodes) @@ -294,7 +321,7 @@ public void message_allNodesReturned() throws Exception { } @Test - public void nestedMessage_filterSelect_descendantsReturned() throws Exception { + public void nestedMessage_filterSelect_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder() .addMessageTypes(TestAllTypes.getDescriptor()) @@ -306,7 +333,7 @@ public void nestedMessage_filterSelect_descendantsReturned() throws Exception { ImmutableList allSelects = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.SELECT)) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -360,7 +387,7 @@ public void presenceTest_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); assertThat(allNodes).hasSize(2); assertThat(allNodes) @@ -380,7 +407,7 @@ public void messageConstruction_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); CelExpr constExpr = CelExpr.ofConstantExpr(3, CelConstant.ofValue(1)); assertThat(allNodes) @@ -394,7 +421,7 @@ public void messageConstruction_allNodesReturned() throws Exception { } @Test - public void messageConstruction_filterCreateStruct_descendantsReturned() throws Exception { + public void messageConstruction_filterCreateStruct_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder() .addMessageTypes(TestAllTypes.getDescriptor()) @@ -406,7 +433,7 @@ public void messageConstruction_filterCreateStruct_descendantsReturned() throws ImmutableList allNodes = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CREATE_STRUCT)) .collect(toImmutableList()); @@ -431,7 +458,7 @@ public void mapConstruction_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().map(CelNavigableExpr::expr).collect(toImmutableList()); + navigableAst.getRoot().allNodes().map(CelNavigableExpr::expr).collect(toImmutableList()); assertThat(allNodes).hasSize(3); CelExpr mapKeyExpr = CelExpr.ofConstantExpr(3, CelConstant.ofValue("key")); @@ -447,7 +474,7 @@ public void mapConstruction_allNodesReturned() throws Exception { } @Test - public void mapConstruction_filterCreateMap_descendantsReturned() throws Exception { + public void mapConstruction_filterCreateMap_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().build(); CelAbstractSyntaxTree ast = compiler.compile("{'key': 2}").getAst(); CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); @@ -455,7 +482,7 @@ public void mapConstruction_filterCreateMap_descendantsReturned() throws Excepti ImmutableList allNodes = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.CREATE_MAP)) .collect(toImmutableList()); @@ -477,7 +504,7 @@ public void emptyMapConstruction_allNodesReturned() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants().collect(toImmutableList()); + navigableAst.getRoot().allNodes().collect(toImmutableList()); assertThat(allNodes).hasSize(1); assertThat(allNodes.get(0).expr()).isEqualTo(CelExpr.ofCreateMapExpr(1, ImmutableList.of())); @@ -495,7 +522,7 @@ public void comprehension_preOrder_allNodesReturned() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.PRE_ORDER) + .allNodes(TraversalOrder.PRE_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -557,7 +584,7 @@ public void comprehension_postOrder_allNodesReturned() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.POST_ORDER) + .allNodes(TraversalOrder.POST_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -617,7 +644,7 @@ public void comprehension_allNodes_parentsPopulated() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); ImmutableList allNodes = - navigableAst.getRoot().descendants(TraversalOrder.PRE_ORDER).collect(toImmutableList()); + navigableAst.getRoot().allNodes(TraversalOrder.PRE_ORDER).collect(toImmutableList()); CelExpr iterRangeConstExpr = CelExpr.ofConstantExpr(2, CelConstant.ofValue(true)); CelExpr iterRange = @@ -668,7 +695,7 @@ public void comprehension_allNodes_parentsPopulated() throws Exception { } @Test - public void comprehension_filterComprehension_descendantsReturned() throws Exception { + public void comprehension_filterComprehension_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder() .setStandardMacros(CelStandardMacro.EXISTS) @@ -679,7 +706,7 @@ public void comprehension_filterComprehension_descendantsReturned() throws Excep ImmutableList allNodes = navigableAst .getRoot() - .descendants() + .allNodes() .filter(x -> x.getKind().equals(Kind.COMPREHENSION)) .collect(toImmutableList()); @@ -736,7 +763,7 @@ public void callExpr_preOrder() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.PRE_ORDER) + .allNodes(TraversalOrder.PRE_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -773,7 +800,7 @@ public void callExpr_postOrder() throws Exception { ImmutableList allNodes = navigableAst .getRoot() - .descendants(TraversalOrder.POST_ORDER) + .allNodes(TraversalOrder.POST_ORDER) .map(CelNavigableExpr::expr) .collect(toImmutableList()); @@ -805,7 +832,7 @@ public void maxRecursionLimitReached_throws() throws Exception { CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); IllegalStateException e = - assertThrows(IllegalStateException.class, () -> navigableAst.getRoot().descendants()); + assertThrows(IllegalStateException.class, () -> navigableAst.getRoot().allNodes()); assertThat(e).hasMessageThat().contains("Max recursion depth reached."); } } diff --git a/optimizer/optimizers/BUILD.bazel b/optimizer/optimizers/BUILD.bazel new file mode 100644 index 00000000..f4095d8b --- /dev/null +++ b/optimizer/optimizers/BUILD.bazel @@ -0,0 +1,9 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], # TODO: Expose to public +) + +java_library( + name = "constant_folding", + exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:constant_folding"], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel index f99d21c8..00b4249e 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel @@ -65,8 +65,11 @@ java_library( tags = [ ], deps = [ + ":mutable_ast", + ":optimization_exception", "//bundle:cel", "//common", + "//common/ast", "//common/navigation", ], ) diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index 8dc80f15..aa0c5e48 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -16,13 +16,34 @@ import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.ast.CelExpr; import dev.cel.common.navigation.CelNavigableAst; /** Public interface for performing a single, custom optimization on an AST. */ public interface CelAstOptimizer { + /** Optimizes a single AST. */ + CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) + throws CelOptimizationException; + /** - * Optimizes a single AST. - **/ - CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel); + * Replaces a subtree in the given CelExpr. This operation is intended for AST optimization + * purposes. + * + *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and + * additionally verify that the resulting AST is semantically valid. + * + *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision + * between the nodes. The renumbering occurs even if the subtree was not replaced. + * + * @param ast Original ast to mutate. + * @param newExpr New CelExpr to replace the subtree with. + * @param exprIdToReplace Expression id of the subtree that is getting replaced. + */ + default CelAbstractSyntaxTree replaceSubtree( + CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { + CelExpr newRoot = MutableAst.replaceSubtree(ast.getExpr(), newExpr, exprIdToReplace); + + return CelAbstractSyntaxTree.newParsedAst(newRoot, ast.getSource()); + } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel new file mode 100644 index 00000000..f463604d --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -0,0 +1,28 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//optimizer/optimizers:__pkg__", + ], +) + +java_library( + name = "constant_folding", + srcs = [ + "ConstantFoldingOptimizer.java", + ], + tags = [ + ], + deps = [ + "//bundle:cel", + "//common", + "//common:compiler_common", + "//common/ast", + "//common/ast:expr_util", + "//common/navigation", + "//optimizer:ast_optimizer", + "//optimizer:optimization_exception", + "//parser:operator", + "//runtime", + "@maven//:com_google_guava_guava", + ], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java new file mode 100644 index 00000000..174d8b04 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -0,0 +1,496 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dev.cel.optimizer.optimizers; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.MoreCollectors.onlyElement; + +import com.google.common.collect.ImmutableList; +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.common.ast.CelConstant; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelCreateList; +import dev.cel.common.ast.CelExpr.CelCreateMap; +import dev.cel.common.ast.CelExpr.CelCreateStruct; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.ast.CelExprUtil; +import dev.cel.common.navigation.CelNavigableAst; +import dev.cel.common.navigation.CelNavigableExpr; +import dev.cel.optimizer.CelAstOptimizer; +import dev.cel.optimizer.CelOptimizationException; +import dev.cel.parser.Operator; +import dev.cel.runtime.CelEvaluationException; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +/** + * Performs optimization for inlining constant scalar and aggregate literal values within function + * calls and select statements with their evaluated result. + */ +public final class ConstantFoldingOptimizer implements CelAstOptimizer { + public static final ConstantFoldingOptimizer INSTANCE = new ConstantFoldingOptimizer(); + private static final int MAX_ITERATION_COUNT = 400; + + // Use optional.of and optional.none as sentinel function names for folding optional calls. + // TODO: Leverage CelValue representation of Optionals instead when available. + private static final String OPTIONAL_OF_FUNCTION = "optional.of"; + private static final String OPTIONAL_NONE_FUNCTION = "optional.none"; + private static final CelExpr OPTIONAL_NONE_EXPR = + CelExpr.ofCallExpr(0, Optional.empty(), OPTIONAL_NONE_FUNCTION, ImmutableList.of()); + + @Override + public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) + throws CelOptimizationException { + Set visitedExprs = new HashSet<>(); + int iterCount = 0; + while (true) { + iterCount++; + if (iterCount == MAX_ITERATION_COUNT) { + throw new IllegalStateException("Max iteration count reached."); + } + Optional foldableExpr = + navigableAst + .getRoot() + .allNodes() + .filter(ConstantFoldingOptimizer::canFold) + .map(CelNavigableExpr::expr) + .filter(expr -> !visitedExprs.contains(expr)) + .findAny(); + if (!foldableExpr.isPresent()) { + break; + } + visitedExprs.add(foldableExpr.get()); + + Optional mutatedAst; + // Attempt to prune if it is a non-strict call + mutatedAst = maybePruneBranches(navigableAst.getAst(), foldableExpr.get()); + if (!mutatedAst.isPresent()) { + // Evaluate the call then fold + mutatedAst = maybeFold(cel, navigableAst.getAst(), foldableExpr.get()); + } + + if (!mutatedAst.isPresent()) { + // Skip this expr. It's neither prune-able nor foldable. + continue; + } + + visitedExprs.clear(); + navigableAst = CelNavigableAst.fromAst(mutatedAst.get()); + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + navigableAst = CelNavigableAst.fromAst(pruneOptionalElements(navigableAst)); + + return navigableAst.getAst(); + } + + private static boolean canFold(CelNavigableExpr navigableExpr) { + switch (navigableExpr.getKind()) { + case CALL: + CelCall celCall = navigableExpr.expr().call(); + String functionName = celCall.function(); + + // These are already folded or do not need to be folded. + if (functionName.equals(OPTIONAL_OF_FUNCTION) + || functionName.equals(OPTIONAL_NONE_FUNCTION)) { + return false; + } + + // Check non-strict calls + if (functionName.equals(Operator.LOGICAL_AND.getFunction()) + || functionName.equals(Operator.LOGICAL_OR.getFunction())) { + + // If any element is a constant, this could be a foldable expr (e.g: x && false -> x) + return celCall.args().stream() + .anyMatch(node -> node.exprKind().getKind().equals(Kind.CONSTANT)); + } + + if (functionName.equals(Operator.CONDITIONAL.getFunction())) { + CelExpr cond = celCall.args().get(0); + + // A ternary with a constant condition is trivially foldable + return cond.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE); + } + + if (functionName.equals(Operator.IN.getFunction())) { + return true; + } + + // Default case: all call arguments must be constants. If the argument is a container (ex: + // list, map), then its arguments must be a constant. + return areChildrenArgConstant(navigableExpr); + case SELECT: + CelNavigableExpr operand = navigableExpr.children().collect(onlyElement()); + return areChildrenArgConstant(operand); + default: + return false; + } + } + + private static boolean areChildrenArgConstant(CelNavigableExpr expr) { + if (expr.getKind().equals(Kind.CONSTANT)) { + return true; + } + + if (expr.getKind().equals(Kind.CALL) + || expr.getKind().equals(Kind.CREATE_LIST) + || expr.getKind().equals(Kind.CREATE_MAP) + || expr.getKind().equals(Kind.CREATE_STRUCT)) { + return expr.children().allMatch(ConstantFoldingOptimizer::areChildrenArgConstant); + } + + return false; + } + + private Optional maybeFold( + Cel cel, CelAbstractSyntaxTree ast, CelExpr expr) throws CelOptimizationException { + Object result; + try { + result = CelExprUtil.evaluateExpr(cel, expr); + } catch (CelValidationException | CelEvaluationException e) { + throw new CelOptimizationException( + "Constant folding failure. Failed to evaluate subtree due to: " + e.getMessage(), e); + } + + // Rewrite optional calls to use the sentinel optional functions. + // ex1: optional.ofNonZeroValue(0) -> optional.none(). + // ex2: optional.ofNonZeroValue(5) -> optional.of(5) + if (result instanceof Optional) { + Optional optResult = ((Optional) result); + return maybeRewriteOptional(optResult, ast, expr); + } + + if (!CelConstant.isConstantValue(result)) { + // Evaluated result is not a constant (e.g: unknowns) + return Optional.empty(); + } + + return Optional.of( + replaceSubtree( + ast, + CelExpr.newBuilder().setConstant(CelConstant.ofObjectValue(result)).build(), + expr.id())); + } + + private Optional maybeRewriteOptional( + Optional optResult, CelAbstractSyntaxTree ast, CelExpr expr) { + if (!optResult.isPresent()) { + if (!expr.callOrDefault().function().equals(OPTIONAL_NONE_FUNCTION)) { + // An empty optional value was encountered. Rewrite the tree with optional.none call. + // This is to account for other optional functions returning an empty optional value + // e.g: optional.ofNonZeroValue(0) + return Optional.of(replaceSubtree(ast, OPTIONAL_NONE_EXPR, expr.id())); + } + } else if (!expr.callOrDefault().function().equals(OPTIONAL_OF_FUNCTION)) { + Object unwrappedResult = optResult.get(); + if (!CelConstant.isConstantValue(unwrappedResult)) { + // Evaluated result is not a constant. Leave the optional as is. + return Optional.empty(); + } + + CelExpr newOptionalOfCall = + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(OPTIONAL_OF_FUNCTION) + .addArgs( + CelExpr.newBuilder() + .setConstant(CelConstant.ofObjectValue(unwrappedResult)) + .build()) + .build()) + .build(); + return Optional.of(replaceSubtree(ast, newOptionalOfCall, expr.id())); + } + + return Optional.empty(); + } + + /** Inspects the non-strict calls to determine whether a branch can be removed. */ + private Optional maybePruneBranches( + CelAbstractSyntaxTree ast, CelExpr expr) { + if (!expr.exprKind().getKind().equals(Kind.CALL)) { + return Optional.empty(); + } + + CelCall call = expr.call(); + String function = call.function(); + if (function.equals(Operator.LOGICAL_AND.getFunction()) + || function.equals(Operator.LOGICAL_OR.getFunction())) { + return maybeShortCircuitCall(ast, expr); + } else if (function.equals(Operator.CONDITIONAL.getFunction())) { + CelExpr cond = call.args().get(0); + CelExpr truthy = call.args().get(1); + CelExpr falsy = call.args().get(2); + if (!cond.exprKind().getKind().equals(Kind.CONSTANT)) { + throw new IllegalStateException( + String.format( + "Expected constant condition. Got: %s instead.", cond.exprKind().getKind())); + } + CelExpr result = cond.constant().booleanValue() ? truthy : falsy; + + return Optional.of(replaceSubtree(ast, result, expr.id())); + } else if (function.equals(Operator.IN.getFunction())) { + CelCreateList haystack = call.args().get(1).createList(); + if (haystack.elements().isEmpty()) { + return Optional.of( + replaceSubtree( + ast, + CelExpr.newBuilder().setConstant(CelConstant.ofValue(false)).build(), + expr.id())); + } + + CelExpr needle = call.args().get(0); + if (needle.exprKind().getKind().equals(Kind.CONSTANT) + || needle.exprKind().getKind().equals(Kind.IDENT)) { + Object needleValue = + needle.exprKind().getKind().equals(Kind.CONSTANT) ? needle.constant() : needle.ident(); + for (CelExpr elem : haystack.elements()) { + if (elem.constantOrDefault().equals(needleValue) + || elem.identOrDefault().equals(needleValue)) { + return Optional.of( + replaceSubtree( + ast, + CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), + expr.id())); + } + } + } + } + + return Optional.empty(); + } + + private Optional maybeShortCircuitCall( + CelAbstractSyntaxTree ast, CelExpr expr) { + CelCall call = expr.call(); + boolean shortCircuit = false; + boolean skip = true; + if (call.function().equals(Operator.LOGICAL_OR.getFunction())) { + shortCircuit = true; + skip = false; + } + ImmutableList.Builder newArgsBuilder = new ImmutableList.Builder<>(); + + for (CelExpr arg : call.args()) { + if (!arg.exprKind().getKind().equals(Kind.CONSTANT)) { + newArgsBuilder.add(arg); + continue; + } + if (arg.constant().booleanValue() == skip) { + continue; + } + + if (arg.constant().booleanValue() == shortCircuit) { + return Optional.of(replaceSubtree(ast, arg, expr.id())); + } + } + + ImmutableList newArgs = newArgsBuilder.build(); + if (newArgs.isEmpty()) { + return Optional.of(replaceSubtree(ast, call.args().get(0), expr.id())); + } + if (newArgs.size() == 1) { + return Optional.of(replaceSubtree(ast, newArgs.get(0), expr.id())); + } + + // TODO: Support folding variadic AND/ORs. + throw new UnsupportedOperationException( + "Folding variadic logical operator is not supported yet."); + } + + private CelAbstractSyntaxTree pruneOptionalElements(CelNavigableAst navigableAst) { + ImmutableList aggregateLiterals = + navigableAst + .getRoot() + .allNodes() + .filter( + node -> + node.getKind().equals(Kind.CREATE_LIST) + || node.getKind().equals(Kind.CREATE_MAP) + || node.getKind().equals(Kind.CREATE_STRUCT)) + .map(CelNavigableExpr::expr) + .collect(toImmutableList()); + + CelAbstractSyntaxTree ast = navigableAst.getAst(); + for (CelExpr expr : aggregateLiterals) { + switch (expr.exprKind().getKind()) { + case CREATE_LIST: + ast = pruneOptionalListElements(ast, expr); + break; + case CREATE_MAP: + ast = pruneOptionalMapElements(ast, expr); + break; + case CREATE_STRUCT: + ast = pruneOptionalStructElements(ast, expr); + break; + default: + throw new IllegalArgumentException("Unexpected exprKind: " + expr.exprKind()); + } + } + return ast; + } + + private CelAbstractSyntaxTree pruneOptionalListElements(CelAbstractSyntaxTree ast, CelExpr expr) { + CelCreateList createList = expr.createList(); + if (createList.optionalIndices().isEmpty()) { + return ast; + } + + HashSet optionalIndices = new HashSet<>(createList.optionalIndices()); + ImmutableList.Builder updatedElemBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder updatedIndicesBuilder = new ImmutableList.Builder<>(); + int newOptIndex = -1; + for (int i = 0; i < createList.elements().size(); i++) { + newOptIndex++; + CelExpr element = createList.elements().get(i); + if (!optionalIndices.contains(i)) { + updatedElemBuilder.add(element); + continue; + } + + if (element.exprKind().getKind().equals(Kind.CALL)) { + CelCall call = element.call(); + if (call.function().equals(OPTIONAL_NONE_FUNCTION)) { + // Skip optional.none. + // Skipping causes the list to get smaller. + newOptIndex--; + continue; + } else if (call.function().equals(OPTIONAL_OF_FUNCTION)) { + CelExpr arg = call.args().get(0); + if (arg.exprKind().getKind().equals(Kind.CONSTANT)) { + updatedElemBuilder.add(call.args().get(0)); + continue; + } + } + } + + updatedElemBuilder.add(element); + updatedIndicesBuilder.add(newOptIndex); + } + + return replaceSubtree( + ast, + CelExpr.newBuilder() + .setCreateList( + CelCreateList.newBuilder() + .addElements(updatedElemBuilder.build()) + .addOptionalIndices(updatedIndicesBuilder.build()) + .build()) + .build(), + expr.id()); + } + + private CelAbstractSyntaxTree pruneOptionalMapElements(CelAbstractSyntaxTree ast, CelExpr expr) { + CelCreateMap createMap = expr.createMap(); + ImmutableList.Builder updatedEntryBuilder = new ImmutableList.Builder<>(); + boolean modified = false; + for (CelCreateMap.Entry entry : createMap.entries()) { + CelExpr key = entry.key(); + Kind keyKind = key.exprKind().getKind(); + CelExpr value = entry.value(); + Kind valueKind = value.exprKind().getKind(); + if (!entry.optionalEntry() + || !keyKind.equals(Kind.CONSTANT) + || !valueKind.equals(Kind.CALL)) { + updatedEntryBuilder.add(entry); + continue; + } + + CelCall call = value.call(); + if (call.function().equals(OPTIONAL_NONE_FUNCTION)) { + // Skip the element. This is resolving an optional.none: ex {?1: optional.none()}. + modified = true; + continue; + } else if (call.function().equals(OPTIONAL_OF_FUNCTION)) { + CelExpr arg = call.args().get(0); + if (arg.exprKind().getKind().equals(Kind.CONSTANT)) { + modified = true; + updatedEntryBuilder.add( + entry.toBuilder().setOptionalEntry(false).setValue(call.args().get(0)).build()); + continue; + } + } + + updatedEntryBuilder.add(entry); + } + + if (modified) { + return replaceSubtree( + ast, + CelExpr.newBuilder() + .setCreateMap( + CelCreateMap.newBuilder().addEntries(updatedEntryBuilder.build()).build()) + .build(), + expr.id()); + } + + return ast; + } + + private CelAbstractSyntaxTree pruneOptionalStructElements( + CelAbstractSyntaxTree ast, CelExpr expr) { + CelCreateStruct createStruct = expr.createStruct(); + ImmutableList.Builder updatedEntryBuilder = + new ImmutableList.Builder<>(); + boolean modified = false; + for (CelCreateStruct.Entry entry : createStruct.entries()) { + CelExpr value = entry.value(); + Kind valueKind = value.exprKind().getKind(); + if (!entry.optionalEntry() || !valueKind.equals(Kind.CALL)) { + // Preserve the entry as is + updatedEntryBuilder.add(entry); + continue; + } + + CelCall call = value.call(); + if (call.function().equals(OPTIONAL_NONE_FUNCTION)) { + // Skip the element. This is resolving an optional.none: ex msg{?field: optional.none()}. + modified = true; + continue; + } else if (call.function().equals(OPTIONAL_OF_FUNCTION)) { + CelExpr arg = call.args().get(0); + if (arg.exprKind().getKind().equals(Kind.CONSTANT)) { + modified = true; + updatedEntryBuilder.add( + entry.toBuilder().setOptionalEntry(false).setValue(call.args().get(0)).build()); + continue; + } + } + + updatedEntryBuilder.add(entry); + } + + if (modified) { + return replaceSubtree( + ast, + CelExpr.newBuilder() + .setCreateStruct( + CelCreateStruct.newBuilder() + .setMessageName(createStruct.messageName()) + .addEntries(updatedEntryBuilder.build()) + .build()) + .build(), + expr.id()); + } + + return ast; + } + + private ConstantFoldingOptimizer() {} +} diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel new file mode 100644 index 00000000..e19c36fb --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -0,0 +1,34 @@ +load("//:testing.bzl", "junit4_test_suites") + +package(default_applicable_licenses = ["//:license"]) + +java_library( + name = "tests", + testonly = 1, + srcs = glob(["*.java"]), + deps = [ + "//:java_truth", + "//bundle:cel", + "//common", + "//common:options", + "//common/resources/testdata/proto3:test_all_types_java_proto", + "//common/types", + "//extensions:optional_library", + "//optimizer", + "//optimizer:optimization_exception", + "//optimizer:optimizer_builder", + "//optimizer/optimizers:constant_folding", + "//parser:unparser", + "@maven//:com_google_testparameterinjector_test_parameter_injector", + "@maven//:junit_junit", + ], +) + +junit4_test_suites( + name = "test_suites", + sizes = [ + "small", + ], + src_dir = "src/test/java", + deps = [":tests"], +) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java new file mode 100644 index 00000000..bec2d009 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -0,0 +1,201 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer.optimizers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelOptions; +import dev.cel.common.types.SimpleType; +import dev.cel.extensions.CelOptionalLibrary; +import dev.cel.optimizer.CelOptimizationException; +import dev.cel.optimizer.CelOptimizer; +import dev.cel.optimizer.CelOptimizerFactory; +import dev.cel.parser.CelUnparser; +import dev.cel.parser.CelUnparserFactory; +import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class ConstantFoldingOptimizerTest { + private static final Cel CEL = + CelFactory.standardCelBuilder() + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("dev.cel.testing.testdata.proto3") + .addCompilerLibraries(CelOptionalLibrary.INSTANCE) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .build(); + + private static final CelOptimizer CEL_OPTIMIZER = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers(ConstantFoldingOptimizer.INSTANCE) + .build(); + + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + + @Test + @TestParameters("{source: '1 + 2', expected: '3'}") + @TestParameters("{source: '1 + 2 + 3', expected: '6'}") + @TestParameters("{source: '1 + 2 + x', expected: '3 + x'}") + @TestParameters("{source: 'true && true', expected: 'true'}") + @TestParameters("{source: 'true && false', expected: 'false'}") + @TestParameters("{source: 'true || false', expected: 'true'}") + @TestParameters("{source: 'false || false', expected: 'false'}") + @TestParameters("{source: 'true && false || true', expected: 'true'}") + @TestParameters("{source: 'false && true || false', expected: 'false'}") + @TestParameters("{source: 'true && x', expected: 'x'}") + @TestParameters("{source: 'x && true', expected: 'x'}") + @TestParameters("{source: 'false && x', expected: 'false'}") + @TestParameters("{source: 'x && false', expected: 'false'}") + @TestParameters("{source: 'true || x', expected: 'true'}") + @TestParameters("{source: 'x || true', expected: 'true'}") + @TestParameters("{source: 'false || x', expected: 'x'}") + @TestParameters("{source: 'x || false', expected: 'x'}") + @TestParameters("{source: 'true && x && true && x', expected: 'x && x'}") + @TestParameters("{source: 'false || x || false || x', expected: 'x || x'}") + @TestParameters("{source: 'false || x || false || y', expected: 'x || y'}") + @TestParameters("{source: 'true ? x + 1 : x + 2', expected: 'x + 1'}") + @TestParameters("{source: 'false ? x + 1 : x + 2', expected: 'x + 2'}") + @TestParameters( + "{source: 'false ? x + ''world'' : ''hello'' + ''world''', expected: '\"helloworld\"'}") + @TestParameters("{source: 'true ? (false ? x + 1 : x + 2) : x', expected: 'x + 2'}") + @TestParameters("{source: 'false ? x : (true ? x + 1 : x + 2)', expected: 'x + 1'}") + @TestParameters("{source: '[1, 1 + 2, 1 + (2 + 3)]', expected: '[1, 3, 6]'}") + @TestParameters("{source: '1 in []', expected: 'false'}") + @TestParameters("{source: 'x in []', expected: 'false'}") + @TestParameters("{source: '6 in [1, 1 + 2, 1 + (2 + 3)]', expected: 'true'}") + @TestParameters("{source: '5 in [1, 1 + 2, 1 + (2 + 3)]', expected: 'false'}") + @TestParameters("{source: '5 in [1, x, y, 5]', expected: 'true'}") + @TestParameters("{source: '!(5 in [1, x, y, 5])', expected: 'false'}") + @TestParameters("{source: 'x in [1, x, y, 5]', expected: 'true'}") + @TestParameters("{source: 'x in [1, 1 + 2, 1 + (2 + 3)]', expected: 'x in [1, 3, 6]'}") + @TestParameters("{source: 'duration(string(7 * 24) + ''h'')', expected: 'duration(\"168h\")'}") + @TestParameters("{source: '[1, ?optional.of(3)]', expected: '[1, 3]'}") + @TestParameters("{source: '[1, ?optional.ofNonZeroValue(0)]', expected: '[1]'}") + @TestParameters("{source: '[1, optional.of(3)]', expected: '[1, optional.of(3)]'}") + @TestParameters("{source: '[?optional.none(), ?x]', expected: '[?x]'}") + @TestParameters("{source: '[?optional.of(1 + 2 + 3)]', expected: '[6]'}") + @TestParameters("{source: '[?optional.of(3)]', expected: '[3]'}") + @TestParameters("{source: '[?optional.of(x)]', expected: '[?optional.of(x)]'}") + @TestParameters("{source: '[?optional.ofNonZeroValue(0)]', expected: '[]'}") + @TestParameters("{source: '[?optional.ofNonZeroValue(3)]', expected: '[3]'}") + @TestParameters("{source: '[optional.none(), ?x]', expected: '[optional.none(), ?x]'}") + @TestParameters("{source: '[optional.of(1 + 2 + 3)]', expected: '[optional.of(6)]'}") + @TestParameters("{source: '[optional.of(3)]', expected: '[optional.of(3)]'}") + @TestParameters("{source: '[optional.of(x)]', expected: '[optional.of(x)]'}") + @TestParameters("{source: '[optional.ofNonZeroValue(1 + 2 + 3)]', expected: '[optional.of(6)]'}") + @TestParameters("{source: '[optional.ofNonZeroValue(3)]', expected: '[optional.of(3)]'}") + @TestParameters("{source: 'optional.none()', expected: 'optional.none()'}") + @TestParameters( + "{source: '[1, x, optional.of(1), ?optional.of(1), optional.ofNonZeroValue(3)," + + " ?optional.ofNonZeroValue(3), ?optional.ofNonZeroValue(0), ?y, ?x.?y]', " + + "expected: '[1, x, optional.of(1), 1, optional.of(3), 3, ?y, ?x.?y]'}") + @TestParameters( + "{source: '[1, x, ?optional.ofNonZeroValue(3), ?x.?y].size() > 3'," + + " expected: '[1, x, 3, ?x.?y].size() > 3'}") + @TestParameters("{source: '{?1: optional.none()}', expected: '{}'}") + @TestParameters( + "{source: '{?1: optional.of(\"hello\"), ?2: optional.ofNonZeroValue(0), 3:" + + " optional.ofNonZeroValue(0), ?4: optional.of(x)}', expected: '{1: \"hello\", 3:" + + " optional.none(), ?4: optional.of(x)}'}") + @TestParameters( + "{source: '{?x: optional.of(1), ?y: optional.ofNonZeroValue(0)}', expected: '{?x:" + + " optional.of(1), ?y: optional.none()}'}") + @TestParameters( + "{source: 'TestAllTypes{single_int64: 1 + 2 + 3 + x}', " + + " expected: 'TestAllTypes{single_int64: 6 + x}'}") + @TestParameters( + "{source: 'TestAllTypes{?single_int64: optional.ofNonZeroValue(1)}', " + + " expected: 'TestAllTypes{single_int64: 1}'}") + @TestParameters( + "{source: 'TestAllTypes{?single_int64: optional.ofNonZeroValue(0)}', " + + " expected: 'TestAllTypes{}'}") + @TestParameters( + "{source: 'TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" + + " optional.of(4), ?single_uint64: optional.ofNonZeroValue(x)}', expected:" + + " 'TestAllTypes{single_int64: 1, single_int32: 4, ?single_uint64:" + + " optional.ofNonZeroValue(x)}'}") + @TestParameters("{source: '{\"hello\": \"world\"}.hello == x', expected: '\"world\" == x'}") + @TestParameters("{source: '{\"hello\": \"world\"}[\"hello\"] == x', expected: '\"world\" == x'}") + @TestParameters("{source: '{\"hello\": \"world\"}.?hello', expected: 'optional.of(\"world\")'}") + @TestParameters( + "{source: '{\"hello\": \"world\"}.?hello.orValue(\"default\") == x', " + + "expected: '\"world\" == x'}") + @TestParameters( + "{source: '{?\"hello\": optional.of(\"world\")}[\"hello\"] == x', expected: '\"world\" ==" + + " x'}") + public void constantFold_success(String source, String expected) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(expected); + } + + @Test + @TestParameters("{source: '1'}") + @TestParameters("{source: '5.0'}") + @TestParameters("{source: 'true'}") + @TestParameters("{source: 'x'}") + @TestParameters("{source: 'x + 1 + 2'}") + @TestParameters("{source: 'duration(\"16800s\")'}") + @TestParameters("{source: 'timestamp(\"1970-01-01T00:00:00Z\")'}") + @TestParameters( + "{source: 'type(1)'}") // This folds in cel-go implementation but Java does not yet have a + // literal representation for type values + @TestParameters("{source: '[1, 2, 3]'}") + @TestParameters("{source: 'optional.of(\"hello\")'}") + @TestParameters("{source: 'optional.none()'}") + @TestParameters("{source: '[optional.none()]'}") + @TestParameters("{source: '[?x.?y]'}") + @TestParameters("{source: 'TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}'}") + public void constantFold_noOp(String source) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); + } + + @Test + public void maxIterationCountReached_throws() throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("0"); + for (int i = 1; i < 400; i++) { + sb.append(" + ").append(i); + } // 0 + 1 + 2 + 3 + ... 400 + Cel cel = + CelFactory.standardCelBuilder() + .setOptions(CelOptions.current().maxParseRecursionDepth(400).build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile(sb.toString()).getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(ConstantFoldingOptimizer.INSTANCE) + .build(); + + CelOptimizationException e = + assertThrows(CelOptimizationException.class, () -> optimizer.optimize(ast)); + assertThat(e).hasMessageThat().contains("Optimization failure: Max iteration count reached."); + } +} diff --git a/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java index a6d64131..45ada6ca 100644 --- a/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java @@ -57,7 +57,7 @@ public static HomogeneousLiteralValidator newInstance(String... exemptFunctions) public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) { navigableAst .getRoot() - .descendants() + .allNodes() .filter( node -> node.getKind().equals(Kind.CREATE_LIST) || node.getKind().equals(Kind.CREATE_MAP)) diff --git a/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java index fc79083c..a596de9c 100644 --- a/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java @@ -31,7 +31,7 @@ public final class RegexLiteralValidator implements CelAstValidator { public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) { navigableAst .getRoot() - .descendants() + .allNodes() .filter(node -> node.expr().callOrDefault().function().equals("matches")) .filter(node -> ImmutableList.of(1, 2).contains(node.expr().call().args().size())) .map(node -> node.expr().call())