Skip to content

Commit

Permalink
Merge pull request #300 from ellemouton/revokeExpiredSessions
Browse files Browse the repository at this point in the history
multi: revoke expired sessions
  • Loading branch information
guggero authored Feb 8, 2022
2 parents 2c446cd + 7a0a18d commit 954daad
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 7 deletions.
22 changes: 17 additions & 5 deletions session/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ type GRPCServerCreator func(opts ...grpc.ServerOption) *grpc.Server
type mailboxSession struct {
server *grpc.Server

wg sync.WaitGroup
wg sync.WaitGroup
quit chan struct{}
}

func newMailboxSession() *mailboxSession {
return &mailboxSession{
quit: make(chan struct{}),
}
}

func (m *mailboxSession) start(session *Session,
Expand Down Expand Up @@ -62,6 +69,7 @@ func (m *mailboxSession) run(mailboxServer *mailbox.Server) {

func (m *mailboxSession) stop() {
m.server.Stop()
close(m.quit)
m.wg.Wait()
}

Expand All @@ -82,7 +90,9 @@ func NewServer(serverCreator GRPCServerCreator) *Server {
}
}

func (s *Server) StartSession(session *Session, authData []byte) error {
func (s *Server) StartSession(session *Session, authData []byte) (chan struct{},
error) {

s.activeSessionsMtx.Lock()
defer s.activeSessionsMtx.Unlock()

Expand All @@ -91,11 +101,13 @@ func (s *Server) StartSession(session *Session, authData []byte) error {

_, ok := s.activeSessions[id]
if ok {
return fmt.Errorf("session %x is already active", id[:])
return nil, fmt.Errorf("session %x is already active", id[:])
}

s.activeSessions[id] = &mailboxSession{}
return s.activeSessions[id].start(session, s.serverCreator, authData)
sess := newMailboxSession()
s.activeSessions[id] = sess

return sess.quit, sess.start(session, s.serverCreator, authData)
}

func (s *Server) StopSession(localPublicKey *btcec.PublicKey) error {
Expand Down
58 changes: 56 additions & 2 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"

"github.com/btcsuite/btcd/btcec"
Expand All @@ -20,6 +21,18 @@ type sessionRpcServer struct {

db *session.DB
sessionServer *session.Server

quit chan struct{}
wg sync.WaitGroup
stopOnce sync.Once
}

// stop cleans up any sessionRpcServer resources.
func (s *sessionRpcServer) stop() {
s.stopOnce.Do(func() {
close(s.quit)
s.wg.Wait()
})
}

// AddSession adds and starts a new Terminal Connect session.
Expand Down Expand Up @@ -72,6 +85,9 @@ func (s *sessionRpcServer) AddSession(_ context.Context,
// resumeSession tries to start an existing session if it is not expired, not
// revoked and a LiT session.
func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
pubKey := sess.LocalPublicKey
pubKeyBytes := pubKey.SerializeCompressed()

// We only start non-revoked, non-expired LiT sessions. Everything else
// we just skip.
if sess.State != session.StateInUse &&
Expand All @@ -88,12 +104,50 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error {
}
if sess.Expiry.Before(time.Now()) {
log.Debugf("Not resuming session %x with expiry %s",
sess.LocalPublicKey.SerializeCompressed(), sess.Expiry)
pubKeyBytes, sess.Expiry)

if err := s.db.RevokeSession(pubKey); err != nil {
return fmt.Errorf("error revoking session: %v", err)
}

return nil
}

authData := []byte("Authorization: Basic " + s.basicAuth)
return s.sessionServer.StartSession(sess, authData)
sessionClosedSub, err := s.sessionServer.StartSession(sess, authData)
if err != nil {
return err
}

s.wg.Add(1)
go func() {
defer s.wg.Done()

ticker := time.NewTimer(time.Until(sess.Expiry))
defer ticker.Stop()

select {
case <-s.quit:
case <-sessionClosedSub:
case <-ticker.C:
log.Debugf("Stopping expired session %x with "+
"type %d", pubKeyBytes, sess.Type)

err = s.sessionServer.StopSession(pubKey)
if err != nil {
log.Debugf("Error stopping session: "+
"%v", err)
}

err = s.db.RevokeSession(pubKey)
if err != nil {
log.Debugf("error revoking session: "+
"%v", err)
}
}
}()

return err
}

// ListSessions returns all sessions known to the session store.
Expand Down
2 changes: 2 additions & 0 deletions terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ func (g *LightningTerminal) Run() error {
basicAuth: g.rpcProxy.basicAuth,
db: g.sessionDB,
sessionServer: g.sessionServer,
quit: make(chan struct{}),
}

// Now start up all previously created sessions.
Expand Down Expand Up @@ -838,6 +839,7 @@ func (g *LightningTerminal) shutdown() error {
}
}

g.sessionRpcServer.stop()
if err := g.sessionDB.Close(); err != nil {
log.Errorf("Error closing session DB: %v", err)
returnErr = err
Expand Down

0 comments on commit 954daad

Please sign in to comment.