Skip to content
This repository has been archived by the owner on May 21, 2024. It is now read-only.

Introducing the reflection metadata param #51

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 6 additions & 109 deletions grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import (
"github.com/grafana/xk6-grpc/lib/netext/grpcext"
"go.k6.io/k6/js/common"
"go.k6.io/k6/js/modules"
"go.k6.io/k6/lib/types"

"github.com/dop251/goja"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/protoparse"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
Expand Down Expand Up @@ -205,13 +205,13 @@ func buildTLSConfigFromMap(parentConfig *tls.Config, tlsConfigMap map[string]int
}

// Connect is a block dial to the gRPC server at the given address (host:port)
func (c *Client) Connect(addr string, params map[string]interface{}) (bool, error) {
func (c *Client) Connect(addr string, params goja.Value) (bool, error) {
state := c.vu.State()
if state == nil {
return false, common.NewInitContextError("connecting to a gRPC server in the init context is not supported")
}

p, err := c.parseConnectParams(params)
p, err := newConnectParams(c.vu, params)
if err != nil {
return false, fmt.Errorf("invalid grpc.connect() parameters: %w", err)
}
Expand Down Expand Up @@ -258,6 +258,9 @@ func (c *Client) Connect(addr string, params map[string]interface{}) (bool, erro
if !p.UseReflectionProtocol {
return true, nil
}

ctx = metadata.NewOutgoingContext(ctx, p.ReflectionMetadata)

fdset, err := c.conn.Reflect(ctx)
if err != nil {
return false, err
Expand Down Expand Up @@ -418,112 +421,6 @@ func (c *Client) convertToMethodInfo(fdset *descriptorpb.FileDescriptorSet) ([]M
return rtn, nil
}

type connectParams struct {
IsPlaintext bool
UseReflectionProtocol bool
Timeout time.Duration
MaxReceiveSize int64
MaxSendSize int64
TLS map[string]interface{}
}

func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams, error) {
params := connectParams{
IsPlaintext: false,
UseReflectionProtocol: false,
Timeout: time.Minute,
MaxReceiveSize: 0,
MaxSendSize: 0,
}
for k, v := range raw {
switch k {
case "plaintext":
var ok bool
params.IsPlaintext, ok = v.(bool)
if !ok {
return params, fmt.Errorf("invalid plaintext value: '%#v', it needs to be boolean", v)
}
case "timeout":
var err error
params.Timeout, err = types.GetDurationValue(v)
if err != nil {
return params, fmt.Errorf("invalid timeout value: %w", err)
}
case "reflect":
var ok bool
params.UseReflectionProtocol, ok = v.(bool)
if !ok {
return params, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v)
}
case "maxReceiveSize":
var ok bool
params.MaxReceiveSize, ok = v.(int64)
if !ok {
return params, fmt.Errorf("invalid maxReceiveSize value: '%#v', it needs to be an integer", v)
}
if params.MaxReceiveSize < 0 {
return params, fmt.Errorf("invalid maxReceiveSize value: '%#v, it needs to be a positive integer", v)
}
case "maxSendSize":
var ok bool
params.MaxSendSize, ok = v.(int64)
if !ok {
return params, fmt.Errorf("invalid maxSendSize value: '%#v', it needs to be an integer", v)
}
if params.MaxSendSize < 0 {
return params, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v)
}
case "tls":
if err := parseConnectTLSParam(&params, v); err != nil {
return params, err
}
default:
return params, fmt.Errorf("unknown connect param: %q", k)
}
}
return params, nil
}

func parseConnectTLSParam(params *connectParams, v interface{}) error {
var ok bool
params.TLS, ok = v.(map[string]interface{})

if !ok {
return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v)
}
// optional map keys below
if cert, certok := params.TLS["cert"]; certok {
if _, ok = cert.(string); !ok {
return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v)
}
}
if key, keyok := params.TLS["key"]; keyok {
if _, ok = key.(string); !ok {
return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v)
}
}
if pass, passok := params.TLS["password"]; passok {
if _, ok = pass.(string); !ok {
return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v)
}
}
if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok {
var cacertsArray []interface{}
if cacertsArray, ok = cacerts.([]interface{}); ok {
for _, cacertsArrayEntry := range cacertsArray {
if _, ok = cacertsArrayEntry.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
} else if _, ok = cacerts.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
return nil
}

func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*descriptorpb.FileDescriptorProto {
fds := []*descriptorpb.FileDescriptorProto{}

Expand Down
44 changes: 44 additions & 0 deletions grpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/wrapperspb"
"gopkg.in/guregu/null.v3"

"github.com/golang/protobuf/ptypes/any"
"github.com/golang/protobuf/ptypes/wrappers"
Expand Down Expand Up @@ -1128,6 +1129,49 @@ func TestDebugStat(t *testing.T) {
}
}

func TestClientConnectionReflectMetadata(t *testing.T) {
t.Parallel()

ts := newTestState(t)

reflection.Register(ts.httpBin.ServerGRPC)

initString := codeBlock{
code: `var client = new grpc.Client();`,
}
vuString := codeBlock{
code: `client.connect("GRPCBIN_ADDR", {reflect: true, reflectMetadata: {"x-test": "custom-header-for-reflection"}})`,
}

val, err := ts.Run(initString.code)
assertResponse(t, initString, err, val, ts)

ts.ToVUContext()

// this should trigger logging of the outgoing gRPC metadata
ts.VU.State().Options.HTTPDebug = null.NewString("full", true)

val, err = ts.Run(vuString.code)
assertResponse(t, vuString, err, val, ts)

entries := ts.loggerHook.Drain()

// since we enable debug logging, we should see the metadata in the logs
foundReflectionCall := false
for _, entry := range entries {
if strings.Contains(entry.Message, "ServerReflection/ServerReflectionInfo") {
foundReflectionCall = true

// check that the metadata is present
assert.Contains(t, entry.Message, "x-test: custom-header-for-reflection")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. We do check that specific RPC call contains the headers. The method that you're referring to works for the less complex cases.

// check that user-agent header is present
assert.Contains(t, entry.Message, "user-agent: k6-test")
}
}

assert.True(t, foundReflectionCall, "expected to find a reflection call in the logs, but didn't")
}

func TestClientLoadProto(t *testing.T) {
t.Parallel()

Expand Down
Loading
Loading