Skip to content

Commit

Permalink
remove HalfCloser interface
Browse files Browse the repository at this point in the history
Directly calls net.TCPConn.CloseWrite or mtls.Conn.CloseWrite, which was added in https://go-review.googlesource.com/c/go/+/31318/
  • Loading branch information
mikemorris committed Aug 6, 2019
1 parent d848637 commit 558ad5e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 28 deletions.
22 changes: 16 additions & 6 deletions agent/consul/snapshot_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package consul

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -192,7 +193,7 @@ RESPOND:
func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool,
args *structs.SnapshotRequest, in io.Reader, reply *structs.SnapshotResponse) (io.ReadCloser, error) {

conn, hc, err := connPool.DialTimeout(dc, addr, 10*time.Second, useTLS)
conn, err := connPool.DialTimeout(dc, addr, 10*time.Second, useTLS)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -225,12 +226,21 @@ func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool,
// the other side that they are done reading the stream, since we don't
// know the size in advance. This saves us from having to buffer just to
// calculate the size.
if hc != nil {
if err := hc.CloseWrite(); err != nil {
return nil, fmt.Errorf("failed to half close snapshot connection: %v", err)
switch connType := conn.(type) {
case *tls.Conn:
if tlsConn, ok := conn.(*tls.Conn); ok {
if err := tlsConn.CloseWrite(); err != nil {
return nil, fmt.Errorf("failed to half close TLS snapshot connection: %v", err)
}
}
case *net.TCPConn:
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.CloseWrite(); err != nil {
return nil, fmt.Errorf("failed to half close non-TLS snapshot connection: %v", err)
}
}
} else {
return nil, fmt.Errorf("snapshot connection requires half-close support")
default:
return nil, fmt.Errorf("unexpected Conn type: %T", connType)
}

// Pull the header decoded as msgpack. The caller can continue to read
Expand Down
2 changes: 1 addition & 1 deletion agent/consul/status_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) {
if wrapper == nil {
return nil, err
}
conn, _, err := pool.DialTimeoutWithRPCType(s.config.Datacenter, addr, nil, time.Second, true, wrapper, pool.RPCTLSInsecure)
conn, err := pool.DialTimeoutWithRPCType(s.config.Datacenter, addr, nil, time.Second, true, wrapper, pool.RPCTLSInsecure)
if err != nil {
return nil, err
}
Expand Down
31 changes: 10 additions & 21 deletions agent/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,9 @@ func (p *ConnPool) acquire(dc string, addr net.Addr, version int, useTLS bool) (
return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
}

// HalfCloser is an interface that exposes a TCP half-close. We need this
// because we want to expose the raw TCP connection underlying a TLS one in a
// way that's hard to screw up and use for anything else. There's a change
// brewing that will allow us to use the TLS connection for this instead -
// https://go-review.googlesource.com/#/c/25159/.
type HalfCloser interface {
CloseWrite() error
}

// DialTimeout is used to establish a raw connection to the given server, with
// given connection timeout. It also writes RPCTLS as the first byte.
func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, HalfCloser, error) {
func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, error) {
p.once.Do(p.init)

return DialTimeoutWithRPCType(dc, addr, p.SrcAddr, timeout, useTLS || p.ForceTLS, p.TLSConfigurator.OutgoingRPCWrapper(), RPCTLS)
Expand All @@ -278,56 +269,54 @@ func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration,
// server, with given connection timeout. It also writes RPCTLSInsecure as the
// first byte to indicate that the client cannot provide a certificate. This is
// so far only used for AutoEncrypt.Sign.
func (p *ConnPool) DialTimeoutInsecure(dc string, addr net.Addr, timeout time.Duration, wrapper tlsutil.DCWrapper) (net.Conn, HalfCloser, error) {
func (p *ConnPool) DialTimeoutInsecure(dc string, addr net.Addr, timeout time.Duration, wrapper tlsutil.DCWrapper) (net.Conn, error) {
p.once.Do(p.init)

if wrapper == nil {
return nil, nil, fmt.Errorf("wrapper cannot be nil")
return nil, fmt.Errorf("wrapper cannot be nil")
}

return DialTimeoutWithRPCType(dc, addr, p.SrcAddr, timeout, true, wrapper, RPCTLSInsecure)
}

func DialTimeoutWithRPCType(dc string, addr net.Addr, src *net.TCPAddr, timeout time.Duration, useTLS bool, wrapper tlsutil.DCWrapper, rpcType RPCType) (net.Conn, HalfCloser, error) {
func DialTimeoutWithRPCType(dc string, addr net.Addr, src *net.TCPAddr, timeout time.Duration, useTLS bool, wrapper tlsutil.DCWrapper, rpcType RPCType) (net.Conn, error) {
// Try to dial the conn
d := &net.Dialer{LocalAddr: src, Timeout: timeout}
conn, err := d.Dial("tcp", addr.String())
if err != nil {
return nil, nil, err
return nil, err
}

// Cast to TCPConn
var hc HalfCloser
if tcp, ok := conn.(*net.TCPConn); ok {
tcp.SetKeepAlive(true)
tcp.SetNoDelay(true)
hc = tcp
}

// Check if TLS is enabled
if (useTLS) && wrapper != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcType)}); err != nil {
conn.Close()
return nil, nil, err
return nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := wrapper(dc, conn)
if err != nil {
conn.Close()
return nil, nil, err
return nil, err
}
conn = tlsConn
}

return conn, hc, nil
return conn, nil
}

// getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int, useTLS bool) (*Conn, error) {
// Get a new, raw connection.
conn, _, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS)
conn, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -434,7 +423,7 @@ func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, use
// connection if it is not being reused.
func (p *ConnPool) rpcInsecure(dc string, addr net.Addr, method string, args interface{}, reply interface{}) error {
var codec rpc.ClientCodec
conn, _, err := p.DialTimeoutInsecure(dc, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper())
conn, err := p.DialTimeoutInsecure(dc, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper())
if err != nil {
return fmt.Errorf("rpcinsecure error establishing connection: %v", err)
}
Expand Down

0 comments on commit 558ad5e

Please sign in to comment.