Skip to content

Commit

Permalink
snapshot: add TLS support to HalfCloser interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemorris committed Aug 6, 2019
1 parent 558ad5e commit 3f55083
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
22 changes: 6 additions & 16 deletions agent/consul/snapshot_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package consul

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -193,7 +192,7 @@ RESPOND:
func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool,
args *structs.SnapshotRequest, in io.Reader, reply *structs.SnapshotResponse) (io.ReadCloser, error) {

conn, err := connPool.DialTimeout(dc, addr, 10*time.Second, useTLS)
conn, hc, err := connPool.DialTimeout(dc, addr, 10*time.Second, useTLS)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -226,21 +225,12 @@ func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool,
// the other side that they are done reading the stream, since we don't
// know the size in advance. This saves us from having to buffer just to
// calculate the size.
switch connType := conn.(type) {
case *tls.Conn:
if tlsConn, ok := conn.(*tls.Conn); ok {
if err := tlsConn.CloseWrite(); err != nil {
return nil, fmt.Errorf("failed to half close TLS snapshot connection: %v", err)
}
}
case *net.TCPConn:
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.CloseWrite(); err != nil {
return nil, fmt.Errorf("failed to half close non-TLS snapshot connection: %v", err)
}
if hc != nil {
if err := hc.CloseWrite(); err != nil {
return nil, fmt.Errorf("failed to half close snapshot connection: %v", err)
}
default:
return nil, fmt.Errorf("unexpected Conn type: %T", connType)
} else {
return nil, fmt.Errorf("snapshot connection requires half-close support")
}

// Pull the header decoded as msgpack. The caller can continue to read
Expand Down
2 changes: 1 addition & 1 deletion agent/consul/status_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) {
if wrapper == nil {
return nil, err
}
conn, err := pool.DialTimeoutWithRPCType(s.config.Datacenter, addr, nil, time.Second, true, wrapper, pool.RPCTLSInsecure)
conn, _, err := pool.DialTimeoutWithRPCType(s.config.Datacenter, addr, nil, time.Second, true, wrapper, pool.RPCTLSInsecure)
if err != nil {
return nil, err
}
Expand Down
34 changes: 24 additions & 10 deletions agent/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pool

import (
"container/list"
"crypto/tls"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -257,9 +258,15 @@ func (p *ConnPool) acquire(dc string, addr net.Addr, version int, useTLS bool) (
return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
}

// HalfCloser is an interface that exposes a TCP half-close without exposing
// the underlying TLS or raw TCP connection.
type HalfCloser interface {
CloseWrite() error
}

// DialTimeout is used to establish a raw connection to the given server, with
// given connection timeout. It also writes RPCTLS as the first byte.
func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, error) {
func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, HalfCloser, error) {
p.once.Do(p.init)

return DialTimeoutWithRPCType(dc, addr, p.SrcAddr, timeout, useTLS || p.ForceTLS, p.TLSConfigurator.OutgoingRPCWrapper(), RPCTLS)
Expand All @@ -269,54 +276,61 @@ func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration,
// server, with given connection timeout. It also writes RPCTLSInsecure as the
// first byte to indicate that the client cannot provide a certificate. This is
// so far only used for AutoEncrypt.Sign.
func (p *ConnPool) DialTimeoutInsecure(dc string, addr net.Addr, timeout time.Duration, wrapper tlsutil.DCWrapper) (net.Conn, error) {
func (p *ConnPool) DialTimeoutInsecure(dc string, addr net.Addr, timeout time.Duration, wrapper tlsutil.DCWrapper) (net.Conn, HalfCloser, error) {
p.once.Do(p.init)

if wrapper == nil {
return nil, fmt.Errorf("wrapper cannot be nil")
return nil, nil, fmt.Errorf("wrapper cannot be nil")
}

return DialTimeoutWithRPCType(dc, addr, p.SrcAddr, timeout, true, wrapper, RPCTLSInsecure)
}

func DialTimeoutWithRPCType(dc string, addr net.Addr, src *net.TCPAddr, timeout time.Duration, useTLS bool, wrapper tlsutil.DCWrapper, rpcType RPCType) (net.Conn, error) {
func DialTimeoutWithRPCType(dc string, addr net.Addr, src *net.TCPAddr, timeout time.Duration, useTLS bool, wrapper tlsutil.DCWrapper, rpcType RPCType) (net.Conn, HalfCloser, error) {
// Try to dial the conn
d := &net.Dialer{LocalAddr: src, Timeout: timeout}
conn, err := d.Dial("tcp", addr.String())
if err != nil {
return nil, err
return nil, nil, err
}

var hc HalfCloser

// Cast to TCPConn
if tcp, ok := conn.(*net.TCPConn); ok {
tcp.SetKeepAlive(true)
tcp.SetNoDelay(true)
hc = tcp
}

// Check if TLS is enabled
if (useTLS) && wrapper != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcType)}); err != nil {
conn.Close()
return nil, err
return nil, nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := wrapper(dc, conn)
if err != nil {
conn.Close()
return nil, err
return nil, nil, err
}
conn = tlsConn

if tlsConn, ok := conn.(*tls.Conn); ok {
hc = tlsConn
}
}

return conn, nil
return conn, hc, nil
}

// getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int, useTLS bool) (*Conn, error) {
// Get a new, raw connection.
conn, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS)
conn, _, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -423,7 +437,7 @@ func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, use
// connection if it is not being reused.
func (p *ConnPool) rpcInsecure(dc string, addr net.Addr, method string, args interface{}, reply interface{}) error {
var codec rpc.ClientCodec
conn, err := p.DialTimeoutInsecure(dc, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper())
conn, _, err := p.DialTimeoutInsecure(dc, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper())
if err != nil {
return fmt.Errorf("rpcinsecure error establishing connection: %v", err)
}
Expand Down

0 comments on commit 3f55083

Please sign in to comment.