Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unpacking any messages containing extension fields #155

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bundle/src/main/java/dev/cel/bundle/CelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.checker.ProtoTypeMask;
import dev.cel.checker.TypeProvider;
Expand Down Expand Up @@ -282,6 +283,13 @@ public interface CelBuilder {
@CanIgnoreReturnValue
CelBuilder addRuntimeLibraries(Iterable<CelRuntimeLibrary> libraries);

/**
* Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2
extension field.
*/
@CanIgnoreReturnValue
CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry);

/** Construct a new {@code Cel} instance from the provided configuration. */
Cel build();
}
8 changes: 8 additions & 0 deletions bundle/src/main/java/dev/cel/bundle/CelImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.checker.CelCheckerBuilder;
import dev.cel.checker.ProtoTypeMask;
Expand Down Expand Up @@ -339,6 +340,13 @@ public Builder addRuntimeLibraries(Iterable<CelRuntimeLibrary> libraries) {
return this;
}

@Override
public CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry) {
checkNotNull(extensionRegistry);
runtimeBuilder.setExtensionRegistry(extensionRegistry);
return this;
}

@Override
public Cel build() {
return new CelImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ public final class DefaultDescriptorPool implements CelDescriptorPool {

/** A DefaultDescriptorPool instance with just well known types loaded. */
public static final DefaultDescriptorPool INSTANCE =
new DefaultDescriptorPool(WELL_KNOWN_TYPE_DESCRIPTORS, ImmutableMultimap.of());
new DefaultDescriptorPool(
WELL_KNOWN_TYPE_DESCRIPTORS,
ImmutableMultimap.of(),
ExtensionRegistry.getEmptyRegistry());

// K: Fully qualified message type name, V: Message descriptor
private final ImmutableMap<String, Descriptor> descriptorMap;
Expand All @@ -55,7 +58,15 @@ public final class DefaultDescriptorPool implements CelDescriptorPool {
// V: Field descriptor for the extension message
private final ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap;

@SuppressWarnings("Immutable") // ExtensionRegistry is immutable, just not marked as such.
private final ExtensionRegistry extensionRegistry;

public static DefaultDescriptorPool create(CelDescriptors celDescriptors) {
return create(celDescriptors, ExtensionRegistry.getEmptyRegistry());
}

public static DefaultDescriptorPool create(
CelDescriptors celDescriptors, ExtensionRegistry extensionRegistry) {
Map<String, Descriptor> descriptorMap = new HashMap<>(); // Using a hashmap to allow deduping
stream(WellKnownProto.values()).forEach(d -> descriptorMap.put(d.typeName(), d.descriptor()));

Expand All @@ -64,7 +75,9 @@ public static DefaultDescriptorPool create(CelDescriptors celDescriptors) {
}

return new DefaultDescriptorPool(
ImmutableMap.copyOf(descriptorMap), celDescriptors.extensionDescriptors());
ImmutableMap.copyOf(descriptorMap),
celDescriptors.extensionDescriptors(),
extensionRegistry);
}

@Override
Expand All @@ -83,14 +96,15 @@ public Optional<FieldDescriptor> findExtensionDescriptor(

@Override
public ExtensionRegistry getExtensionRegistry() {
// TODO: Populate one from runtime builder.
return ExtensionRegistry.getEmptyRegistry();
return extensionRegistry;
}

private DefaultDescriptorPool(
ImmutableMap<String, Descriptor> descriptorMap,
ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap) {
ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap,
ExtensionRegistry extensionRegistry) {
this.descriptorMap = checkNotNull(descriptorMap);
this.extensionDescriptorMap = checkNotNull(extensionDescriptorMap);
this.extensionRegistry = checkNotNull(extensionRegistry);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.Any;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelFactory;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOverloadDecl;
Expand Down Expand Up @@ -277,6 +281,29 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception {
assertThat(result).isTrue();
}

@Test
public void getExt_onAnyPackedExtensionField_success() throws Exception {
ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
MessagesProto2Extensions.registerAllExtensions(extensionRegistry);
Cel cel =
CelFactory.standardCelBuilder()
.addCompilerLibraries(CelExtensions.protos())
.addFileTypes(MessagesProto2Extensions.getDescriptor())
.setExtensionRegistry(extensionRegistry)
.addVar(
"msg", StructTypeReference.create("dev.cel.testing.testdata.proto2.Proto2Message"))
.build();
CelAbstractSyntaxTree ast =
cel.compile("proto.getExt(msg, dev.cel.testing.testdata.proto2.int32_ext)").getAst();
Any anyMsg =
Any.pack(
Proto2Message.newBuilder().setExtension(MessagesProto2Extensions.int32Ext, 1).build());

Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg));

assertThat(result).isEqualTo(1);
}

private enum ParseErrorTestCase {
FIELD_NOT_FULLY_QUALIFIED(
"proto.getExt(Proto2ExtensionScopedMessage{}, int64_ext)",
Expand Down
8 changes: 8 additions & 0 deletions runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.common.CelOptions;
import java.util.function.Function;
Expand Down Expand Up @@ -149,6 +150,13 @@ public interface CelRuntimeBuilder {
@CanIgnoreReturnValue
CelRuntimeBuilder addLibraries(Iterable<? extends CelRuntimeLibrary> libraries);

/**
* Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2
extension field.
*/
@CanIgnoreReturnValue
CelRuntimeBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry);

/** Build a new instance of the {@code CelRuntime}. */
@CheckReturnValue
CelRuntime build();
Expand Down
15 changes: 14 additions & 1 deletion runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelDescriptorUtil;
Expand Down Expand Up @@ -79,6 +80,7 @@ public static final class Builder implements CelRuntimeBuilder {

private boolean standardEnvironmentEnabled;
private Function<String, Message.Builder> customTypeFactory;
private ExtensionRegistry extensionRegistry;

@Override
@CanIgnoreReturnValue
Expand Down Expand Up @@ -161,6 +163,14 @@ public Builder addLibraries(Iterable<? extends CelRuntimeLibrary> libraries) {
return this;
}

@Override
@CanIgnoreReturnValue
public Builder setExtensionRegistry(ExtensionRegistry extensionRegistry) {
checkNotNull(extensionRegistry);
this.extensionRegistry = extensionRegistry.getUnmodifiable();
return this;
}

/** Build a new {@code CelRuntimeLegacyImpl} instance from the builder config. */
@Override
@CanIgnoreReturnValue
Expand All @@ -171,6 +181,7 @@ public CelRuntimeLegacyImpl build() {
CelDescriptorPool celDescriptorPool =
newDescriptorPool(
fileTypes.build(),
extensionRegistry,
options);

@SuppressWarnings("Immutable")
Expand Down Expand Up @@ -214,14 +225,15 @@ public CelRuntimeLegacyImpl build() {

private static CelDescriptorPool newDescriptorPool(
ImmutableSet<FileDescriptor> fileTypeSet,
ExtensionRegistry extensionRegistry,
CelOptions celOptions) {
CelDescriptors celDescriptors =
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(
fileTypeSet, celOptions.resolveTypeDependencies());

ImmutableList.Builder<CelDescriptorPool> descriptorPools = new ImmutableList.Builder<>();

descriptorPools.add(DefaultDescriptorPool.create(celDescriptors));
descriptorPools.add(DefaultDescriptorPool.create(celDescriptors, extensionRegistry));

return CombinedDescriptorPool.create(descriptorPools.build());
}
Expand All @@ -241,6 +253,7 @@ private Builder() {
this.fileTypes = ImmutableSet.builder();
this.functionBindings = ImmutableMap.builder();
this.celRuntimeLibraries = ImmutableSet.builder();
this.extensionRegistry = ExtensionRegistry.getEmptyRegistry();
this.customTypeFactory = null;
}
}
Expand Down