Skip to content

Commit

Permalink
Merge pull request #78 from aruiz14/refactor-session-serve
Browse files Browse the repository at this point in the history
Refactor Session serveMessage and mutex usage
  • Loading branch information
nflynt authored May 28, 2024
2 parents 9da880f + 03cf07c commit c3a9b3c
Show file tree
Hide file tree
Showing 9 changed files with 483 additions and 150 deletions.
4 changes: 2 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ func (c *connection) doTunnelClose(err error) {
c.buffer.Close(c.err)
}

func (c *connection) OnData(m *message) error {
func (c *connection) OnData(r io.Reader) error {
if PrintTunnelData {
defer func() {
logrus.Debugf("ONDATA [%d] %s", c.connID, c.buffer.Status())
}()
}
return c.buffer.Offer(m.body)
return c.buffer.Offer(r)
}

func (c *connection) Close() error {
Expand Down
11 changes: 10 additions & 1 deletion message.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@ import (
)

const (
// Data is the main message type, used to transport application data
Data messageType = iota + 1
// Connect is a control message type, used to request opening a new connection
Connect
// Error is a message type used to send an error during the communication.
// Any receiver of an Error message can assume the connection can be closed.
// io.EOF is used for graceful termination of connections.
Error
// AddClient is a message type used to open a new client to the peering session
AddClient
// RemoveClient is a message type used to remove an existing client from a peering session
RemoveClient
// Pause is a message type used to temporarily stop a given connection
Pause
// Resume is a message type used to resume a paused connection
Resume
)

Expand Down Expand Up @@ -211,7 +220,7 @@ func (m *message) Read(p []byte) (int, error) {
return m.body.Read(p)
}

func (m *message) WriteTo(deadline time.Time, wsConn *wsConn) (int, error) {
func (m *message) WriteTo(deadline time.Time, wsConn wsConn) (int, error) {
err := wsConn.WriteMessage(websocket.BinaryMessage, deadline, m.Bytes())
return len(m.bytes), err
}
Expand Down
209 changes: 72 additions & 137 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
Expand All @@ -18,12 +17,12 @@ import (
)

type Session struct {
sync.Mutex
sync.RWMutex

nextConnID int64
clientKey string
sessionKey int64
conn *wsConn
conn wsConn
conns map[int64]*connection
remoteClientKeys map[string]map[int]bool
auth ConnectAuthorizer
Expand Down Expand Up @@ -57,17 +56,81 @@ func NewClientSessionWithDialer(auth ConnectAuthorizer, conn *websocket.Conn, di
}
}

func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *Session {
func newSession(sessionKey int64, clientKey string, conn wsConn) *Session {
return &Session{
nextConnID: 1,
clientKey: clientKey,
sessionKey: sessionKey,
conn: newWSConn(conn),
conn: conn,
conns: map[int64]*connection{},
remoteClientKeys: map[string]map[int]bool{},
}
}

// addConnection safely registers a new connection in the connections map
func (s *Session) addConnection(connID int64, conn *connection) {
s.Lock()
defer s.Unlock()

s.conns[connID] = conn
if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
}

// removeConnection safely removes a connection by ID, returning the connection object
func (s *Session) removeConnection(connID int64) *connection {
s.Lock()
defer s.Unlock()

conn := s.conns[connID]
delete(s.conns, connID)
if PrintTunnelData {
defer logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
return conn
}

// getConnection retrieves a connection by ID
func (s *Session) getConnection(connID int64) *connection {
s.RLock()
defer s.RUnlock()

return s.conns[connID]
}

// addSessionKey registers a new session key for a given client key
func (s *Session) addSessionKey(clientKey string, sessionKey int) {
s.Lock()
defer s.Unlock()

keys := s.remoteClientKeys[clientKey]
if keys == nil {
keys = map[int]bool{}
s.remoteClientKeys[clientKey] = keys
}
keys[sessionKey] = true
}

// removeSessionKey removes a specific session key for a client key
func (s *Session) removeSessionKey(clientKey string, sessionKey int) {
s.Lock()
defer s.Unlock()

keys := s.remoteClientKeys[clientKey]
delete(keys, sessionKey)
if len(keys) == 0 {
delete(s.remoteClientKeys, clientKey)
}
}

// getSessionKeys retrieves all session keys for a given client key
func (s *Session) getSessionKeys(clientKey string) map[int]bool {
s.RLock()
defer s.RUnlock()
return s.remoteClientKeys[clientKey]
}

func (s *Session) startPings(rootCtx context.Context) {
ctx, cancel := context.WithCancel(rootCtx)
s.pingCancel = cancel
Expand All @@ -84,12 +147,10 @@ func (s *Session) startPings(rootCtx context.Context) {
case <-ctx.Done():
return
case <-t.C:
s.conn.Lock()
if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(PingWaitDuration)); err != nil {
if err := s.conn.WriteControl(websocket.PingMessage, time.Now().Add(PingWaitDuration), []byte("")); err != nil {
logrus.WithError(err).Error("Error writing ping")
}
logrus.Debug("Wrote ping")
s.conn.Unlock()
}
}
}()
Expand Down Expand Up @@ -125,61 +186,6 @@ func (s *Session) Serve(ctx context.Context) (int, error) {
}
}

func (s *Session) serveMessage(ctx context.Context, reader io.Reader) error {
message, err := newServerMessage(reader)
if err != nil {
return err
}

if PrintTunnelData {
logrus.Debug("REQUEST ", message)
}

if message.messageType == Connect {
if s.auth == nil || !s.auth(message.proto, message.address) {
return errors.New("connect not allowed")
}
s.clientConnect(ctx, message)
return nil
}

s.Lock()
if message.messageType == AddClient && s.remoteClientKeys != nil {
err := s.addRemoteClient(message.address)
s.Unlock()
return err
} else if message.messageType == RemoveClient {
err := s.removeRemoteClient(message.address)
s.Unlock()
return err
}
conn := s.conns[message.connID]
s.Unlock()

if conn == nil {
if message.messageType == Data {
err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID)
newErrorMessage(message.connID, err).WriteTo(defaultDeadline(), s.conn)
}
return nil
}

switch message.messageType {
case Data:
if err := conn.OnData(message); err != nil {
s.closeConnection(message.connID, err)
}
case Pause:
conn.OnPause()
case Resume:
conn.OnResume()
case Error:
s.closeConnection(message.connID, message.Err())
}

return nil
}

func defaultDeadline() time.Time {
return time.Now().Add(time.Minute)
}
Expand All @@ -193,72 +199,6 @@ func parseAddress(address string) (string, int, error) {
return parts[0], v, err
}

func (s *Session) addRemoteClient(address string) error {
clientKey, sessionKey, err := parseAddress(address)
if err != nil {
return fmt.Errorf("invalid remote Session %s: %v", address, err)
}

keys := s.remoteClientKeys[clientKey]
if keys == nil {
keys = map[int]bool{}
s.remoteClientKeys[clientKey] = keys
}
keys[sessionKey] = true

if PrintTunnelData {
logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
}

return nil
}

func (s *Session) removeRemoteClient(address string) error {
clientKey, sessionKey, err := parseAddress(address)
if err != nil {
return fmt.Errorf("invalid remote Session %s: %v", address, err)
}

keys := s.remoteClientKeys[clientKey]
delete(keys, int(sessionKey))
if len(keys) == 0 {
delete(s.remoteClientKeys, clientKey)
}

if PrintTunnelData {
logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
}

return nil
}

func (s *Session) closeConnection(connID int64, err error) {
s.Lock()
conn := s.conns[connID]
delete(s.conns, connID)
if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
s.Unlock()

if conn != nil {
conn.tunnelClose(err)
}
}

func (s *Session) clientConnect(ctx context.Context, message *message) {
conn := newConnection(message.connID, s, message.proto, message.address)

s.Lock()
s.conns[message.connID] = conn
if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
s.Unlock()

go clientDial(ctx, s.dialer, conn, message)
}

type connResult struct {
conn net.Conn
err error
Expand Down Expand Up @@ -299,12 +239,7 @@ func (s *Session) serverConnect(deadline time.Time, proto, address string) (net.
connID := atomic.AddInt64(&s.nextConnID, 1)
conn := newConnection(connID, s, proto, address)

s.Lock()
s.conns[connID] = conn
if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
s.Unlock()
s.addConnection(connID, conn)

_, err := s.writeMessage(deadline, newConnect(connID, proto, address))
if err != nil {
Expand Down Expand Up @@ -339,14 +274,14 @@ func (s *Session) sessionAdded(clientKey string, sessionKey int64) {
client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
_, err := s.writeMessage(time.Time{}, newAddClient(client))
if err != nil {
s.conn.conn.Close()
s.conn.Close()
}
}

func (s *Session) sessionRemoved(clientKey string, sessionKey int64) {
client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
_, err := s.writeMessage(time.Time{}, newRemoveClient(client))
if err != nil {
s.conn.conn.Close()
s.conn.Close()
}
}
6 changes: 2 additions & 4 deletions session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ func (sm *sessionManager) getDialer(clientKey string) (Dialer, error) {

for _, sessions := range sm.peers {
for _, session := range sessions {
session.Lock()
keys := session.remoteClientKeys[clientKey]
session.Unlock()
keys := session.getSessionKeys(clientKey)
if len(keys) > 0 {
return toDialer(session, clientKey), nil
}
Expand All @@ -91,7 +89,7 @@ func (sm *sessionManager) getDialer(clientKey string) (Dialer, error) {

func (sm *sessionManager) add(clientKey string, conn *websocket.Conn, peer bool) *Session {
sessionKey := rand.Int63()
session := newSession(sessionKey, clientKey, conn)
session := newSession(sessionKey, clientKey, newWSConn(conn))

sm.Lock()
defer sm.Unlock()
Expand Down
Loading

0 comments on commit c3a9b3c

Please sign in to comment.