diff --git a/options/message_sets.go b/internal/messageset/messageset.go similarity index 91% rename from options/message_sets.go rename to internal/messageset/messageset.go index 2096eb58..850a0c66 100644 --- a/options/message_sets.go +++ b/internal/messageset/messageset.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package options +package messageset import ( "math" @@ -28,7 +28,9 @@ var ( messageSetSupportInit sync.Once ) -func canSerializeMessageSets() bool { +// CanSupportMessageSets returns true if the protobuf-go runtime supports +// serializing messages with the message set wire format. +func CanSupportMessageSets() bool { messageSetSupportInit.Do(func() { // We check using the protodesc package, instead of just relying // on protolegacy build tag, in case someone links in a fork of diff --git a/options/message_sets_protolegacy_test.go b/internal/messageset/messageset_protolegacy_test.go similarity index 92% rename from options/message_sets_protolegacy_test.go rename to internal/messageset/messageset_protolegacy_test.go index a09aca93..c5c712cc 100644 --- a/options/message_sets_protolegacy_test.go +++ b/internal/messageset/messageset_protolegacy_test.go @@ -15,7 +15,7 @@ //go:build protolegacy // +build protolegacy -package options +package messageset import ( "testing" @@ -25,5 +25,5 @@ import ( func TestCanSerializeMessageSets(t *testing.T) { t.Parallel() - assert.True(t, canSerializeMessageSets()) + assert.True(t, CanSupportMessageSets()) } diff --git a/options/message_sets_test.go b/internal/messageset/messageset_test.go similarity index 92% rename from options/message_sets_test.go rename to internal/messageset/messageset_test.go index aaca726f..0b3c2bc3 100644 --- a/options/message_sets_test.go +++ b/internal/messageset/messageset_test.go @@ -15,7 +15,7 @@ //go:build !protolegacy // +build !protolegacy -package options +package messageset import ( "testing" @@ -25,5 +25,5 @@ import ( func TestCanSerializeMessageSets(t *testing.T) { t.Parallel() - assert.False(t, canSerializeMessageSets()) + assert.False(t, CanSupportMessageSets()) } diff --git a/linker/linker_test.go b/linker/linker_test.go index 42757981..97e3d885 100644 --- a/linker/linker_test.go +++ b/linker/linker_test.go @@ -30,14 +30,18 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "github.com/bufbuild/protocompile" "github.com/bufbuild/protocompile/internal/editions" + "github.com/bufbuild/protocompile/internal/messageset" "github.com/bufbuild/protocompile/internal/protoc" "github.com/bufbuild/protocompile/internal/prototest" "github.com/bufbuild/protocompile/linker" + "github.com/bufbuild/protocompile/protoutil" "github.com/bufbuild/protocompile/reporter" ) @@ -125,6 +129,7 @@ func TestLinkerValidation(t *testing.T) { // Expected error message - leave empty if input is expected to succeed expectedErr string expectedDiffWithProtoc bool + expectProtodescFail bool }{ "success_multi_namespace": { input: map[string]string{ @@ -485,6 +490,7 @@ func TestLinkerValidation(t *testing.T) { input: map[string]string{ "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to 100; } extend Foo { optional Foo bar = 1; }", }, + expectProtodescFail: !messageset.CanSupportMessageSets(), }, "failure_tag_out_of_range": { input: map[string]string{ @@ -496,6 +502,7 @@ func TestLinkerValidation(t *testing.T) { input: map[string]string{ "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to max; } extend Foo { optional Foo bar = 536870912; }", }, + expectProtodescFail: !messageset.CanSupportMessageSets(), }, "failure_message_set_wire_format_repeated": { input: map[string]string{ @@ -1566,6 +1573,10 @@ func TestLinkerValidation(t *testing.T) { string FOO_BAR = 2; }`, }, + // protodesc.NewFile is applying overly strict checks on name + // collisions in proto3 files. + // https://github.com/golang/protobuf/issues/1616 + expectProtodescFail: true, }, "failure_json_name_conflict_leading_underscores": { input: map[string]string{ @@ -3354,7 +3365,7 @@ func TestLinkerValidation(t *testing.T) { for filename, data := range tc.input { tc.input[filename] = removePrefixIndent(data) } - _, errs := compile(t, tc.input) + files, errs := compile(t, tc.input) actualErrs := make([]string, len(errs)) for i := range errs { @@ -3413,6 +3424,17 @@ func TestLinkerValidation(t *testing.T) { } } + // Make sure protobuf-go can handle resulting files + if len(errs) == 0 && len(files) > 0 { + err := convertToProtoreflectDescriptors(files) + if tc.expectProtodescFail { + // This is a known case where it cannot handle the file. + require.Error(t, err) + } else { + require.NoError(t, err) + } + } + // parse with protoc passProtoc := testByProtoc(t, tc.input, tc.inputOrder) if tc.expectedErr == "" { @@ -3984,3 +4006,28 @@ func testByProtoc(t *testing.T, files map[string]string, fileNames []string) boo require.NoError(t, err) return true } + +func convertToProtoreflectDescriptors(files linker.Files) error { + allFiles := make(map[string]*descriptorpb.FileDescriptorProto, len(files)) + addFileDescriptorsToMap(files, allFiles) + fileSlice := make([]*descriptorpb.FileDescriptorProto, 0, len(allFiles)) + for _, fileProto := range allFiles { + fileSlice = append(fileSlice, fileProto) + } + _, err := protodesc.NewFiles(&descriptorpb.FileDescriptorSet{File: fileSlice}) + return err +} + +func addFileDescriptorsToMap[F protoreflect.FileDescriptor](files []F, allFiles map[string]*descriptorpb.FileDescriptorProto) { + for _, file := range files { + if _, exists := allFiles[file.Path()]; exists { + continue // already added this one + } + allFiles[file.Path()] = protoutil.ProtoFromFileDescriptor(file) + deps := make([]protoreflect.FileDescriptor, file.Imports().Len()) + for i := 0; i < file.Imports().Len(); i++ { + deps[i] = file.Imports().Get(i).FileDescriptor + } + addFileDescriptorsToMap(deps, allFiles) + } +} diff --git a/options/options.go b/options/options.go index 1e822531..266a35c1 100644 --- a/options/options.go +++ b/options/options.go @@ -41,6 +41,7 @@ import ( "github.com/bufbuild/protocompile/ast" "github.com/bufbuild/protocompile/internal" + "github.com/bufbuild/protocompile/internal/messageset" "github.com/bufbuild/protocompile/linker" "github.com/bufbuild/protocompile/parser" "github.com/bufbuild/protocompile/reporter" @@ -659,7 +660,7 @@ func (interp *interpreter) checkFieldUsage( node ast.Node, ) error { msgOpts, _ := fld.ContainingMessage().Options().(*descriptorpb.MessageOptions) - if msgOpts.GetMessageSetWireFormat() && !canSerializeMessageSets() { + if msgOpts.GetMessageSetWireFormat() && !messageset.CanSupportMessageSets() { err := interp.reporter.HandleErrorf(interp.nodeInfo(node), "field %q may not be used in an option: it uses 'message set wire format' legacy proto1 feature which is not supported", fld.FullName()) if err != nil { return err