Skip to content

Commit

Permalink
Add optional client TLS auth
Browse files Browse the repository at this point in the history
Signed-off-by: Tristan Colgate <tristan@qubit.com>
  • Loading branch information
Tristan Colgate committed Jun 7, 2019
1 parent 7f2a49d commit 1cdefd7
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 11 deletions.
45 changes: 39 additions & 6 deletions cmd/agent/app/reporter/grpc/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
package grpc

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"strings"

grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
Expand All @@ -39,6 +43,8 @@ type ConnBuilder struct {
TLS bool
TLSCA string
TLSServerName string
TLSCert string
TLSKey string

DiscoveryMinPeers int
Notifier discovery.Notifier
Expand All @@ -56,20 +62,47 @@ func (b *ConnBuilder) CreateConnection(logger *zap.Logger) (*grpc.ClientConn, er
var dialTarget string
if b.TLS { // user requested a secure connection
logger.Info("Agent requested secure grpc connection to collector(s)")
var creds credentials.TransportCredentials
var err error
var certPool *x509.CertPool
if len(b.TLSCA) == 0 { // no truststore given, use SystemCertPool
pool, err := systemCertPool()
certPool, err = systemCertPool()
if err != nil {
return nil, err
}
creds = credentials.NewClientTLSFromCert(pool, b.TLSServerName)
} else { // setup user specified truststore
var err error
creds, err = credentials.NewClientTLSFromFile(b.TLSCA, b.TLSServerName)
caPEM, err := ioutil.ReadFile(b.TLSCA)
if err != nil {
return nil, err
return nil, fmt.Errorf("reading client CA failed, %v", err)
}

certPool = x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caPEM) {
return nil, fmt.Errorf("building client CA failed, %v", err)
}
}

tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: certPool,
ServerName: b.TLSServerName,
}

if (b.TLSKey == "" || b.TLSCert == "") &&
(b.TLSKey != "" || b.TLSCert != "") {
return nil, fmt.Errorf("for client auth, both client certificate and key must be supplied")
}

if b.TLSKey != "" && b.TLSCert != "" {
tlsCert, err := tls.LoadX509KeyPair(b.TLSCert, b.TLSKey)
if err != nil {
return nil, fmt.Errorf("could not load server TLS cert and key, %v", err)
}

logger.Info("client TLS authentication enabled")
tlsCfg.Certificates = []tls.Certificate{tlsCert}
}

creds := credentials.NewTLS(tlsCfg)
dialOptions = append(dialOptions, grpc.WithTransportCredentials(creds))
} else { // insecure connection
logger.Info("Agent requested insecure grpc connection to collector(s)")
Expand Down
6 changes: 6 additions & 0 deletions cmd/agent/app/reporter/grpc/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ const (
defaultMaxRetry = 3
collectorTLS = gRPCPrefix + "tls"
collectorTLSCA = gRPCPrefix + "tls.ca"
agentCert = gRPCPrefix + "tls.cert"
agentKey = gRPCPrefix + "tls.key"
collectorTLSServerName = gRPCPrefix + "tls.server-name"
discoveryMinPeers = gRPCPrefix + "discovery.min-peers"
)
Expand All @@ -39,6 +41,8 @@ func AddFlags(flags *flag.FlagSet) {
flags.Bool(collectorTLS, false, "Enable TLS.")
flags.String(collectorTLSCA, "", "Path to a TLS CA file. (default use the systems truststore)")
flags.String(collectorTLSServerName, "", "Override the TLS server name.")
flags.String(agentCert, "", "Path to a TLS client certificate file.")
flags.String(agentKey, "", "Path to a TLS client key file.")
flags.Int(discoveryMinPeers, 3, "Max number of collectors to which the agent will try to connect at any given time")
}

Expand All @@ -52,6 +56,8 @@ func (b *ConnBuilder) InitFromViper(v *viper.Viper) *ConnBuilder {
b.TLS = v.GetBool(collectorTLS)
b.TLSCA = v.GetString(collectorTLSCA)
b.TLSServerName = v.GetString(collectorTLSServerName)
b.TLSCert = v.GetString(agentCert)
b.TLSKey = v.GetString(agentKey)
b.DiscoveryMinPeers = v.GetInt(discoveryMinPeers)
return b
}
5 changes: 5 additions & 0 deletions cmd/collector/app/builder/builder_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const (
collectorGRPCTLS = "collector.grpc.tls"
collectorGRPCCert = "collector.grpc.tls.cert"
collectorGRPCKey = "collector.grpc.tls.key"
collectorGRPCClientCA = "collector.grpc.tls.client.ca"
collectorZipkinHTTPort = "collector.zipkin.http-port"
collectorZipkinAllowedOrigins = "collector.zipkin.allowed-origins"
collectorZipkinAllowedHeaders = "collector.zipkin.allowed-headers"
Expand All @@ -53,6 +54,8 @@ type CollectorOptions struct {
CollectorGRPCTLS bool
// CollectorGRPCCert is the path to a TLS certificate file for the server
CollectorGRPCCert string
// CollectorGRPCClientCA is the path to a TLS certificate file for // authenticating clients
CollectorGRPCClientCA string
// CollectorGRPCKey is the path to a TLS key file for the server
CollectorGRPCKey string
// CollectorZipkinHTTPPort is the port that the Zipkin collector service listens in on for http requests
Expand All @@ -74,6 +77,7 @@ func AddFlags(flags *flag.FlagSet) {
flags.Bool(collectorGRPCTLS, false, "Enable TLS")
flags.String(collectorGRPCCert, "", "Path to TLS certificate file")
flags.String(collectorGRPCKey, "", "Path to TLS key file")
flags.String(collectorGRPCClientCA, "", "Path to TLS certificates for authenticating clients")
flags.String(collectorZipkinAllowedOrigins, "*", "Allowed origins for the Zipkin collector service, default accepts all")
flags.String(collectorZipkinAllowedHeaders, "content-type", "Allowed headers for the Zipkin collector service, default content-type")
}
Expand All @@ -87,6 +91,7 @@ func (cOpts *CollectorOptions) InitFromViper(v *viper.Viper) *CollectorOptions {
cOpts.CollectorGRPCPort = v.GetInt(collectorGRPCPort)
cOpts.CollectorGRPCTLS = v.GetBool(collectorGRPCTLS)
cOpts.CollectorGRPCCert = v.GetString(collectorGRPCCert)
cOpts.CollectorGRPCClientCA = v.GetString(collectorGRPCClientCA)
cOpts.CollectorGRPCKey = v.GetString(collectorGRPCKey)
cOpts.CollectorZipkinHTTPPort = v.GetInt(collectorZipkinHTTPort)
cOpts.CollectorZipkinAllowedOrigins = v.GetString(collectorZipkinAllowedOrigins)
Expand Down
33 changes: 28 additions & 5 deletions cmd/collector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
package main

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -195,13 +198,33 @@ func startGRPCServer(
if opts.CollectorGRPCCert == "" || opts.CollectorGRPCKey == "" {
return nil, fmt.Errorf("you requested TLS but configuration does not include a path to cert and/or key")
}
creds, err := credentials.NewServerTLSFromFile(
opts.CollectorGRPCCert,
opts.CollectorGRPCKey,
)

tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS12,
}

tlsCert, err := tls.LoadX509KeyPair(opts.CollectorGRPCCert, opts.CollectorGRPCKey)
if err != nil {
return nil, fmt.Errorf("failed to load TLS keys: %s", err)
return nil, fmt.Errorf("could not load server TLS cert and key, %v", err)
}

tlsCfg.Certificates = []tls.Certificate{tlsCert}

if opts.CollectorGRPCClientCA != "" {
caPEM, err := ioutil.ReadFile(opts.CollectorGRPCClientCA)
if err != nil {
return nil, fmt.Errorf("load TLS client CA, %v", err)
}

certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caPEM) {
return nil, fmt.Errorf("building TLS client CA, %v", err)
}
tlsCfg.ClientCAs = certPool
tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
}

creds := credentials.NewTLS(tlsCfg)
server = grpc.NewServer(grpc.Creds(creds))
} else { // server without TLS
server = grpc.NewServer()
Expand Down

0 comments on commit 1cdefd7

Please sign in to comment.