diff --git a/p2p/transport/websocket/addrs.go b/p2p/transport/websocket/addrs.go index fed649dcbc..6ac23d01d8 100644 --- a/p2p/transport/websocket/addrs.go +++ b/p2p/transport/websocket/addrs.go @@ -132,12 +132,27 @@ func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) { type parsedWebsocketMultiaddr struct { isWSS bool - // sni is the SNI value for the TLS handshake + // sni is the SNI value for the TLS handshake, and for setting HTTP Host header sni *ma.Component // the rest of the multiaddr before the /tls/sni/example.com/ws or /ws or /wss restMultiaddr ma.Multiaddr } +func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { + if !pwma.isWSS { + if pwma.sni == nil { + return pwma.restMultiaddr.Encapsulate(wsComponent) + } + return pwma.restMultiaddr.Encapsulate(pwma.sni).Encapsulate(wsComponent) + } + + if pwma.sni == nil { + return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(wsComponent) + } + + return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(pwma.sni).Encapsulate(wsComponent) +} + func parseWebsocketMultiaddr(a ma.Multiaddr) (parsedWebsocketMultiaddr, error) { out := parsedWebsocketMultiaddr{} // First check if we have a WSS component. If so we'll canonicalize it into a /tls/ws diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 128fdf5eb5..c1cf051121 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -20,18 +20,6 @@ type listener struct { incoming chan *Conn } -func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { - if !pwma.isWSS { - return pwma.restMultiaddr.Encapsulate(wsComponent) - } - - if pwma.sni == nil { - return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(wsComponent) - } - - return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(pwma.sni).Encapsulate(wsComponent) -} - // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 2e9fc0b032..6a8c4b594d 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -4,6 +4,7 @@ package websocket import ( "context" "crypto/tls" + "net" "net/http" "time" @@ -125,11 +126,6 @@ func (t *WebsocketTransport) Resolve(ctx context.Context, maddr ma.Multiaddr) ([ return nil, err } - if !parsed.isWSS { - // No /tls/ws component, this isn't a secure websocket multiaddr. We can just return it here - return []ma.Multiaddr{maddr}, nil - } - if parsed.sni == nil { var err error // We don't have an sni component, we'll use dns/dnsaddr @@ -174,14 +170,30 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, err } isWss := wsurl.Scheme == "wss" - dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second} - if isWss { - sni := "" - sni, err = raddr.ValueForProtocol(ma.P_SNI) - if err != nil { - sni = "" + + sni := "" + sni, err = raddr.ValueForProtocol(ma.P_SNI) + if err != nil { + sni = "" + } + + host := wsurl.Host + + var dialer ws.Dialer + if sni == "" { + dialer = ws.Dialer{HandshakeTimeout: 30 * time.Second} + } else { + dialer = ws.Dialer{ + HandshakeTimeout: 30 * time.Second, + NetDial: func(network, address string) (net.Conn, error) { + tcpAddr, _ := net.ResolveTCPAddr(network, host) + return net.DialTCP("tcp", nil, tcpAddr) + }, } + wsurl.Host = sni + ":" + wsurl.Port() + } + if isWss { if sni != "" { copytlsClientConf := t.tlsClientConf.Clone() copytlsClientConf.ServerName = sni diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 1961e9cec9..df27f2fa26 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -479,6 +479,8 @@ func TestWriteZero(t *testing.T) { func TestResolveMultiaddr(t *testing.T) { // map[unresolved]resolved testCases := map[string]string{ + "/ip4/1.2.3.4/tcp/1234/ws": "/ip4/1.2.3.4/tcp/1234/ws", + "/dns4/example.com/tcp/1234/ws": "/dns4/example.com/tcp/1234/sni/example.com/ws", "/dns4/example.com/tcp/1234/wss": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws", "/dns6/example.com/tcp/1234/wss": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws", "/dnsaddr/example.com/tcp/1234/wss": "/dnsaddr/example.com/tcp/1234/tls/sni/example.com/ws",