diff --git a/agent/consul/snapshot_endpoint.go b/agent/consul/snapshot_endpoint.go index 85ef8a3047e4..f00ba1931a64 100644 --- a/agent/consul/snapshot_endpoint.go +++ b/agent/consul/snapshot_endpoint.go @@ -9,7 +9,6 @@ package consul import ( "bytes" - "crypto/tls" "errors" "fmt" "io" @@ -193,7 +192,7 @@ RESPOND: func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool, args *structs.SnapshotRequest, in io.Reader, reply *structs.SnapshotResponse) (io.ReadCloser, error) { - conn, err := connPool.DialTimeout(dc, addr, 10*time.Second, useTLS) + conn, hc, err := connPool.DialTimeout(dc, addr, 10*time.Second, useTLS) if err != nil { return nil, err } @@ -226,21 +225,12 @@ func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool, // the other side that they are done reading the stream, since we don't // know the size in advance. This saves us from having to buffer just to // calculate the size. - switch connType := conn.(type) { - case *tls.Conn: - if tlsConn, ok := conn.(*tls.Conn); ok { - if err := tlsConn.CloseWrite(); err != nil { - return nil, fmt.Errorf("failed to half close TLS snapshot connection: %v", err) - } - } - case *net.TCPConn: - if tcpConn, ok := conn.(*net.TCPConn); ok { - if err := tcpConn.CloseWrite(); err != nil { - return nil, fmt.Errorf("failed to half close non-TLS snapshot connection: %v", err) - } + if hc != nil { + if err := hc.CloseWrite(); err != nil { + return nil, fmt.Errorf("failed to half close snapshot connection: %v", err) } - default: - return nil, fmt.Errorf("unexpected Conn type: %T", connType) + } else { + return nil, fmt.Errorf("snapshot connection requires half-close support") } // Pull the header decoded as msgpack. The caller can continue to read diff --git a/agent/consul/status_endpoint_test.go b/agent/consul/status_endpoint_test.go index b35419bfb795..0403e6320578 100644 --- a/agent/consul/status_endpoint_test.go +++ b/agent/consul/status_endpoint_test.go @@ -37,7 +37,7 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) { if wrapper == nil { return nil, err } - conn, err := pool.DialTimeoutWithRPCType(s.config.Datacenter, addr, nil, time.Second, true, wrapper, pool.RPCTLSInsecure) + conn, _, err := pool.DialTimeoutWithRPCType(s.config.Datacenter, addr, nil, time.Second, true, wrapper, pool.RPCTLSInsecure) if err != nil { return nil, err } diff --git a/agent/pool/pool.go b/agent/pool/pool.go index fee1b844703e..0593532a30c3 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -2,6 +2,7 @@ package pool import ( "container/list" + "crypto/tls" "fmt" "io" "net" @@ -257,9 +258,15 @@ func (p *ConnPool) acquire(dc string, addr net.Addr, version int, useTLS bool) ( return nil, fmt.Errorf("rpc error: lead thread didn't get connection") } +// HalfCloser is an interface that exposes a TCP half-close without exposing +// the underlying TLS or raw TCP connection. +type HalfCloser interface { + CloseWrite() error +} + // DialTimeout is used to establish a raw connection to the given server, with // given connection timeout. It also writes RPCTLS as the first byte. -func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, error) { +func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, HalfCloser, error) { p.once.Do(p.init) return DialTimeoutWithRPCType(dc, addr, p.SrcAddr, timeout, useTLS || p.ForceTLS, p.TLSConfigurator.OutgoingRPCWrapper(), RPCTLS) @@ -269,28 +276,31 @@ func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, // server, with given connection timeout. It also writes RPCTLSInsecure as the // first byte to indicate that the client cannot provide a certificate. This is // so far only used for AutoEncrypt.Sign. -func (p *ConnPool) DialTimeoutInsecure(dc string, addr net.Addr, timeout time.Duration, wrapper tlsutil.DCWrapper) (net.Conn, error) { +func (p *ConnPool) DialTimeoutInsecure(dc string, addr net.Addr, timeout time.Duration, wrapper tlsutil.DCWrapper) (net.Conn, HalfCloser, error) { p.once.Do(p.init) if wrapper == nil { - return nil, fmt.Errorf("wrapper cannot be nil") + return nil, nil, fmt.Errorf("wrapper cannot be nil") } return DialTimeoutWithRPCType(dc, addr, p.SrcAddr, timeout, true, wrapper, RPCTLSInsecure) } -func DialTimeoutWithRPCType(dc string, addr net.Addr, src *net.TCPAddr, timeout time.Duration, useTLS bool, wrapper tlsutil.DCWrapper, rpcType RPCType) (net.Conn, error) { +func DialTimeoutWithRPCType(dc string, addr net.Addr, src *net.TCPAddr, timeout time.Duration, useTLS bool, wrapper tlsutil.DCWrapper, rpcType RPCType) (net.Conn, HalfCloser, error) { // Try to dial the conn d := &net.Dialer{LocalAddr: src, Timeout: timeout} conn, err := d.Dial("tcp", addr.String()) if err != nil { - return nil, err + return nil, nil, err } + var hc HalfCloser + // Cast to TCPConn if tcp, ok := conn.(*net.TCPConn); ok { tcp.SetKeepAlive(true) tcp.SetNoDelay(true) + hc = tcp } // Check if TLS is enabled @@ -298,25 +308,29 @@ func DialTimeoutWithRPCType(dc string, addr net.Addr, src *net.TCPAddr, timeout // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(rpcType)}); err != nil { conn.Close() - return nil, err + return nil, nil, err } // Wrap the connection in a TLS client tlsConn, err := wrapper(dc, conn) if err != nil { conn.Close() - return nil, err + return nil, nil, err } conn = tlsConn + + if tlsConn, ok := conn.(*tls.Conn); ok { + hc = tlsConn + } } - return conn, nil + return conn, hc, nil } // getNewConn is used to return a new connection func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int, useTLS bool) (*Conn, error) { // Get a new, raw connection. - conn, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS) + conn, _, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS) if err != nil { return nil, err } @@ -423,7 +437,7 @@ func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, use // connection if it is not being reused. func (p *ConnPool) rpcInsecure(dc string, addr net.Addr, method string, args interface{}, reply interface{}) error { var codec rpc.ClientCodec - conn, err := p.DialTimeoutInsecure(dc, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper()) + conn, _, err := p.DialTimeoutInsecure(dc, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper()) if err != nil { return fmt.Errorf("rpcinsecure error establishing connection: %v", err) }