Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make client and server to resync active connections #74

Merged
merged 23 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e0dcdc0
Make client and server to resync active connections
aruiz14 Apr 3, 2024
ea777e2
add encoding/decoding of connection IDs
aruiz14 Apr 3, 2024
e35f83f
small refactor
aruiz14 Apr 4, 2024
3bcfd6a
allow listing active connection IDs
aruiz14 Apr 4, 2024
aad6651
new SyncConnections message type
aruiz14 Apr 4, 2024
a83c84d
send and receive resync messages
aruiz14 Apr 4, 2024
9d01b7c
add unit tests
aruiz14 Apr 5, 2024
80fa9cf
Move metrics recording inside doTunnelClose
aruiz14 Apr 5, 2024
0f8773d
rename close function, it should be a reserved word
aruiz14 Apr 5, 2024
abbd9e6
make comparison code more readable
aruiz14 Apr 8, 2024
fc34b29
docs: add comments to message types
aruiz14 Apr 26, 2024
53ecd7b
small refactor and avoid deadlock
aruiz14 Apr 29, 2024
5ecbee7
add comment to removedFromSlice
aruiz14 Apr 29, 2024
129b5d7
redo possibly flaky test
aruiz14 Apr 29, 2024
509131f
docs: add comments to new constant durations
aruiz14 Apr 29, 2024
2939ade
Rename removedFromSlice to diffSortedSetsGetRemoved
aruiz14 May 28, 2024
78d91a0
Fix typo
aruiz14 Jun 17, 2024
97868f9
Mark tests as safe to run in parallel
aruiz14 Jun 17, 2024
01aa1cd
Add more cases for diffSortedSetsGetRemoved test
aruiz14 Jun 17, 2024
ee938dc
Rename lockedRemoveConnection to removeConnectionLocked for consistency
aruiz14 Jun 17, 2024
14c3d27
Rename closeStaleConnections to compareAndCloseStaleConnections
aruiz14 Jun 17, 2024
0260c52
Avoid using named return parameter
aruiz14 Jun 17, 2024
9dc7b39
Fix loop variable access during parallelization
aruiz14 Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions client_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func pipe(client *connection, server net.Conn) {
wg := sync.WaitGroup{}
wg.Add(1)

close := func(err error) error {
closePipe := func(err error) error {
if err == nil {
err = io.EOF
}
Expand All @@ -50,11 +50,11 @@ func pipe(client *connection, server net.Conn) {
go func() {
defer wg.Done()
_, err := io.Copy(server, client)
close(err)
closePipe(err)
}()

_, err := io.Copy(client, server)
err = close(err)
err = closePipe(err)
wg.Wait()

// Write tunnel error after no more I/O is happening, just incase messages get out of order
Expand Down
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ func newConnection(connID int64, session *Session, proto, address string) *conne
}

func (c *connection) tunnelClose(err error) {
metrics.IncSMTotalRemoveConnectionsForWS(c.session.clientKey, c.addr.Network(), c.addr.String())
c.writeErr(err)
c.doTunnelClose(err)
}
Expand All @@ -46,6 +45,7 @@ func (c *connection) doTunnelClose(err error) {
return
}

metrics.IncSMTotalRemoveConnectionsForWS(c.session.clientKey, c.addr.Network(), c.addr.String())
aruiz14 marked this conversation as resolved.
Show resolved Hide resolved
c.err = err
if c.err == nil {
c.err = io.ErrClosedPipe
Expand Down
5 changes: 5 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ const (
Pause
// Resume is a message type used to resume a paused connection
Resume
// SyncConnections is a message type used to communicate active connection IDs.
// The receiver can consider any ID not present in this message as stale and free any associated resource.
SyncConnections
aruiz14 marked this conversation as resolved.
Show resolved Hide resolved
)

var (
Expand Down Expand Up @@ -244,6 +247,8 @@ func (m *message) String() string {
return fmt.Sprintf("%d PAUSE [%d]", m.id, m.connID)
case Resume:
return fmt.Sprintf("%d RESUME [%d]", m.id, m.connID)
case SyncConnections:
return fmt.Sprintf("%d SYNCCONNS [%d]", m.id, m.connID)
}
return fmt.Sprintf("%d UNKNOWN[%d]: %d", m.id, m.connID, m.messageType)
}
39 changes: 36 additions & 3 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"os"
"sort"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -83,14 +84,21 @@ func (s *Session) removeConnection(connID int64) *connection {
s.Lock()
defer s.Unlock()

conn := s.conns[connID]
delete(s.conns, connID)
conn := s.removeConnectionLocked(connID)
if PrintTunnelData {
defer logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
return conn
}

// removeConnectionLocked removes a given connection from the session.
// The session lock must be held by the caller when calling this method
func (s *Session) removeConnectionLocked(connID int64) *connection {
conn := s.conns[connID]
delete(s.conns, connID)
return conn
}

// getConnection retrieves a connection by ID
func (s *Session) getConnection(connID int64) *connection {
s.RLock()
Expand All @@ -99,6 +107,19 @@ func (s *Session) getConnection(connID int64) *connection {
return s.conns[connID]
}

// activeConnectionIDs returns an ordered list of IDs for the currently active connections
func (s *Session) activeConnectionIDs() []int64 {
s.RLock()
defer s.RUnlock()

res := make([]int64, 0, len(s.conns))
for id := range s.conns {
res = append(res, id)
}
sort.Slice(res, func(i, j int) bool { return res[i] < res[j] })
return res
}

// addSessionKey registers a new session key for a given client key
func (s *Session) addSessionKey(clientKey string, sessionKey int) {
s.Lock()
Expand Down Expand Up @@ -142,12 +163,19 @@ func (s *Session) startPings(rootCtx context.Context) {
t := time.NewTicker(PingWriteInterval)
defer t.Stop()

syncConnections := time.NewTicker(SyncConnectionsInterval)
defer syncConnections.Stop()

for {
select {
case <-ctx.Done():
return
case <-syncConnections.C:
if err := s.sendSyncConnections(); err != nil {
logrus.WithError(err).Error("Error syncing connections")
}
case <-t.C:
if err := s.conn.WriteControl(websocket.PingMessage, time.Now().Add(PingWaitDuration), []byte("")); err != nil {
if err := s.sendPing(); err != nil {
logrus.WithError(err).Error("Error writing ping")
}
logrus.Debug("Wrote ping")
Expand All @@ -156,6 +184,11 @@ func (s *Session) startPings(rootCtx context.Context) {
}()
}

// sendPing sends a Ping control message to the peer
func (s *Session) sendPing() error {
return s.conn.WriteControl(websocket.PingMessage, time.Now().Add(PingWaitDuration), []byte(""))
}

func (s *Session) stopPings() {
if s.pingCancel == nil {
return
Expand Down
17 changes: 17 additions & 0 deletions session_serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func (s *Session) serveMessage(ctx context.Context, reader io.Reader) error {
return s.addRemoteClient(message.address)
case RemoveClient:
return s.removeRemoteClient(message.address)
case SyncConnections:
return s.syncConnections(message.body)
case Data:
s.connectionData(message.connID, message.body)
case Pause:
Expand Down Expand Up @@ -87,6 +89,21 @@ func (s *Session) removeRemoteClient(address string) error {
return nil
}

// syncConnections closes any session connection that is not present in the IDs received from the client
func (s *Session) syncConnections(r io.Reader) error {
payload, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("reading message body: %w", err)
}
clientActiveConnections, err := decodeConnectionIDs(payload)
if err != nil {
return fmt.Errorf("decoding sync connections payload: %w", err)
}

s.compareAndCloseStaleConnections(clientActiveConnections)
return nil
}

// closeConnection removes a connection for a given ID from the session, sending an error message to communicate the closing to the other end.
// If an error is not provided, io.EOF will be used instead.
func (s *Session) closeConnection(connID int64, err error) {
Expand Down
86 changes: 86 additions & 0 deletions session_sync.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package remotedialer

import (
"encoding/binary"
"errors"
"fmt"
"time"
)

var errCloseSyncConnections = errors.New("sync from client")

// encodeConnectionIDs serializes a slice of connection IDs
func encodeConnectionIDs(ids []int64) []byte {
payload := make([]byte, 0, 8*len(ids))
aruiz14 marked this conversation as resolved.
Show resolved Hide resolved
for _, id := range ids {
payload = binary.LittleEndian.AppendUint64(payload, uint64(id))
}
return payload
}

// decodeConnectionIDs deserializes a slice of connection IDs
func decodeConnectionIDs(payload []byte) ([]int64, error) {
if len(payload)%8 != 0 {
return nil, fmt.Errorf("incorrect data format")
}
result := make([]int64, 0, len(payload)/8)
for x := 0; x < len(payload); x += 8 {
id := binary.LittleEndian.Uint64(payload[x : x+8])
result = append(result, int64(id))
}
return result, nil
}

func newSyncConnectionsMessage(connectionIDs []int64) *message {
return &message{
id: nextid(),
messageType: SyncConnections,
bytes: encodeConnectionIDs(connectionIDs),
}
}

// sendSyncConnections sends a binary message of type SyncConnections, whose payload is a list of the active connection IDs for this session
func (s *Session) sendSyncConnections() error {
_, err := s.writeMessage(time.Now().Add(SyncConnectionsTimeout), newSyncConnectionsMessage(s.activeConnectionIDs()))
return err
}

// compareAndCloseStaleConnections compares the Session's activeConnectionIDs with the provided list from the client, then closing every connection not present in it
func (s *Session) compareAndCloseStaleConnections(clientIDs []int64) {
serverIDs := s.activeConnectionIDs()
aruiz14 marked this conversation as resolved.
Show resolved Hide resolved
toClose := diffSortedSetsGetRemoved(serverIDs, clientIDs)
if len(toClose) == 0 {
return
}

s.Lock()
defer s.Unlock()
for _, id := range toClose {
// Connection no longer active in the client, close it server-side
conn := s.removeConnectionLocked(id)
if conn != nil {
// Using doTunnelClose directly instead of tunnelClose, omitting unnecessarily sending an Error message
conn.doTunnelClose(errCloseSyncConnections)
}
}
}

// diffSortedSetsGetRemoved compares two sorted slices and returns those items present in a that are not present in b
// similar to coreutil's "comm -23"
func diffSortedSetsGetRemoved(a, b []int64) []int64 {
var res []int64
var i, j int
for i < len(a) && j < len(b) {
if a[i] < b[j] { // present in "a", not in "b"
res = append(res, a[i])
i++
} else if a[i] > b[j] { // present in "b", not in "a"
j++
} else { // present in both
i++
j++
}
}
res = append(res, a[i:]...) // any remainders in "a" are also removed from "b"
return res
}
Loading