diff --git a/session/server.go b/session/server.go index da3dfb811..2acb802e5 100644 --- a/session/server.go +++ b/session/server.go @@ -19,7 +19,14 @@ type GRPCServerCreator func(opts ...grpc.ServerOption) *grpc.Server type mailboxSession struct { server *grpc.Server - wg sync.WaitGroup + wg sync.WaitGroup + quit chan struct{} +} + +func newMailboxSession() *mailboxSession { + return &mailboxSession{ + quit: make(chan struct{}), + } } func (m *mailboxSession) start(session *Session, @@ -62,6 +69,7 @@ func (m *mailboxSession) run(mailboxServer *mailbox.Server) { func (m *mailboxSession) stop() { m.server.Stop() + close(m.quit) m.wg.Wait() } @@ -82,7 +90,9 @@ func NewServer(serverCreator GRPCServerCreator) *Server { } } -func (s *Server) StartSession(session *Session, authData []byte) error { +func (s *Server) StartSession(session *Session, authData []byte) (chan struct{}, + error) { + s.activeSessionsMtx.Lock() defer s.activeSessionsMtx.Unlock() @@ -91,11 +101,13 @@ func (s *Server) StartSession(session *Session, authData []byte) error { _, ok := s.activeSessions[id] if ok { - return fmt.Errorf("session %x is already active", id[:]) + return nil, fmt.Errorf("session %x is already active", id[:]) } - s.activeSessions[id] = &mailboxSession{} - return s.activeSessions[id].start(session, s.serverCreator, authData) + sess := newMailboxSession() + s.activeSessions[id] = sess + + return sess.quit, sess.start(session, s.serverCreator, authData) } func (s *Server) StopSession(localPublicKey *btcec.PublicKey) error { diff --git a/session_rpcserver.go b/session_rpcserver.go index 88f78d5a5..d2e459f02 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "sync" "time" "github.com/btcsuite/btcd/btcec" @@ -20,6 +21,18 @@ type sessionRpcServer struct { db *session.DB sessionServer *session.Server + + quit chan struct{} + wg sync.WaitGroup + stopOnce sync.Once +} + +// stop cleans up any sessionRpcServer resources. +func (s *sessionRpcServer) stop() { + s.stopOnce.Do(func() { + close(s.quit) + s.wg.Wait() + }) } // AddSession adds and starts a new Terminal Connect session. @@ -72,6 +85,9 @@ func (s *sessionRpcServer) AddSession(_ context.Context, // resumeSession tries to start an existing session if it is not expired, not // revoked and a LiT session. func (s *sessionRpcServer) resumeSession(sess *session.Session) error { + pubKey := sess.LocalPublicKey + pubKeyBytes := pubKey.SerializeCompressed() + // We only start non-revoked, non-expired LiT sessions. Everything else // we just skip. if sess.State != session.StateInUse && @@ -88,12 +104,50 @@ func (s *sessionRpcServer) resumeSession(sess *session.Session) error { } if sess.Expiry.Before(time.Now()) { log.Debugf("Not resuming session %x with expiry %s", - sess.LocalPublicKey.SerializeCompressed(), sess.Expiry) + pubKeyBytes, sess.Expiry) + + if err := s.db.RevokeSession(pubKey); err != nil { + return fmt.Errorf("error revoking session: %v", err) + } + return nil } authData := []byte("Authorization: Basic " + s.basicAuth) - return s.sessionServer.StartSession(sess, authData) + sessionClosedSub, err := s.sessionServer.StartSession(sess, authData) + if err != nil { + return err + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + + ticker := time.NewTimer(time.Until(sess.Expiry)) + defer ticker.Stop() + + select { + case <-s.quit: + case <-sessionClosedSub: + case <-ticker.C: + log.Debugf("Stopping expired session %x with "+ + "type %d", pubKeyBytes, sess.Type) + + err = s.sessionServer.StopSession(pubKey) + if err != nil { + log.Debugf("Error stopping session: "+ + "%v", err) + } + + err = s.db.RevokeSession(pubKey) + if err != nil { + log.Debugf("error revoking session: "+ + "%v", err) + } + } + }() + + return err } // ListSessions returns all sessions known to the session store. diff --git a/terminal.go b/terminal.go index 943696d28..ae1704392 100644 --- a/terminal.go +++ b/terminal.go @@ -227,6 +227,7 @@ func (g *LightningTerminal) Run() error { basicAuth: g.rpcProxy.basicAuth, db: g.sessionDB, sessionServer: g.sessionServer, + quit: make(chan struct{}), } // Now start up all previously created sessions. @@ -838,6 +839,7 @@ func (g *LightningTerminal) shutdown() error { } } + g.sessionRpcServer.stop() if err := g.sessionDB.Close(); err != nil { log.Errorf("Error closing session DB: %v", err) returnErr = err