diff --git a/proto/compiler.bzl b/proto/compiler.bzl index 3601c7bf85..3ad7dca3d6 100644 --- a/proto/compiler.bzl +++ b/proto/compiler.bzl @@ -87,14 +87,24 @@ def go_proto_compile(go, compiler, protos, imports, importpath): continue proto_paths[path] = src - out = go.declare_file( - go, - path = importpath + "/" + src.basename[:-len(".proto")], - ext = compiler.internal.suffix, - ) - go_srcs.append(out) + if compiler.internal.suffixes: + for suffix in compiler.internal.suffixes: + out = go.declare_file( + go, + path = importpath + "/" + src.basename[:-len(".proto")], + ext = suffix, + ) + go_srcs.append(out) + else: + out = go.declare_file( + go, + path = importpath + "/" + src.basename[:-len(".proto")], + ext = compiler.internal.suffix, + ) + go_srcs.append(out) + if outpath == None: - outpath = out.dirname[:-len(importpath)] + outpath = go_srcs[0].dirname[:-len(importpath)] transitive_descriptor_sets = depset(direct = [], transitive = desc_sets) @@ -174,6 +184,7 @@ def _go_proto_compiler_impl(ctx): internal = struct( options = ctx.attr.options, suffix = ctx.attr.suffix, + suffixes = ctx.attr.suffixes, protoc = ctx.executable._protoc, go_protoc = ctx.executable._go_protoc, plugin = ctx.executable.plugin, @@ -190,6 +201,7 @@ _go_proto_compiler = rule( "deps": attr.label_list(providers = [GoLibrary]), "options": attr.string_list(), "suffix": attr.string(default = ".pb.go"), + "suffixes": attr.string_list(), "valid_archive": attr.bool(default = True), "import_path_option": attr.bool(default = False), "plugin": attr.label( diff --git a/tests/core/go_proto_library/BUILD.bazel b/tests/core/go_proto_library/BUILD.bazel index 0c1f8e23c6..c38953b049 100644 --- a/tests/core/go_proto_library/BUILD.bazel +++ b/tests/core/go_proto_library/BUILD.bazel @@ -32,6 +32,14 @@ proto_library( srcs = ["grpc.proto"], ) +proto_library( + name = "enum_proto", + srcs = ["enum.proto"], + deps = [ + "@com_google_protobuf//:descriptor_proto", + ], +) + # embed_test go_proto_library( name = "embed_go_proto", @@ -199,6 +207,27 @@ go_proto_library( protos = [":grpc_proto"], ) +# compilers with multiple suffixes +go_test( + name = "compilers_multi_suffix_test", + srcs = ["compiler_multi_suffix_test.go"], + deps = [ + ":compilers_multi_suffix", + ], +) + +go_proto_library( + name = "compilers_multi_suffix", + compilers = ["//tests/core/go_proto_library/compilers:dbenum_compiler"], + importpath = "github.com/bazelbuild/rules_go/tests/core/go_proto_library/enum", + protos = [":enum_proto"], + deps = [ + "@com_github_gogo_protobuf//proto", + "@com_github_gogo_protobuf//protoc-gen-gogo/descriptor", + "@com_github_gogo_protobuf//types", + ], +) + # adjusted_import_test # TODO(#1851): uncomment when Bazel 0.22.0 is the minimum version. # go_test( diff --git a/tests/core/go_proto_library/compiler_multi_suffix_test.go b/tests/core/go_proto_library/compiler_multi_suffix_test.go new file mode 100644 index 0000000000..bd43aab2da --- /dev/null +++ b/tests/core/go_proto_library/compiler_multi_suffix_test.go @@ -0,0 +1,38 @@ +/* Copyright 2019 The Bazel Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package multi_suffix_compiler + +import ( + "testing" + + "github.com/bazelbuild/rules_go/tests/core/go_proto_library/enum" +) + +func use(interface{}) {} + +func TestMultiSuffixCompiler(t *testing.T) { + // just make sure types and generated functions exist + v := enum.Enum_BYTES + expected := "bytes_type" + if v.String() != expected { + panic(v.String()) + } + v = enum.Enum_INT32 + expected = "INT32" + if v.String() != expected { + panic(v.String()) + } +} diff --git a/tests/core/go_proto_library/compilers/BUILD.bazel b/tests/core/go_proto_library/compilers/BUILD.bazel new file mode 100644 index 0000000000..aeaedd563d --- /dev/null +++ b/tests/core/go_proto_library/compilers/BUILD.bazel @@ -0,0 +1,51 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") +load( + "//proto:compiler.bzl", + "go_proto_compiler", +) +load( + "//proto/wkt:well_known_types.bzl", + "GOGO_WELL_KNOWN_TYPE_REMAPS", +) + +go_library( + name = "dbenum", + srcs = ["dbenums.go"], + importpath = "github.com/bazelbuild/rules_go/tests/core/go_proto_library/compilers", + visibility = ["//visibility:private"], + deps = [ + "@com_github_gogo_protobuf//proto", + "@com_github_gogo_protobuf//protoc-gen-gogo/descriptor", + "@com_github_gogo_protobuf//protoc-gen-gogo/generator", + ], +) + +go_library( + name = "protoc-gen-dbenum_lib", + srcs = ["main.go"], + importpath = "github.com/bazelbuild/rules_go/tests/core/go_proto_library/compilers", + visibility = ["//visibility:private"], + deps = [ + "//tests/core/go_proto_library/compilers:dbenum", + "@com_github_gogo_protobuf//protoc-gen-gogo/plugin", + "@com_github_gogo_protobuf//vanity", + "@com_github_gogo_protobuf//vanity/command", + ], +) + +go_binary( + name = "protoc-gen-dbenum-compiler", + embed = [":protoc-gen-dbenum_lib"], + visibility = ["//visibility:private"], +) + +go_proto_compiler( + name = "dbenum_compiler", + options = GOGO_WELL_KNOWN_TYPE_REMAPS, + plugin = "//tests/core/go_proto_library/compilers:protoc-gen-dbenum-compiler", + suffixes = [ + "_dbenum.pb.go", + ".pb.go", + ], + visibility = ["//visibility:public"], +) diff --git a/tests/core/go_proto_library/compilers/dbenums.go b/tests/core/go_proto_library/compilers/dbenums.go new file mode 100644 index 0000000000..5af5e90177 --- /dev/null +++ b/tests/core/go_proto_library/compilers/dbenums.go @@ -0,0 +1,189 @@ +package dbenum + +import ( + "bytes" + "strings" + "text/template" + + "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + pb "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +func init() { + generator.RegisterPlugin(NewGenerator()) +} + +type Generator struct { + *generator.Generator + generator.PluginImports + write bool +} + +func NewGenerator() *Generator { + return &Generator{} +} + +func (g *Generator) Name() string { + return "dbenum" +} + +func (g *Generator) Init(gen *generator.Generator) { + g.Generator = gen +} + +func (g *Generator) GenerateImports(file *generator.FileDescriptor) { +} + +func (g *Generator) Generate(file *generator.FileDescriptor) { + for _, enum := range file.Enums() { + g.enumHelper(enum) + } + g.writeTrailer(file.Enums()) +} + +func (g *Generator) Write() bool { + return g.write +} + +const initTmpl = ` +` + +func (g *Generator) writeTrailer(enums []*generator.EnumDescriptor) { + type desc struct { + PackageName string + TypeName string + LowerCaseTypeName string + } + if !g.write { + return + } + tmpl := template.Must(template.New("db_enum_trailer").Parse(initTmpl)) + g.P("func init() {") + for _, e := range enums { + if !HasDBEnum(e.Value) { + continue + } + pkg := e.File().GetPackage() + if pkg != "" { + pkg += "." + } + tp := generator.CamelCaseSlice(e.TypeName()) + var buf bytes.Buffer + tmpl.Execute(&buf, desc{ + PackageName: pkg + tp, + TypeName: tp, + LowerCaseTypeName: strings.ToLower(tp), + }) + g.P(buf.String()) + } + g.P("}") +} + +func (g *Generator) enumHelper(enum *generator.EnumDescriptor) { + type anEnum struct { + PBName string + DBName string + } + type typeDesc struct { + TypeName string + TypeNamespace string + LowerCaseTypeName string + Found map[int32]bool + Names []anEnum + AllNames []anEnum + } + tp := generator.CamelCaseSlice(enum.TypeName()) + namespace := tp + enumTypeName := enum.TypeName() + if len(enumTypeName) > 1 { // This is a nested enum. + names := enumTypeName[:len(enumTypeName)-1] + // See https://protobuf.dev/reference/go/go-generated/#enum + namespace = generator.CamelCaseSlice(names) + } + t := typeDesc{ + TypeName: tp, + TypeNamespace: namespace, + LowerCaseTypeName: strings.ToLower(tp), + Found: make(map[int32]bool), + } + for _, v := range enum.Value { + enumValue := v.GetNumber() + if validDbEnum, dbName := getDbEnum(v); validDbEnum { + names := anEnum{PBName: v.GetName(), DBName: dbName} + t.AllNames = append(t.AllNames, names) + // Skip enums that are aliased where one value has already been processed. + if t.Found[enumValue] { + continue + } + t.Found[enumValue] = true + t.Names = append(t.Names, names) + } else { + t.Found[enumValue] = true + } + } + if len(t.AllNames) == 0 { + return + } + g.write = true + tmpl := template.Must(template.New("db_enum").Parse(tmpl)) + var buf bytes.Buffer + tmpl.Execute(&buf, t) + g.P(buf.String()) +} + +var E_DbEnum = &proto.ExtensionDesc{ + ExtendedType: (*descriptor.EnumValueOptions)(nil), + ExtensionType: (*string)(nil), + Field: 5002, + Name: "tests.core.go_proto_library.enum", + Tag: "bytes,5002,opt,name=db_enum", +} + +func getDbEnum(value *pb.EnumValueDescriptorProto) (bool, string) { + if value == nil || value.Options == nil { + return false, "" + } + EDbEnum := E_DbEnum + v, err := proto.GetExtension(value.Options, EDbEnum) + if err != nil { + return false, "" + } + strPtr := v.(*string) + if strPtr == nil { + return false, "" + } + return true, *strPtr +} + +// HasDBEnum returns if there is DBEnums extensions defined in given enums. +func HasDBEnum(enums []*pb.EnumValueDescriptorProto) bool { + for _, enum := range enums { + if validDbEnum, _ := getDbEnum(enum); validDbEnum { + return true + } + } + return false +} + +const tmpl = ` + +var {{ .LowerCaseTypeName }}ToStringValue = ` + + `map[{{ .TypeName }}]string { {{ range $names := .Names }} + {{ $.TypeNamespace }}_{{ $names.PBName }}: ` + + `"{{ $names.DBName }}",{{ end }} +} + + +// String implements the stringer interface and should produce the same output +// that is inserted into the db. +func (v {{ .TypeName }}) String() string { + if val, ok := {{ .LowerCaseTypeName }}ToStringValue[v]; ok { + return val + } else if int(v) == 0 { + return "null" + } else { + return proto.EnumName({{ .TypeName }}_name, int32(v)) + } +}` diff --git a/tests/core/go_proto_library/compilers/main.go b/tests/core/go_proto_library/compilers/main.go new file mode 100644 index 0000000000..6b092c4b19 --- /dev/null +++ b/tests/core/go_proto_library/compilers/main.go @@ -0,0 +1,53 @@ +package main + +import ( + plugin "github.com/gogo/protobuf/protoc-gen-gogo/plugin" + "github.com/gogo/protobuf/vanity" + "github.com/gogo/protobuf/vanity/command" + + dbenum "github.com/bazelbuild/rules_go/tests/core/go_proto_library/compilers" +) + +func main() { + req := command.Read() + files := req.GetProtoFile() + files = vanity.FilterFiles(files, vanity.NotGoogleProtobufDescriptorProto) + + vanity.ForEachFile(files, vanity.TurnOffGoStringerAll) + vanity.ForEachFile(files, vanity.TurnOffGoEnumStringerAll) + + resp := command.Generate(req) + command.Write(resp) + + baseFiles := req.FileToGenerate + + dbenumGenerator := dbenum.NewGenerator() + req = onlyEnumFiles(req, baseFiles) + if len(req.FileToGenerate) > 0 { + resp = command.GeneratePlugin(req, dbenumGenerator, "_dbenum.pb.go") + command.Write(resp) + } +} + +func onlyEnumFiles( + req *plugin.CodeGeneratorRequest, baseFiles []string, +) *plugin.CodeGeneratorRequest { + // Find out files that contains enum value with dbenum extension. + dbEnumFiles := make(map[string]bool) + for _, file := range req.GetProtoFile() { + for _, enum := range file.EnumType { + if dbenum.HasDBEnum(enum.Value) { + dbEnumFiles[*file.Name] = true + break + } + } + } + enumFilesToGenerate := make([]string, 0, len(baseFiles)) + for _, file := range baseFiles { + if dbEnumFiles[file] { + enumFilesToGenerate = append(enumFilesToGenerate, file) + } + } + req.FileToGenerate = enumFilesToGenerate + return req +} diff --git a/tests/core/go_proto_library/enum.proto b/tests/core/go_proto_library/enum.proto new file mode 100644 index 0000000000..ffed47293c --- /dev/null +++ b/tests/core/go_proto_library/enum.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; +package tests.core.go_proto_library.enum; +option go_package = "github.com/bazelbuild/rules_go/tests/core/go_proto_library/enum"; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.EnumValueOptions { + optional string db_enum = 5002; +} + + +enum Enum { + INVALID = 0; + BYTES = 1 [(tests.core.go_proto_library.enum.db_enum) = "bytes_type"]; + INT32 = 2; +}