From 042668a122473ce999501f793d939cc50bf60405 Mon Sep 17 00:00:00 2001 From: Oleg Bespalov Date: Thu, 14 Sep 2023 10:49:08 +0200 Subject: [PATCH 1/3] Refactoring of the connection params This minor refactoring is a pre-requisition of the following extraction of the logic of parsing a gRPC metadata. --- grpc/client.go | 111 +-------------------------------------------- grpc/params.go | 121 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 110 deletions(-) diff --git a/grpc/client.go b/grpc/client.go index 7607d9c..8a56939 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/xk6-grpc/lib/netext/grpcext" "go.k6.io/k6/js/common" "go.k6.io/k6/js/modules" - "go.k6.io/k6/lib/types" "github.com/dop251/goja" "github.com/jhump/protoreflect/desc" @@ -205,13 +204,13 @@ func buildTLSConfigFromMap(parentConfig *tls.Config, tlsConfigMap map[string]int } // Connect is a block dial to the gRPC server at the given address (host:port) -func (c *Client) Connect(addr string, params map[string]interface{}) (bool, error) { +func (c *Client) Connect(addr string, params goja.Value) (bool, error) { state := c.vu.State() if state == nil { return false, common.NewInitContextError("connecting to a gRPC server in the init context is not supported") } - p, err := c.parseConnectParams(params) + p, err := newConnectParams(c.vu, params) if err != nil { return false, fmt.Errorf("invalid grpc.connect() parameters: %w", err) } @@ -418,112 +417,6 @@ func (c *Client) convertToMethodInfo(fdset *descriptorpb.FileDescriptorSet) ([]M return rtn, nil } -type connectParams struct { - IsPlaintext bool - UseReflectionProtocol bool - Timeout time.Duration - MaxReceiveSize int64 - MaxSendSize int64 - TLS map[string]interface{} -} - -func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams, error) { - params := connectParams{ - IsPlaintext: false, - UseReflectionProtocol: false, - Timeout: time.Minute, - MaxReceiveSize: 0, - MaxSendSize: 0, - } - for k, v := range raw { - switch k { - case "plaintext": - var ok bool - params.IsPlaintext, ok = v.(bool) - if !ok { - return params, fmt.Errorf("invalid plaintext value: '%#v', it needs to be boolean", v) - } - case "timeout": - var err error - params.Timeout, err = types.GetDurationValue(v) - if err != nil { - return params, fmt.Errorf("invalid timeout value: %w", err) - } - case "reflect": - var ok bool - params.UseReflectionProtocol, ok = v.(bool) - if !ok { - return params, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v) - } - case "maxReceiveSize": - var ok bool - params.MaxReceiveSize, ok = v.(int64) - if !ok { - return params, fmt.Errorf("invalid maxReceiveSize value: '%#v', it needs to be an integer", v) - } - if params.MaxReceiveSize < 0 { - return params, fmt.Errorf("invalid maxReceiveSize value: '%#v, it needs to be a positive integer", v) - } - case "maxSendSize": - var ok bool - params.MaxSendSize, ok = v.(int64) - if !ok { - return params, fmt.Errorf("invalid maxSendSize value: '%#v', it needs to be an integer", v) - } - if params.MaxSendSize < 0 { - return params, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v) - } - case "tls": - if err := parseConnectTLSParam(¶ms, v); err != nil { - return params, err - } - default: - return params, fmt.Errorf("unknown connect param: %q", k) - } - } - return params, nil -} - -func parseConnectTLSParam(params *connectParams, v interface{}) error { - var ok bool - params.TLS, ok = v.(map[string]interface{}) - - if !ok { - return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v) - } - // optional map keys below - if cert, certok := params.TLS["cert"]; certok { - if _, ok = cert.(string); !ok { - return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v) - } - } - if key, keyok := params.TLS["key"]; keyok { - if _, ok = key.(string); !ok { - return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v) - } - } - if pass, passok := params.TLS["password"]; passok { - if _, ok = pass.(string); !ok { - return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v) - } - } - if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok { - var cacertsArray []interface{} - if cacertsArray, ok = cacerts.([]interface{}); ok { - for _, cacertsArrayEntry := range cacertsArray { - if _, ok = cacertsArrayEntry.(string); !ok { - return fmt.Errorf("invalid tls cacerts value: '%#v',"+ - " it needs to be a string or an array of PEM formatted strings", v) - } - } - } else if _, ok = cacerts.(string); !ok { - return fmt.Errorf("invalid tls cacerts value: '%#v',"+ - " it needs to be a string or an array of PEM formatted strings", v) - } - } - return nil -} - func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*descriptorpb.FileDescriptorProto { fds := []*descriptorpb.FileDescriptorProto{} diff --git a/grpc/params.go b/grpc/params.go index 14e3af9..6c08c7d 100644 --- a/grpc/params.go +++ b/grpc/params.go @@ -31,7 +31,7 @@ func newCallParams(vu modules.VU, input goja.Value) (*callParams, error) { TagsAndMeta: vu.State().Tags.GetCurrentValues(), } - if input == nil || goja.IsUndefined(input) || goja.IsNull(input) { + if common.IsNullish(input) { return result, nil } @@ -46,6 +46,7 @@ func newCallParams(vu modules.VU, input goja.Value) (*callParams, error) { if !ok { return result, errors.New("metadata must be an object with key-value pairs") } + for hk, kv := range rawHeaders { var val string @@ -99,3 +100,121 @@ func (p *callParams) SetSystemTags(state *lib.State, addr string, methodName str p.TagsAndMeta.SetSystemTagOrMetaIfEnabled(state.Options.SystemTags, metrics.TagName, methodName) } } + +// connectParams is the parameters that can be passed to a gRPC connect call. +type connectParams struct { + IsPlaintext bool + UseReflectionProtocol bool + Timeout time.Duration + MaxReceiveSize int64 + MaxSendSize int64 + TLS map[string]interface{} +} + +func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { + result := &connectParams{ + IsPlaintext: false, + UseReflectionProtocol: false, + Timeout: time.Minute, + MaxReceiveSize: 0, + MaxSendSize: 0, + } + + if common.IsNullish(input) { + return result, nil + } + + rt := vu.Runtime() + params := input.ToObject(rt) + + for _, k := range params.Keys() { + v := params.Get(k).Export() + + switch k { + case "plaintext": + var ok bool + result.IsPlaintext, ok = v.(bool) + if !ok { + return result, fmt.Errorf("invalid plaintext value: '%#v', it needs to be boolean", v) + } + case "timeout": + var err error + result.Timeout, err = types.GetDurationValue(v) + if err != nil { + return result, fmt.Errorf("invalid timeout value: %w", err) + } + case "reflect": + var ok bool + result.UseReflectionProtocol, ok = v.(bool) + if !ok { + return result, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v) + } + case "maxReceiveSize": + var ok bool + result.MaxReceiveSize, ok = v.(int64) + if !ok { + return result, fmt.Errorf("invalid maxReceiveSize value: '%#v', it needs to be an integer", v) + } + if result.MaxReceiveSize < 0 { + return result, fmt.Errorf("invalid maxReceiveSize value: '%#v, it needs to be a positive integer", v) + } + case "maxSendSize": + var ok bool + result.MaxSendSize, ok = v.(int64) + if !ok { + return result, fmt.Errorf("invalid maxSendSize value: '%#v', it needs to be an integer", v) + } + if result.MaxSendSize < 0 { + return result, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v) + } + case "tls": + if err := parseConnectTLSParam(result, v); err != nil { + return result, err + } + default: + return result, fmt.Errorf("unknown connect param: %q", k) + } + } + + return result, nil +} + +func parseConnectTLSParam(params *connectParams, v interface{}) error { + var ok bool + params.TLS, ok = v.(map[string]interface{}) + + if !ok { + return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v) + } + // optional map keys below + if cert, certok := params.TLS["cert"]; certok { + if _, ok = cert.(string); !ok { + return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v) + } + } + if key, keyok := params.TLS["key"]; keyok { + if _, ok = key.(string); !ok { + return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v) + } + } + if pass, passok := params.TLS["password"]; passok { + if _, ok = pass.(string); !ok { + return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v) + } + } + if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok { + var cacertsArray []interface{} + if cacertsArray, ok = cacerts.([]interface{}); ok { + for _, cacertsArrayEntry := range cacertsArray { + if _, ok = cacertsArrayEntry.(string); !ok { + return fmt.Errorf("invalid tls cacerts value: '%#v',"+ + " it needs to be a string or an array of PEM formatted strings", v) + } + } + } else if _, ok = cacerts.(string); !ok { + return fmt.Errorf("invalid tls cacerts value: '%#v',"+ + " it needs to be a string or an array of PEM formatted strings", v) + } + } + return nil +} From c5913adef6b17ae12abca2d2bd925f2f4f16871f Mon Sep 17 00:00:00 2001 From: Oleg Bespalov Date: Thu, 14 Sep 2023 11:02:09 +0200 Subject: [PATCH 2/3] Extracting metadata construction logic Extracting logic + few more tests --- grpc/params.go | 64 +++++++++++++++++++++++++++++---------------- grpc/params_test.go | 46 ++++++++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 25 deletions(-) diff --git a/grpc/params.go b/grpc/params.go index 6c08c7d..04983ad 100644 --- a/grpc/params.go +++ b/grpc/params.go @@ -41,31 +41,12 @@ func newCallParams(vu modules.VU, input goja.Value) (*callParams, error) { for _, k := range params.Keys() { switch k { case "metadata": - v := params.Get(k).Export() - rawHeaders, ok := v.(map[string]interface{}) - if !ok { - return result, errors.New("metadata must be an object with key-value pairs") + md, err := newMetadata(params.Get(k)) + if err != nil { + return result, fmt.Errorf("invalid metadata param: %w", err) } - for hk, kv := range rawHeaders { - var val string - - // The gRPC spec defines that Binary-valued keys end in -bin - // https://grpc.io/docs/what-is-grpc/core-concepts/#metadata - if strings.HasSuffix(hk, "-bin") { - var binVal []byte - if binVal, ok = kv.([]byte); !ok { - return result, fmt.Errorf("metadata %q value must be binary", hk) - } - - // https://github.com/grpc/grpc-go/blob/v1.57.0/Documentation/grpc-metadata.md#storing-binary-data-in-metadata - val = string(binVal) - } else if val, ok = kv.(string); !ok { - return result, fmt.Errorf("metadata %q value must be a string", hk) - } - - result.Metadata.Append(hk, val) - } + result.Metadata = md case "tags": if err := common.ApplyCustomUserTags(rt, &result.TagsAndMeta, params.Get(k)); err != nil { return result, fmt.Errorf("metric tags: %w", err) @@ -85,6 +66,43 @@ func newCallParams(vu modules.VU, input goja.Value) (*callParams, error) { return result, nil } +// newMetadata constructs a metadata.MD from the input value. +func newMetadata(input goja.Value) (metadata.MD, error) { + md := metadata.New(nil) + + if common.IsNullish(input) { + return md, nil + } + + v := input.Export() + + rawHeaders, ok := v.(map[string]interface{}) + if !ok { + return md, errors.New("must be an object with key-value pairs") + } + + for hk, kv := range rawHeaders { + var val string + // The gRPC spec defines that Binary-valued keys end in -bin + // https://grpc.io/docs/what-is-grpc/core-concepts/#metadata + if strings.HasSuffix(hk, "-bin") { + var binVal []byte + if binVal, ok = kv.([]byte); !ok { + return md, fmt.Errorf("%q value must be binary", hk) + } + + // https://github.com/grpc/grpc-go/blob/v1.57.0/Documentation/grpc-metadata.md#storing-binary-data-in-metadata + val = string(binVal) + } else if val, ok = kv.(string); !ok { + return md, fmt.Errorf("%q value must be a string", hk) + } + + md.Append(hk, val) + } + + return md, nil +} + // SetSystemTags sets the system tags for the call. func (p *callParams) SetSystemTags(state *lib.State, addr string, methodName string) { if state.Options.SystemTags.Has(metrics.TagURL) { diff --git a/grpc/params_test.go b/grpc/params_test.go index 641cb4f..c43c6d9 100644 --- a/grpc/params_test.go +++ b/grpc/params_test.go @@ -13,10 +13,11 @@ import ( "go.k6.io/k6/js/modulestest" "go.k6.io/k6/lib" "go.k6.io/k6/metrics" + "google.golang.org/grpc/metadata" "gopkg.in/guregu/null.v3" ) -func TestParamsInvalidInput(t *testing.T) { +func TestCallParamsInvalidInput(t *testing.T) { t.Parallel() testCases := []struct { @@ -39,6 +40,11 @@ func TestParamsInvalidInput(t *testing.T) { JSON: `{ timeout: "please" }`, ErrContains: `invalid duration`, }, + { + Name: "InvalidMetadata", + JSON: `{ metadata: "lorem" }`, + ErrContains: `invalid metadata param: must be an object with key-value pairs`, + }, } for _, tc := range testCases { @@ -56,7 +62,43 @@ func TestParamsInvalidInput(t *testing.T) { } } -func TestParamsTimeOutParse(t *testing.T) { +func TestCallParamsMetadata(t *testing.T) { + t.Parallel() + + testCases := []struct { + Name string + JSON string + ExpectedMetadata metadata.MD + }{ + { + Name: "EmptyMetadata", + JSON: `{}`, + ExpectedMetadata: metadata.New(nil), + }, + { + Name: "Metadata", + JSON: `{metadata: {foo: "bar", baz: "qux"}}`, + ExpectedMetadata: metadata.New(map[string]string{"foo": "bar", "baz": "qux"}), + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + testRuntime, params := newParamsTestRuntime(t, tc.JSON) + + p, err := newCallParams(testRuntime.VU, params) + + require.NoError(t, err) + assert.Equal(t, tc.ExpectedMetadata, p.Metadata) + }) + } +} + +func TestCallParamsTimeOutParse(t *testing.T) { t.Parallel() testCases := []struct { From 350679e8e6a398ddf09d56df22b99931ff0c8fc9 Mon Sep 17 00:00:00 2001 From: Oleg Bespalov Date: Thu, 14 Sep 2023 13:33:17 +0200 Subject: [PATCH 3/3] Reflection metadata --- grpc/client.go | 4 ++++ grpc/client_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++ grpc/params.go | 11 ++++++++++- grpc/teststate_test.go | 6 ++++++ 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/grpc/client.go b/grpc/client.go index 8a56939..ef46ffc 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" @@ -257,6 +258,9 @@ func (c *Client) Connect(addr string, params goja.Value) (bool, error) { if !p.UseReflectionProtocol { return true, nil } + + ctx = metadata.NewOutgoingContext(ctx, p.ReflectionMetadata) + fdset, err := c.conn.Reflect(ctx) if err != nil { return false, err diff --git a/grpc/client_test.go b/grpc/client_test.go index 4f8585c..f76429f 100644 --- a/grpc/client_test.go +++ b/grpc/client_test.go @@ -16,6 +16,7 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/known/wrapperspb" + "gopkg.in/guregu/null.v3" "github.com/golang/protobuf/ptypes/any" "github.com/golang/protobuf/ptypes/wrappers" @@ -1128,6 +1129,49 @@ func TestDebugStat(t *testing.T) { } } +func TestClientConnectionReflectMetadata(t *testing.T) { + t.Parallel() + + ts := newTestState(t) + + reflection.Register(ts.httpBin.ServerGRPC) + + initString := codeBlock{ + code: `var client = new grpc.Client();`, + } + vuString := codeBlock{ + code: `client.connect("GRPCBIN_ADDR", {reflect: true, reflectMetadata: {"x-test": "custom-header-for-reflection"}})`, + } + + val, err := ts.Run(initString.code) + assertResponse(t, initString, err, val, ts) + + ts.ToVUContext() + + // this should trigger logging of the outgoing gRPC metadata + ts.VU.State().Options.HTTPDebug = null.NewString("full", true) + + val, err = ts.Run(vuString.code) + assertResponse(t, vuString, err, val, ts) + + entries := ts.loggerHook.Drain() + + // since we enable debug logging, we should see the metadata in the logs + foundReflectionCall := false + for _, entry := range entries { + if strings.Contains(entry.Message, "ServerReflection/ServerReflectionInfo") { + foundReflectionCall = true + + // check that the metadata is present + assert.Contains(t, entry.Message, "x-test: custom-header-for-reflection") + // check that user-agent header is present + assert.Contains(t, entry.Message, "user-agent: k6-test") + } + } + + assert.True(t, foundReflectionCall, "expected to find a reflection call in the logs, but didn't") +} + func TestClientLoadProto(t *testing.T) { t.Parallel() diff --git a/grpc/params.go b/grpc/params.go index 04983ad..d9fb913 100644 --- a/grpc/params.go +++ b/grpc/params.go @@ -123,19 +123,21 @@ func (p *callParams) SetSystemTags(state *lib.State, addr string, methodName str type connectParams struct { IsPlaintext bool UseReflectionProtocol bool + ReflectionMetadata metadata.MD Timeout time.Duration MaxReceiveSize int64 MaxSendSize int64 TLS map[string]interface{} } -func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { +func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { //nolint:gocognit result := &connectParams{ IsPlaintext: false, UseReflectionProtocol: false, Timeout: time.Minute, MaxReceiveSize: 0, MaxSendSize: 0, + ReflectionMetadata: metadata.New(nil), } if common.IsNullish(input) { @@ -167,6 +169,13 @@ func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { if !ok { return result, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v) } + case "reflectMetadata": + md, err := newMetadata(params.Get(k)) + if err != nil { + return result, fmt.Errorf("invalid reflectMetadata param: %w", err) + } + + result.ReflectionMetadata = md case "maxReceiveSize": var ok bool result.MaxReceiveSize, ok = v.(int64) diff --git a/grpc/teststate_test.go b/grpc/teststate_test.go index f3bc669..fbc8ae4 100644 --- a/grpc/teststate_test.go +++ b/grpc/teststate_test.go @@ -15,6 +15,7 @@ import ( "go.k6.io/k6/js/modulestest" "go.k6.io/k6/lib" "go.k6.io/k6/lib/fsext" + "go.k6.io/k6/lib/testutils" "go.k6.io/k6/lib/testutils/httpmultibin" "go.k6.io/k6/metrics" "gopkg.in/guregu/null.v3" @@ -78,6 +79,7 @@ type testState struct { httpBin *httpmultibin.HTTPMultiBin samples chan metrics.SampleContainer logger logrus.FieldLogger + loggerHook *testutils.SimpleLogrusHook callRecorder *callRecorder } @@ -114,6 +116,9 @@ func newTestState(t *testing.T) testState { logger.SetLevel(logrus.InfoLevel) logger.Out = io.Discard + hook := testutils.NewLogHook() + logger.AddHook(hook) + recorder := &callRecorder{ calls: make([]string, 0), } @@ -123,6 +128,7 @@ func newTestState(t *testing.T) testState { httpBin: tb, samples: samples, logger: logger, + loggerHook: hook, callRecorder: recorder, }