diff --git a/connection.go b/connection.go index 0bea4cf..556def5 100644 --- a/connection.go +++ b/connection.go @@ -54,13 +54,13 @@ func (c *connection) doTunnelClose(err error) { c.buffer.Close(c.err) } -func (c *connection) OnData(m *message) error { +func (c *connection) OnData(r io.Reader) error { if PrintTunnelData { defer func() { logrus.Debugf("ONDATA [%d] %s", c.connID, c.buffer.Status()) }() } - return c.buffer.Offer(m.body) + return c.buffer.Offer(r) } func (c *connection) Close() error { diff --git a/message.go b/message.go index ce388ef..6ef894f 100644 --- a/message.go +++ b/message.go @@ -16,12 +16,21 @@ import ( ) const ( + // Data is the main message type, used to transport application data Data messageType = iota + 1 + // Connect is a control message type, used to request opening a new connection Connect + // Error is a message type used to send an error during the communication. + // Any receiver of an Error message can assume the connection can be closed. + // io.EOF is used for graceful termination of connections. Error + // AddClient is a message type used to open a new client to the peering session AddClient + // RemoveClient is a message type used to remove an existing client from a peering session RemoveClient + // Pause is a message type used to temporarily stop a given connection Pause + // Resume is a message type used to resume a paused connection Resume ) @@ -211,7 +220,7 @@ func (m *message) Read(p []byte) (int, error) { return m.body.Read(p) } -func (m *message) WriteTo(deadline time.Time, wsConn *wsConn) (int, error) { +func (m *message) WriteTo(deadline time.Time, wsConn wsConn) (int, error) { err := wsConn.WriteMessage(websocket.BinaryMessage, deadline, m.Bytes()) return len(m.bytes), err } diff --git a/session.go b/session.go index 0c9f84e..e187f49 100644 --- a/session.go +++ b/session.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "os" "strconv" @@ -18,12 +17,12 @@ import ( ) type Session struct { - sync.Mutex + sync.RWMutex nextConnID int64 clientKey string sessionKey int64 - conn *wsConn + conn wsConn conns map[int64]*connection remoteClientKeys map[string]map[int]bool auth ConnectAuthorizer @@ -57,17 +56,81 @@ func NewClientSessionWithDialer(auth ConnectAuthorizer, conn *websocket.Conn, di } } -func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *Session { +func newSession(sessionKey int64, clientKey string, conn wsConn) *Session { return &Session{ nextConnID: 1, clientKey: clientKey, sessionKey: sessionKey, - conn: newWSConn(conn), + conn: conn, conns: map[int64]*connection{}, remoteClientKeys: map[string]map[int]bool{}, } } +// addConnection safely registers a new connection in the connections map +func (s *Session) addConnection(connID int64, conn *connection) { + s.Lock() + defer s.Unlock() + + s.conns[connID] = conn + if PrintTunnelData { + logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) + } +} + +// removeConnection safely removes a connection by ID, returning the connection object +func (s *Session) removeConnection(connID int64) *connection { + s.Lock() + defer s.Unlock() + + conn := s.conns[connID] + delete(s.conns, connID) + if PrintTunnelData { + defer logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) + } + return conn +} + +// getConnection retrieves a connection by ID +func (s *Session) getConnection(connID int64) *connection { + s.RLock() + defer s.RUnlock() + + return s.conns[connID] +} + +// addSessionKey registers a new session key for a given client key +func (s *Session) addSessionKey(clientKey string, sessionKey int) { + s.Lock() + defer s.Unlock() + + keys := s.remoteClientKeys[clientKey] + if keys == nil { + keys = map[int]bool{} + s.remoteClientKeys[clientKey] = keys + } + keys[sessionKey] = true +} + +// removeSessionKey removes a specific session key for a client key +func (s *Session) removeSessionKey(clientKey string, sessionKey int) { + s.Lock() + defer s.Unlock() + + keys := s.remoteClientKeys[clientKey] + delete(keys, sessionKey) + if len(keys) == 0 { + delete(s.remoteClientKeys, clientKey) + } +} + +// getSessionKeys retrieves all session keys for a given client key +func (s *Session) getSessionKeys(clientKey string) map[int]bool { + s.RLock() + defer s.RUnlock() + return s.remoteClientKeys[clientKey] +} + func (s *Session) startPings(rootCtx context.Context) { ctx, cancel := context.WithCancel(rootCtx) s.pingCancel = cancel @@ -84,12 +147,10 @@ func (s *Session) startPings(rootCtx context.Context) { case <-ctx.Done(): return case <-t.C: - s.conn.Lock() - if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(PingWaitDuration)); err != nil { + if err := s.conn.WriteControl(websocket.PingMessage, time.Now().Add(PingWaitDuration), []byte("")); err != nil { logrus.WithError(err).Error("Error writing ping") } logrus.Debug("Wrote ping") - s.conn.Unlock() } } }() @@ -125,61 +186,6 @@ func (s *Session) Serve(ctx context.Context) (int, error) { } } -func (s *Session) serveMessage(ctx context.Context, reader io.Reader) error { - message, err := newServerMessage(reader) - if err != nil { - return err - } - - if PrintTunnelData { - logrus.Debug("REQUEST ", message) - } - - if message.messageType == Connect { - if s.auth == nil || !s.auth(message.proto, message.address) { - return errors.New("connect not allowed") - } - s.clientConnect(ctx, message) - return nil - } - - s.Lock() - if message.messageType == AddClient && s.remoteClientKeys != nil { - err := s.addRemoteClient(message.address) - s.Unlock() - return err - } else if message.messageType == RemoveClient { - err := s.removeRemoteClient(message.address) - s.Unlock() - return err - } - conn := s.conns[message.connID] - s.Unlock() - - if conn == nil { - if message.messageType == Data { - err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID) - newErrorMessage(message.connID, err).WriteTo(defaultDeadline(), s.conn) - } - return nil - } - - switch message.messageType { - case Data: - if err := conn.OnData(message); err != nil { - s.closeConnection(message.connID, err) - } - case Pause: - conn.OnPause() - case Resume: - conn.OnResume() - case Error: - s.closeConnection(message.connID, message.Err()) - } - - return nil -} - func defaultDeadline() time.Time { return time.Now().Add(time.Minute) } @@ -193,72 +199,6 @@ func parseAddress(address string) (string, int, error) { return parts[0], v, err } -func (s *Session) addRemoteClient(address string) error { - clientKey, sessionKey, err := parseAddress(address) - if err != nil { - return fmt.Errorf("invalid remote Session %s: %v", address, err) - } - - keys := s.remoteClientKeys[clientKey] - if keys == nil { - keys = map[int]bool{} - s.remoteClientKeys[clientKey] = keys - } - keys[sessionKey] = true - - if PrintTunnelData { - logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) - } - - return nil -} - -func (s *Session) removeRemoteClient(address string) error { - clientKey, sessionKey, err := parseAddress(address) - if err != nil { - return fmt.Errorf("invalid remote Session %s: %v", address, err) - } - - keys := s.remoteClientKeys[clientKey] - delete(keys, int(sessionKey)) - if len(keys) == 0 { - delete(s.remoteClientKeys, clientKey) - } - - if PrintTunnelData { - logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) - } - - return nil -} - -func (s *Session) closeConnection(connID int64, err error) { - s.Lock() - conn := s.conns[connID] - delete(s.conns, connID) - if PrintTunnelData { - logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) - } - s.Unlock() - - if conn != nil { - conn.tunnelClose(err) - } -} - -func (s *Session) clientConnect(ctx context.Context, message *message) { - conn := newConnection(message.connID, s, message.proto, message.address) - - s.Lock() - s.conns[message.connID] = conn - if PrintTunnelData { - logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) - } - s.Unlock() - - go clientDial(ctx, s.dialer, conn, message) -} - type connResult struct { conn net.Conn err error @@ -299,12 +239,7 @@ func (s *Session) serverConnect(deadline time.Time, proto, address string) (net. connID := atomic.AddInt64(&s.nextConnID, 1) conn := newConnection(connID, s, proto, address) - s.Lock() - s.conns[connID] = conn - if PrintTunnelData { - logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) - } - s.Unlock() + s.addConnection(connID, conn) _, err := s.writeMessage(deadline, newConnect(connID, proto, address)) if err != nil { @@ -339,7 +274,7 @@ func (s *Session) sessionAdded(clientKey string, sessionKey int64) { client := fmt.Sprintf("%s/%d", clientKey, sessionKey) _, err := s.writeMessage(time.Time{}, newAddClient(client)) if err != nil { - s.conn.conn.Close() + s.conn.Close() } } @@ -347,6 +282,6 @@ func (s *Session) sessionRemoved(clientKey string, sessionKey int64) { client := fmt.Sprintf("%s/%d", clientKey, sessionKey) _, err := s.writeMessage(time.Time{}, newRemoveClient(client)) if err != nil { - s.conn.conn.Close() + s.conn.Close() } } diff --git a/session_manager.go b/session_manager.go index c23d5d6..6fa0a05 100644 --- a/session_manager.go +++ b/session_manager.go @@ -77,9 +77,7 @@ func (sm *sessionManager) getDialer(clientKey string) (Dialer, error) { for _, sessions := range sm.peers { for _, session := range sessions { - session.Lock() - keys := session.remoteClientKeys[clientKey] - session.Unlock() + keys := session.getSessionKeys(clientKey) if len(keys) > 0 { return toDialer(session, clientKey), nil } @@ -91,7 +89,7 @@ func (sm *sessionManager) getDialer(clientKey string) (Dialer, error) { func (sm *sessionManager) add(clientKey string, conn *websocket.Conn, peer bool) *Session { sessionKey := rand.Int63() - session := newSession(sessionKey, clientKey, conn) + session := newSession(sessionKey, clientKey, newWSConn(conn)) sm.Lock() defer sm.Unlock() diff --git a/session_serve.go b/session_serve.go new file mode 100644 index 0000000..45661cd --- /dev/null +++ b/session_serve.go @@ -0,0 +1,124 @@ +package remotedialer + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/sirupsen/logrus" +) + +// serveMessage accepts an incoming message from the underlying websocket connection and processes the request based on its messageType +func (s *Session) serveMessage(ctx context.Context, reader io.Reader) error { + message, err := newServerMessage(reader) + if err != nil { + return err + } + + if PrintTunnelData { + logrus.Debug("REQUEST ", message) + } + + switch message.messageType { + case Connect: + return s.clientConnect(ctx, message) + case AddClient: + return s.addRemoteClient(message.address) + case RemoveClient: + return s.removeRemoteClient(message.address) + case Data: + s.connectionData(message.connID, message.body) + case Pause: + s.pauseConnection(message.connID) + case Resume: + s.resumeConnection(message.connID) + case Error: + s.closeConnection(message.connID, message.Err()) + } + return nil +} + +// clientConnect accepts a new connection request, dialing back to establish the connection +func (s *Session) clientConnect(ctx context.Context, message *message) error { + if s.auth == nil || !s.auth(message.proto, message.address) { + return errors.New("connect not allowed") + } + + conn := newConnection(message.connID, s, message.proto, message.address) + s.addConnection(message.connID, conn) + + go clientDial(ctx, s.dialer, conn, message) + + return nil +} + +// / addRemoteClient registers a new remote client, making it accessible for requests +func (s *Session) addRemoteClient(address string) error { + if s.remoteClientKeys == nil { + return nil + } + + clientKey, sessionKey, err := parseAddress(address) + if err != nil { + return fmt.Errorf("invalid remote Session %s: %v", address, err) + } + s.addSessionKey(clientKey, sessionKey) + + if PrintTunnelData { + logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) + } + + return nil +} + +// / addRemoteClient removes a given client from a session +func (s *Session) removeRemoteClient(address string) error { + clientKey, sessionKey, err := parseAddress(address) + if err != nil { + return fmt.Errorf("invalid remote Session %s: %v", address, err) + } + s.removeSessionKey(clientKey, sessionKey) + + if PrintTunnelData { + logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) + } + + return nil +} + +// closeConnection removes a connection for a given ID from the session, sending an error message to communicate the closing to the other end. +// If an error is not provided, io.EOF will be used instead. +func (s *Session) closeConnection(connID int64, err error) { + if conn := s.removeConnection(connID); conn != nil { + conn.tunnelClose(err) + } +} + +// connectionData process incoming data from connection by reading the body into an internal readBuffer +func (s *Session) connectionData(connID int64, body io.Reader) { + conn := s.getConnection(connID) + if conn == nil { + errMsg := newErrorMessage(connID, fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, connID)) + _, _ = errMsg.WriteTo(defaultDeadline(), s.conn) + return + } + + if err := conn.OnData(body); err != nil { + s.closeConnection(connID, err) + } +} + +// pauseConnection activates backPressure for a given connection ID +func (s *Session) pauseConnection(connID int64) { + if conn := s.getConnection(connID); conn != nil { + conn.OnPause() + } +} + +// resumeConnection deactivates backPressure for a given connection ID +func (s *Session) resumeConnection(connID int64) { + if conn := s.getConnection(connID); conn != nil { + conn.OnResume() + } +} diff --git a/session_serve_test.go b/session_serve_test.go new file mode 100644 index 0000000..019fece --- /dev/null +++ b/session_serve_test.go @@ -0,0 +1,136 @@ +package remotedialer + +import ( + "bytes" + "context" + "errors" + "fmt" + "math/rand" + "net" + "reflect" + "strings" + "testing" + "time" +) + +func TestSession_clientConnect(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + msgProto, msgAddr := "testproto", "testaddr" + s := setupDummySession(t, 0) + s.auth = func(proto, address string) bool { return proto == msgProto && address == msgAddr } + + dialerC := make(chan struct{}) + s.dialer = func(ctx context.Context, network, address string) (net.Conn, error) { + close(dialerC) + clientConn, _ := net.Pipe() + return clientConn, nil + } + + connID := getDummyConnectionID() + if err := s.clientConnect(ctx, newConnect(connID, msgProto, msgAddr)); err != nil { + t.Fatal(err) + } + + select { + case <-dialerC: + case <-time.After(1 * time.Second): + t.Errorf("timed out waiting for dialer") + } + + if conn := s.getConnection(connID); conn == nil { + t.Errorf("Connection not found in session for ID %d", connID) + } +} + +func TestSession_addRemoveRemoteClient(t *testing.T) { + s := setupDummySession(t, 0) + clientKey, sessionKey := "test", rand.Int() + + msgAddress := fmt.Sprintf("%s/%d", clientKey, sessionKey) + if err := s.addRemoteClient(msgAddress); err != nil { + t.Fatal(err) + } + + if got, want := s.getSessionKeys(clientKey), map[int]bool{sessionKey: true}; !reflect.DeepEqual(got, want) { + t.Errorf("remote client session was not added correctly, got %v, want %v", got, want) + } + + if err := s.removeRemoteClient(msgAddress); err != nil { + t.Fatal(err) + } + + if got, want := s.getSessionKeys(clientKey), 0; len(got) != want { + t.Errorf("remote client session was not removed correctly, got %v, want len(%d)", got, want) + } +} + +func TestSession_connectionData(t *testing.T) { + s := setupDummySession(t, 0) + connID := getDummyConnectionID() + conn := newConnection(connID, s, "test", "test") + s.addConnection(connID, conn) + + data := "testing!" + s.connectionData(connID, strings.NewReader(data)) + + if got, want := conn.buffer.offerCount, int64(len(data)); got != want { + t.Errorf("incorrect data length, got %d, want %d", got, want) + } + + buf := make([]byte, conn.buffer.offerCount) + if _, err := conn.buffer.Read(buf); err != nil { + t.Fatal(err) + } + if got, want := string(buf), data; got != want { + t.Errorf("incorrect data, got %q, want %q", got, want) + } +} + +func TestSession_pauseResumeConnection(t *testing.T) { + s := setupDummySession(t, 0) + connID := getDummyConnectionID() + conn := newConnection(connID, s, "test", "test") + s.addConnection(connID, conn) + + s.pauseConnection(connID) + if !conn.backPressure.paused { + t.Errorf("connection was not paused correctly") + } + + s.resumeConnection(connID) + if conn.backPressure.paused { + t.Errorf("connection was not resumed correctly") + } +} + +func TestSession_closeConnection(t *testing.T) { + s := setupDummySession(t, 0) + var msg *message + s.conn = &fakeWSConn{ + writeMessageCallback: func(msgType int, _ time.Time, data []byte) (err error) { + msg, err = newServerMessage(bytes.NewReader(data)) + return + }, + } + connID := getDummyConnectionID() + conn := newConnection(connID, s, "test", "test") + s.addConnection(connID, conn) + + expectedErr := errors.New("connection closed") + s.closeConnection(connID, expectedErr) + + if s.getConnection(connID) != nil { + t.Errorf("connection was not closed correctly") + } + if conn.err == nil { + t.Fatal("message not sent on closed connection") + } else if msg.messageType != Error { + t.Errorf("incorrect message type sent") + } else if got, want := msg.Err().Error(), expectedErr.Error(); got != want { + t.Errorf("wrong error, got %v, want %v", got, want) + } else if got, want := conn.err, expectedErr; !errors.Is(got, want) { + t.Errorf("wrong error, got %v, want %v", got, want) + } +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..6c5ae84 --- /dev/null +++ b/session_test.go @@ -0,0 +1,77 @@ +package remotedialer + +import ( + "math/rand" + "reflect" + "sync" + "sync/atomic" + "testing" +) + +var dummyConnectionsNextID int64 = 1 + +func getDummyConnectionID() int64 { + return atomic.AddInt64(&dummyConnectionsNextID, 1) +} + +func setupDummySession(t *testing.T, nConnections int) *Session { + t.Helper() + + s := newSession(rand.Int63(), "", nil) + + var wg sync.WaitGroup + ready := make(chan struct{}) + for i := 0; i < nConnections; i++ { + connID := getDummyConnectionID() + wg.Add(1) + go func() { + defer wg.Done() + <-ready + s.addConnection(connID, &connection{}) + }() + } + close(ready) + wg.Wait() + + if got, want := len(s.conns), nConnections; got != want { + t.Fatalf("incorrect number of connections, got: %d, want %d", got, want) + } + + return s +} + +func TestSession_connections(t *testing.T) { + const n = 10 + s := setupDummySession(t, n) + + connID, conn := getDummyConnectionID(), &connection{} + s.addConnection(connID, conn) + if got, want := len(s.conns), n+1; got != want { + t.Errorf("incorrect number of connections, got: %d, want %d", got, want) + } + if got, want := s.getConnection(connID), conn; got != want { + t.Errorf("incorrect result from getConnection, got: %v, want %v", got, want) + } + if got, want := s.removeConnection(connID), conn; got != want { + t.Errorf("incorrect result from removeConnection, got: %v, want %v", got, want) + } +} + +func TestSession_sessionKeys(t *testing.T) { + s := setupDummySession(t, 0) + + clientKey, sessionKey := "testkey", rand.Int() + s.addSessionKey(clientKey, sessionKey) + if got, want := len(s.remoteClientKeys), 1; got != want { + t.Errorf("incorrect number of remote client keys, got: %d, want %d", got, want) + } + + if got, want := s.getSessionKeys(clientKey), map[int]bool{sessionKey: true}; !reflect.DeepEqual(got, want) { + t.Errorf("incorrect result from getSessionKeys, got: %v, want %v", got, want) + } + + s.removeSessionKey(clientKey, sessionKey) + if got, want := len(s.remoteClientKeys), 0; got != want { + t.Errorf("incorrect number of remote client keys after removal, got: %d, want %d", got, want) + } +} diff --git a/wsconn.go b/wsconn.go index c0a8eae..9ca7664 100644 --- a/wsconn.go +++ b/wsconn.go @@ -10,20 +10,40 @@ import ( "github.com/gorilla/websocket" ) -type wsConn struct { +type wsWrapper struct { + // Mutex is used to protect from concurrent usage of the websocket connection sync.Mutex + // conn is the underlying websocket connection conn *websocket.Conn } -func newWSConn(conn *websocket.Conn) *wsConn { - w := &wsConn{ +func newWSConn(conn *websocket.Conn) *wsWrapper { + w := &wsWrapper{ conn: conn, } w.setupDeadline() return w } -func (w *wsConn) WriteMessage(messageType int, deadline time.Time, data []byte) error { +type wsConn interface { + // Close will indicate the underlying websocket connection + Close() error + // NextReader gets a new reader from the underlying websocket connection + NextReader() (int, io.Reader, error) + // WriteControl writes a new websocket control frame, see https://datatracker.ietf.org/doc/html/rfc6455#section-5.5 + WriteControl(messageType int, deadline time.Time, data []byte) error + // WriteMessage writes a new websocket data frame, see https://datatracker.ietf.org/doc/html/rfc6455#section-6 + WriteMessage(messageType int, deadline time.Time, data []byte) error +} + +func (w *wsWrapper) WriteControl(messageType int, deadline time.Time, data []byte) error { + w.Lock() + defer w.Unlock() + + return w.conn.WriteControl(messageType, data, deadline) +} + +func (w *wsWrapper) WriteMessage(messageType int, deadline time.Time, data []byte) error { if deadline.IsZero() { w.Lock() defer w.Unlock() @@ -48,11 +68,15 @@ func (w *wsConn) WriteMessage(messageType int, deadline time.Time, data []byte) } } -func (w *wsConn) NextReader() (int, io.Reader, error) { +func (w *wsWrapper) NextReader() (int, io.Reader, error) { return w.conn.NextReader() } -func (w *wsConn) setupDeadline() { +func (w *wsWrapper) Close() error { + return w.conn.Close() +} + +func (w *wsWrapper) setupDeadline() { w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) w.conn.SetPingHandler(func(string) error { w.Lock() diff --git a/wsconn_test.go b/wsconn_test.go new file mode 100644 index 0000000..3624b17 --- /dev/null +++ b/wsconn_test.go @@ -0,0 +1,30 @@ +package remotedialer + +import ( + "errors" + "io" + "time" +) + +type fakeWSConn struct { + writeMessageCallback func(int, time.Time, []byte) error +} + +func (f fakeWSConn) Close() error { + return nil +} + +func (f fakeWSConn) NextReader() (int, io.Reader, error) { + return 0, nil, errors.New("not implemented") +} + +func (f fakeWSConn) WriteMessage(messageType int, deadline time.Time, data []byte) error { + if cb := f.writeMessageCallback; cb != nil { + return cb(messageType, deadline, data) + } + return errors.New("callback not provided") +} + +func (f fakeWSConn) WriteControl(int, time.Time, []byte) error { + return errors.New("not implemented") +}