Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Added wss dialing #46

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions addrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,17 @@ func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) {
}

func parseMultiaddr(a ma.Multiaddr) (string, error) {
_, host, err := manet.DialArgs(a)
p := a.Protocols()
host, err := a.ValueForProtocol(p[0].Code)
if err != nil {
return "", err
}

return "ws://" + host, nil
if p[0].Code == ma.P_IP6 {
host = "[" + host + "]"
}
port, err := a.ValueForProtocol(ma.P_TCP)
if err != nil {
return "", err
}
return p[2].Name + "://" + host + ":" + port, nil
}
13 changes: 13 additions & 0 deletions addrs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ func TestMultiaddrParsing(t *testing.T) {
if wsaddr != "ws://127.0.0.1:5555" {
t.Fatalf("expected ws://127.0.0.1:5555, got %s", wsaddr)
}

addr, err = ma.NewMultiaddr("/dnsaddr/example.com/tcp/5555/wss")
if err != nil {
t.Fatal(err)
}

wsaddr, err = parseMultiaddr(addr)
if err != nil {
t.Fatal(err)
}
if wsaddr != "wss://example.com:5555" {
t.Fatalf("expected wss://example.com:5555, got %s", wsaddr)
}
}

type httpAddr struct {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/libp2p/go-libp2p-testing v0.0.3
github.com/libp2p/go-libp2p-transport-upgrader v0.1.1
github.com/multiformats/go-multiaddr v0.0.4
github.com/multiformats/go-multiaddr-fmt v0.0.1
github.com/multiformats/go-multiaddr-net v0.0.1
github.com/whyrusleeping/mafmt v1.2.8
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ github.com/multiformats/go-multiaddr v0.0.4 h1:WgMSI84/eRLdbptXMkMWDXPjPq7SPLIgG
github.com/multiformats/go-multiaddr v0.0.4/go.mod h1:xKVEak1K9cS1VdmPZW3LSIb6lgmoS58qz/pzqmAxV44=
github.com/multiformats/go-multiaddr-dns v0.0.1 h1:jQt9c6tDSdQLIlBo4tXYx7QUHCPjxsB1zXcag/2S7zc=
github.com/multiformats/go-multiaddr-dns v0.0.1/go.mod h1:9kWcqw/Pj6FwxAwW38n/9403szc57zJPs45fmnznu3Q=
github.com/multiformats/go-multiaddr-dns v0.0.2 h1:/Bbsgsy3R6e3jf2qBahzNHzww6usYaZ0NhNH3sqdFS8=
github.com/multiformats/go-multiaddr-dns v0.0.2/go.mod h1:9kWcqw/Pj6FwxAwW38n/9403szc57zJPs45fmnznu3Q=
github.com/multiformats/go-multiaddr-fmt v0.0.1 h1:5YjeOIzbX8OTKVaN72aOzGIYW7PnrZrnkDyOfAWRSMA=
github.com/multiformats/go-multiaddr-fmt v0.0.1/go.mod h1:aBYjqL4T/7j4Qx+R73XSv/8JsgnRFlf0w2KGLCmXl3Q=
github.com/multiformats/go-multiaddr-net v0.0.1 h1:76O59E3FavvHqNg7jvzWzsPSW5JSi/ek0E4eiDVbg9g=
github.com/multiformats/go-multiaddr-net v0.0.1/go.mod h1:nw6HSxNmCIQH27XPGBuX+d1tnvM7ihcFwHMSstNAVUU=
github.com/multiformats/go-multibase v0.0.1/go.mod h1:bja2MqRZ3ggyXtZSEDKpl0uO/gviWFaSteVbWT51qgs=
Expand Down
107 changes: 78 additions & 29 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package websocket

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
Expand All @@ -15,8 +16,8 @@ import (

ws "github.com/gorilla/websocket"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr-net"
mafmt "github.com/whyrusleeping/mafmt"
)

// WsProtocol is the multiaddr protocol definition for this transport.
Expand All @@ -25,9 +26,17 @@ var WsProtocol = ma.Protocol{
Name: "ws",
VCode: ma.CodeToVarint(477),
}
var WssProtocol = ma.Protocol{
Code: 478,
Name: "wss",
VCode: ma.CodeToVarint(478),
}

// WsFmt is multiaddr formatter for WsProtocol
var WsFmt = mafmt.And(mafmt.TCP, mafmt.Base(WsProtocol.Code))
var WsFmt = mafmt.And(mafmt.TCP, mafmt.Or(
mafmt.Base(WsProtocol.Code),
mafmt.Base(WssProtocol.Code),
))

// WsCodec is the multiaddr-net codec definition for the websocket transport
var WsCodec = &manet.NetCodec{
Expand All @@ -36,6 +45,12 @@ var WsCodec = &manet.NetCodec{
ConvertMultiaddr: ConvertWebsocketMultiaddrToNetAddr,
ParseNetAddr: ParseWebsocketNetAddr,
}
var WssCodec = &manet.NetCodec{
NetAddrNetworks: []string{"websocket+tls"},
ProtocolName: "wss",
ConvertMultiaddr: ConvertWebsocketMultiaddrToNetAddr,
ParseNetAddr: ParseWebsocketNetAddr,
}

// Default gorilla upgrader
var upgrader = ws.Upgrader{
Expand All @@ -50,8 +65,13 @@ func init() {
if err != nil {
panic(fmt.Errorf("error registering websocket protocol: %s", err))
}
err = ma.AddProtocol(WssProtocol)
if err != nil {
panic(fmt.Errorf("error registering websocket+tls protocol: %s", err))
}

manet.RegisterNetCodec(WsCodec)
manet.RegisterNetCodec(WssCodec)
}

// WebsocketTransport is the actual go-libp2p transport
Expand All @@ -60,7 +80,9 @@ type WebsocketTransport struct {
}

func New(u *tptu.Upgrader) *WebsocketTransport {
return &WebsocketTransport{u}
return &WebsocketTransport{
Upgrader: u,
}
}

var _ transport.Transport = (*WebsocketTransport)(nil)
Expand All @@ -70,20 +92,41 @@ func (t *WebsocketTransport) CanDial(a ma.Multiaddr) bool {
}

func (t *WebsocketTransport) Protocols() []int {
return []int{WsProtocol.Code}
return []int{WsProtocol.Code, WssProtocol.Code}
}

func (t *WebsocketTransport) Proxy() bool {
return false
}

var CertCfg *tls.Config = &tls.Config{RootCAs: nil, InsecureSkipVerify: true}

func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
wsurl, err := parseMultiaddr(raddr)
if err != nil {
return nil, err
}

wscon, _, err := ws.DefaultDialer.Dial(wsurl, nil)
var wscon *ws.Conn
if raddr.Protocols()[2].Code == WsProtocol.Code {
wscon, _, err = ws.DefaultDialer.Dial(wsurl, nil)
} else {
u, err := url.Parse(wsurl)
if err != nil {
return nil, err
}

wsHeaders := http.Header{
"Origin": {u.Host},
// your milage may differ
"Sec-WebSocket-Extensions": {"permessage-deflate; client_max_window_bits, x-webkit-deflate-frame"},
}
tlsconn, err := tls.Dial("tcp", u.Host, CertCfg)
if err != nil {
return nil, err
}
wscon, _, err = ws.NewClient(tlsconn, u, wsHeaders, 1024, 1024)
}
if err != nil {
return nil, err
}
Expand All @@ -104,32 +147,38 @@ func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p pee
return t.Upgrader.UpgradeOutbound(ctx, t, macon, p)
}

func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
lnet, lnaddr, err := manet.DialArgs(a)
if err != nil {
return nil, err
}

nl, err := net.Listen(lnet, lnaddr)
if err != nil {
return nil, err
}
var listenWss = fmt.Errorf("Unable to listen wss, you should use a reverse proxy like nginx or apache.")

u, err := url.Parse("http://" + nl.Addr().String())
if err != nil {
nl.Close()
return nil, err
}

malist, err := t.wrapListener(nl, u)
if err != nil {
nl.Close()
return nil, err
func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
if a.Protocols()[2].Code == WsProtocol.Code {
lnet, lnaddr, err := manet.DialArgs(a)
if err != nil {
return nil, err
}

nl, err := net.Listen(lnet, lnaddr)
if err != nil {
return nil, err
}

u, err := url.Parse("http://" + nl.Addr().String())
if err != nil {
nl.Close()
return nil, err
}

malist, err := t.wrapListener(nl, u)
if err != nil {
nl.Close()
return nil, err
}

go malist.serve()

return malist, nil
} else {
return nil, listenWss
}

go malist.serve()

return malist, nil
}

func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) {
Expand Down
83 changes: 83 additions & 0 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ package websocket
import (
"bytes"
"context"
//"crypto/x509"
"io"
"io/ioutil"
"testing"
"testing/iotest"
/*"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
*/

"github.com/libp2p/go-libp2p-core/sec/insecure"

Expand All @@ -21,6 +26,10 @@ func TestCanDial(t *testing.T) {
if err != nil {
t.Fatal(err)
}
addrWss, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555/wss")
if err != nil {
t.Fatal(err)
}

addrTCP, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555")
if err != nil {
Expand All @@ -29,11 +38,15 @@ func TestCanDial(t *testing.T) {

d := &WebsocketTransport{}
matchTrue := d.CanDial(addrWs)
matchTrueWss := d.CanDial(addrWss)
matchFalse := d.CanDial(addrTCP)

if !matchTrue {
t.Fatal("expected to match websocket maddr, but did not")
}
if !matchTrueWss {
t.Fatal("expected to match websocket+tls maddr, but did not")
}

if matchFalse {
t.Fatal("expected to not match tcp maddr, but did")
Expand All @@ -54,6 +67,20 @@ func TestWebsocketTransport(t *testing.T) {
ttransport.SubtestTransport(t, ta, tb, zero, "peerA")
}

func TestWebsocketTransport6(t *testing.T) {
ta := New(&tptu.Upgrader{
Secure: insecure.New("peerA"),
Muxer: new(mplex.Transport),
})
tb := New(&tptu.Upgrader{
Secure: insecure.New("peerB"),
Muxer: new(mplex.Transport),
})

zero := "/ip6/::1/tcp/0/ws"
ttransport.SubtestTransport(t, ta, tb, zero, "peerA")
}

func TestWebsocketListen(t *testing.T) {
zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws")
if err != nil {
Expand Down Expand Up @@ -194,3 +221,59 @@ func TestWriteZero(t *testing.T) {
t.Errorf("expected EOF, got err: %s", err)
}
}
/*
func TestWebsocketSecureComplete(t *testing.T) {
var cert *[]byte

cert, _ = &CreateCertificate()

listen, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws")
if err != nil {
t.Fatal(err)
}

tpt := &WebsocketTransport{}
l, err := tpt.maListen(zero)
if err != nil {
t.Fatal(err)
}
defer l.Close()
fmt.Printf("%+v\n", l)

msg := []byte("HELLO WORLD")

go func() {
c, err := tpt.maDial(context.Background(), l.Multiaddr())
if err != nil {
t.Error(err)
return
}

_, err = c.Write(msg)
if err != nil {
t.Error(err)
}
err = c.Close()
if err != nil {
t.Error(err)
}
}()

c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()

obr := iotest.OneByteReader(c)

out, err := ioutil.ReadAll(obr)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(out, msg) {
t.Fatal("got wrong message", out, msg)
}
}
*/