Skip to content
This repository has been archived by the owner on May 21, 2024. It is now read-only.

This is a port of the PR changes found here: #25

Merged
merged 2 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 151 additions & 3 deletions grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package grpc

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -106,6 +109,101 @@ func (c *Client) LoadProtoset(protosetPath string) ([]MethodInfo, error) {
return c.convertToMethodInfo(fdset)
}

// Note: this function was lifted from `lib/options.go`
func decryptPrivateKey(key, password []byte) ([]byte, error) {
block, _ := pem.Decode(key)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM key")
}

blockType := block.Type
if blockType == "ENCRYPTED PRIVATE KEY" {
return nil, fmt.Errorf("encrypted pkcs8 formatted key is not supported")
}
/*
Even though `DecryptPEMBlock` has been deprecated since 1.16.x it is still
being used here because it is deprecated due to it not supporting *good* cryptography
ultimately though we want to support something so we will be using it for now.
*/
decryptedKey, err := x509.DecryptPEMBlock(block, password) //nolint:staticcheck
if err != nil {
return nil, err
}
key = pem.EncodeToMemory(&pem.Block{
Type: blockType,
Bytes: decryptedKey,
})
return key, nil
}

func buildTLSConfig(parentConfig *tls.Config, certificate, key []byte, caCertificates [][]byte) (*tls.Config, error) {
var cp *x509.CertPool
if len(caCertificates) > 0 {
cp, _ = x509.SystemCertPool()
for i, caCert := range caCertificates {
if ok := cp.AppendCertsFromPEM(caCert); !ok {
return nil, fmt.Errorf("failed to append ca certificate [%d] from PEM", i)
}
}
}

// Ignoring 'TLS MinVersion is too low' because this tls.Config will inherit MinValue and MaxValue
// from the vu state tls.Config

//nolint:golint,gosec
tlsCfg := &tls.Config{
CipherSuites: parentConfig.CipherSuites,
InsecureSkipVerify: parentConfig.InsecureSkipVerify,
MinVersion: parentConfig.MinVersion,
MaxVersion: parentConfig.MaxVersion,
Renegotiation: parentConfig.Renegotiation,
RootCAs: cp,
}
if len(certificate) > 0 && len(key) > 0 {
cert, err := tls.X509KeyPair(certificate, key)
if err != nil {
return nil, fmt.Errorf("failed to append certificate from PEM: %w", err)
}
tlsCfg.Certificates = []tls.Certificate{cert}
}
return tlsCfg, nil
}

func buildTLSConfigFromMap(parentConfig *tls.Config, tlsConfigMap map[string]interface{}) (*tls.Config, error) {
var cert, key, pass []byte
var ca [][]byte
var err error
if certstr, ok := tlsConfigMap["cert"].(string); ok {
cert = []byte(certstr)
}
if keystr, ok := tlsConfigMap["key"].(string); ok {
key = []byte(keystr)
}
if passwordStr, ok := tlsConfigMap["password"].(string); ok {
pass = []byte(passwordStr)
if len(pass) > 0 {
if key, err = decryptPrivateKey(key, pass); err != nil {
return nil, err
}
}
}
if cas, ok := tlsConfigMap["cacerts"]; ok {
var caCertsArray []interface{}
if caCertsArray, ok = cas.([]interface{}); ok {
ca = make([][]byte, len(caCertsArray))
for i, entry := range caCertsArray {
var entryStr string
if entryStr, ok = entry.(string); ok {
ca[i] = []byte(entryStr)
}
}
} else if caCertStr, caCertStrOk := cas.(string); caCertStrOk {
ca = [][]byte{[]byte(caCertStr)}
}
}
return buildTLSConfig(parentConfig, cert, key, ca)
}

// Connect is a block dial to the gRPC server at the given address (host:port)
func (c *Client) Connect(addr string, params map[string]interface{}) (bool, error) {
state := c.vu.State()
Expand All @@ -122,10 +220,16 @@ func (c *Client) Connect(addr string, params map[string]interface{}) (bool, erro

var tcred credentials.TransportCredentials
if !p.IsPlaintext {
tlsCfg := state.TLSConfig.Clone()
var tlsCfg *tls.Config
if len(p.TLS) > 0 {
if tlsCfg, err = buildTLSConfigFromMap(state.TLSConfig.Clone(), p.TLS); err != nil {
return false, err
}
} else {
tlsCfg = state.TLSConfig.Clone()
}
tlsCfg.NextProtos = []string{"h2"}

// TODO(rogchap): Would be good to add support for custom RootCAs (self signed)
tcred = credentials.NewTLS(tlsCfg)
} else {
tcred = insecure.NewCredentials()
Expand Down Expand Up @@ -322,6 +426,7 @@ type connectParams struct {
Timeout time.Duration
MaxReceiveSize int64
MaxSendSize int64
TLS map[string]interface{}
}

func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams, error) {
Expand Down Expand Up @@ -370,14 +475,57 @@ func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams,
if params.MaxSendSize < 0 {
return params, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v)
}

case "tls":
if err := parseConnectTLSParam(&params, v); err != nil {
return params, err
}
default:
return params, fmt.Errorf("unknown connect param: %q", k)
}
}
return params, nil
}

func parseConnectTLSParam(params *connectParams, v interface{}) error {
var ok bool
params.TLS, ok = v.(map[string]interface{})

if !ok {
return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v)
}
// optional map keys below
if cert, certok := params.TLS["cert"]; certok {
if _, ok = cert.(string); !ok {
return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v)
}
}
if key, keyok := params.TLS["key"]; keyok {
if _, ok = key.(string); !ok {
return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v)
}
}
if pass, passok := params.TLS["password"]; passok {
if _, ok = pass.(string); !ok {
return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v)
}
}
if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok {
var cacertsArray []interface{}
if cacertsArray, ok = cacerts.([]interface{}); ok {
for _, cacertsArrayEntry := range cacertsArray {
if _, ok = cacertsArrayEntry.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or string[] of PEM formatted strings", v)
}
}
} else if _, ok = cacerts.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or string[] of PEM formatted strings", v)
}
}
return nil
}

func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*descriptorpb.FileDescriptorProto {
fds := []*descriptorpb.FileDescriptorProto{}

Expand Down
Loading