Skip to content

Commit

Permalink
Add an optional implementation of streams using generics (#7057)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarongable authored May 3, 2024
1 parent a87e923 commit bb9882e
Show file tree
Hide file tree
Showing 21 changed files with 462 additions and 906 deletions.
4 changes: 2 additions & 2 deletions balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

103 changes: 84 additions & 19 deletions cmd/protoc-gen-go-grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,13 @@ func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.

g.P("// This is a compile-time assertion to ensure that this generated file")
g.P("// is compatible with the grpc package it is being compiled against.")
g.P("// Requires gRPC-Go v1.62.0 or later.")
g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion8")) // When changing, update version number above.
if *useGenericStreams {
g.P("// Requires gRPC-Go v1.64.0 or later.")
g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion9"))
} else {
g.P("// Requires gRPC-Go v1.62.0 or later.")
g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion8")) // When changing, update version number above.
}
g.P()
for _, service := range file.Services {
genService(gen, file, g, service)
Expand Down Expand Up @@ -299,12 +304,27 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
} else {
s += method.Parent.GoName + "_" + method.GoName + "Client"
if *useGenericStreams {
s += clientStreamInterface(g, method)
} else {
s += method.Parent.GoName + "_" + method.GoName + "Client"
}
}
s += ", error)"
return s
}

func clientStreamInterface(g *protogen.GeneratedFile, method *protogen.Method) string {
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() {
return g.QualifiedGoIdent(grpcPackage.Ident("BidiStreamingClient")) + "[" + typeParam + "]"
} else if method.Desc.IsStreamingClient() {
return g.QualifiedGoIdent(grpcPackage.Ident("ClientStreamingClient")) + "[" + typeParam + "]"
} else { // i.e. if method.Desc.IsStreamingServer()
return g.QualifiedGoIdent(grpcPackage.Ident("ServerStreamingClient")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
}
}

func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
service := method.Parent
fmSymbol := helper.formatFullMethodSymbol(service, method)
Expand All @@ -323,11 +343,17 @@ func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
g.P()
return
}
streamType := unexport(service.GoName) + method.GoName + "Client"

streamImpl := unexport(service.GoName) + method.GoName + "Client"
if *useGenericStreams {
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
streamImpl = g.QualifiedGoIdent(grpcPackage.Ident("GenericClientStream")) + "[" + typeParam + "]"
}

serviceDescVar := service.GoName + "_ServiceDesc"
g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], `, fmSymbol, `, cOpts...)`)
g.P("if err != nil { return nil, err }")
g.P("x := &", streamType, "{stream}")
g.P("x := &", streamImpl, "{ClientStream: stream}")
if !method.Desc.IsStreamingClient() {
g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
Expand All @@ -336,11 +362,20 @@ func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
g.P("}")
g.P()

// Auxiliary types aliases, for backwards compatibility.
if *useGenericStreams {
g.P("// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.")
g.P("type ", service.GoName, "_", method.GoName, "Client = ", clientStreamInterface(g, method))
g.P()
return
}

// Stream auxiliary types and methods, if we're not taking advantage of the
// pre-implemented generic types and their methods.
genSend := method.Desc.IsStreamingClient()
genRecv := method.Desc.IsStreamingServer()
genCloseAndRecv := !method.Desc.IsStreamingServer()

// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
if genSend {
g.P("Send(*", method.Input.GoIdent, ") error")
Expand All @@ -355,27 +390,27 @@ func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
g.P("}")
g.P()

g.P("type ", streamType, " struct {")
g.P("type ", streamImpl, " struct {")
g.P(grpcPackage.Ident("ClientStream"))
g.P("}")
g.P()

if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
g.P("func (x *", streamImpl, ") Send(m *", method.Input.GoIdent, ") error {")
g.P("return x.ClientStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
g.P("func (x *", streamImpl, ") Recv() (*", method.Output.GoIdent, ", error) {")
g.P("m := new(", method.Output.GoIdent, ")")
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
if genCloseAndRecv {
g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
g.P("func (x *", streamImpl, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
g.P("m := new(", method.Output.GoIdent, ")")
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
Expand All @@ -396,7 +431,11 @@ func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string
reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
}
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
if *useGenericStreams {
reqArgs = append(reqArgs, serverStreamInterface(g, method))
} else {
reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
}
}
return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
}
Expand Down Expand Up @@ -442,6 +481,17 @@ func genServiceDesc(file *protogen.File, g *protogen.GeneratedFile, serviceDescV
g.P()
}

func serverStreamInterface(g *protogen.GeneratedFile, method *protogen.Method) string {
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() {
return g.QualifiedGoIdent(grpcPackage.Ident("BidiStreamingServer")) + "[" + typeParam + "]"
} else if method.Desc.IsStreamingClient() {
return g.QualifiedGoIdent(grpcPackage.Ident("ClientStreamingServer")) + "[" + typeParam + "]"
} else { // i.e. if method.Desc.IsStreamingServer()
return g.QualifiedGoIdent(grpcPackage.Ident("ServerStreamingServer")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
}
}

func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, hnameFuncNameFormatter func(string) string) string {
service := method.Parent
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
Expand All @@ -464,23 +514,38 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
g.P()
return hname
}
streamType := unexport(service.GoName) + method.GoName + "Server"

streamImpl := unexport(service.GoName) + method.GoName + "Server"
if *useGenericStreams {
typeParam := g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent)
streamImpl = g.QualifiedGoIdent(grpcPackage.Ident("GenericServerStream")) + "[" + typeParam + "]"
}

g.P("func ", hnameFuncNameFormatter(hname), "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
if !method.Desc.IsStreamingClient() {
g.P("m := new(", method.Input.GoIdent, ")")
g.P("if err := stream.RecvMsg(m); err != nil { return err }")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamImpl, "{ServerStream: stream})")
} else {
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamImpl, "{ServerStream: stream})")
}
g.P("}")
g.P()

// Auxiliary types aliases, for backwards compatibility.
if *useGenericStreams {
g.P("// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.")
g.P("type ", service.GoName, "_", method.GoName, "Server = ", serverStreamInterface(g, method))
g.P()
return hname
}

// Stream auxiliary types and methods, if we're not taking advantage of the
// pre-implemented generic types and their methods.
genSend := method.Desc.IsStreamingServer()
genSendAndClose := !method.Desc.IsStreamingServer()
genRecv := method.Desc.IsStreamingClient()

// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
if genSend {
g.P("Send(*", method.Output.GoIdent, ") error")
Expand All @@ -495,25 +560,25 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
g.P("}")
g.P()

g.P("type ", streamType, " struct {")
g.P("type ", streamImpl, " struct {")
g.P(grpcPackage.Ident("ServerStream"))
g.P("}")
g.P()

if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
g.P("func (x *", streamImpl, ") Send(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genSendAndClose {
g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
g.P("func (x *", streamImpl, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
g.P("func (x *", streamImpl, ") Recv() (*", method.Input.GoIdent, ", error) {")
g.P("m := new(", method.Input.GoIdent, ")")
g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
Expand Down
2 changes: 2 additions & 0 deletions cmd/protoc-gen-go-grpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
const version = "1.3.0"

var requireUnimplemented *bool
var useGenericStreams *bool

func main() {
showVersion := flag.Bool("version", false, "print the version and exit")
Expand All @@ -55,6 +56,7 @@ func main() {

var flags flag.FlagSet
requireUnimplemented = flags.Bool("require_unimplemented_servers", true, "set to false to match legacy behavior")
useGenericStreams = flags.Bool("use_generic_streams_experimental", false, "set to true to use generic types for streaming client and server objects; this flag is EXPERIMENTAL and may be changed or removed in a future release")

protogen.Options{
ParamFunc: flags.Set,
Expand Down
2 changes: 1 addition & 1 deletion cmd/protoc-gen-go-grpc/protoc-gen-go-grpc_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ popd

protoc \
--go-grpc_out="${TEMPDIR}" \
--go-grpc_opt=paths=source_relative \
--go-grpc_opt=paths=source_relative,use_generic_streams_experimental=true \
"examples/route_guide/routeguide/route_guide.proto"

GOLDENFILE="examples/route_guide/routeguide/route_guide_grpc.pb.go"
Expand Down
62 changes: 12 additions & 50 deletions credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit bb9882e

Please sign in to comment.