diff --git a/command/agent/http.go b/command/agent/http.go index a7a3c4d58756..8c174711bf70 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -188,6 +188,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.HandleFunc("/v1/session/create", s.wrap(s.SessionCreate)) s.mux.HandleFunc("/v1/session/destroy/", s.wrap(s.SessionDestroy)) + s.mux.HandleFunc("/v1/session/renew/", s.wrap(s.SessionRenew)) s.mux.HandleFunc("/v1/session/info/", s.wrap(s.SessionGet)) s.mux.HandleFunc("/v1/session/node/", s.wrap(s.SessionsForNode)) s.mux.HandleFunc("/v1/session/list", s.wrap(s.SessionList)) diff --git a/command/agent/session_endpoint.go b/command/agent/session_endpoint.go index cd9fa7ecdb74..878a0b5843bf 100644 --- a/command/agent/session_endpoint.go +++ b/command/agent/session_endpoint.go @@ -40,6 +40,7 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request) Checks: []string{consul.SerfCheckID}, LockDelay: 15 * time.Second, Behavior: structs.SessionKeysRelease, + TTL: "", }, } s.parseDC(req, &args.Datacenter) @@ -51,6 +52,21 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request) resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err))) return nil, nil } + + if args.Session.TTL != "" { + ttl, err := time.ParseDuration(args.Session.TTL) + if err != nil { + resp.WriteHeader(400) + resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err))) + return nil, nil + } + + if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax { + resp.WriteHeader(400) + resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax))) + return nil, nil + } + } } // Create the session, get the ID @@ -130,6 +146,39 @@ func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request) return true, nil } +// SessionRenew is used to renew the TTL on an existing TTL session +func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + // Mandate a PUT request + if req.Method != "PUT" { + resp.WriteHeader(405) + return nil, nil + } + + args := structs.SessionSpecificRequest{} + if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { + return nil, nil + } + + // Pull out the session id + args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/") + if args.Session == "" { + resp.WriteHeader(400) + resp.Write([]byte("Missing session")) + return nil, nil + } + + var out structs.IndexedSessions + if err := s.agent.RPC("Session.Renew", &args, &out); err != nil { + return nil, err + } else if out.Sessions == nil { + resp.WriteHeader(404) + resp.Write([]byte(fmt.Sprintf("Session id '%s' not found", args.Session))) + return nil, nil + } + + return out.Sessions, nil +} + // SessionGet is used to get info for a particular session func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) { args := structs.SessionSpecificRequest{} diff --git a/command/agent/session_endpoint_test.go b/command/agent/session_endpoint_test.go index 74ec20ebf27a..edfa074402f4 100644 --- a/command/agent/session_endpoint_test.go +++ b/command/agent/session_endpoint_test.go @@ -176,6 +176,28 @@ func makeTestSessionDelete(t *testing.T, srv *HTTPServer) string { return sessResp.ID } +func makeTestSessionTTL(t *testing.T, srv *HTTPServer, ttl string) string { + // Create Session with TTL + body := bytes.NewBuffer(nil) + enc := json.NewEncoder(body) + raw := map[string]interface{}{ + "TTL": ttl, + } + enc.Encode(raw) + + req, err := http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := httptest.NewRecorder() + obj, err := srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + sessResp := obj.(sessionCreateResponse) + return sessResp.ID +} + func TestSessionDestroy(t *testing.T) { httpTest(t, func(srv *HTTPServer) { id := makeTestSession(t, srv) @@ -192,6 +214,206 @@ func TestSessionDestroy(t *testing.T) { }) } +func TestSessionTTL(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + TTL := "10s" // use the minimum legal ttl + ttl := 10 * time.Second + + id := makeTestSessionTTL(t, srv, TTL) + + req, err := http.NewRequest("GET", + "/v1/session/info/"+id, nil) + resp := httptest.NewRecorder() + obj, err := srv.SessionGet(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok := obj.(structs.Sessions) + if !ok { + t.Fatalf("should work") + } + if len(respObj) != 1 { + t.Fatalf("bad: %v", respObj) + } + if respObj[0].TTL != TTL { + t.Fatalf("Incorrect TTL: %s", respObj[0].TTL) + } + + time.Sleep(ttl*structs.SessionTTLMultiplier + ttl) + + req, err = http.NewRequest("GET", + "/v1/session/info/"+id, nil) + resp = httptest.NewRecorder() + obj, err = srv.SessionGet(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok = obj.(structs.Sessions) + if len(respObj) != 0 { + t.Fatalf("session '%s' should have been destroyed", id) + } + }) +} + +func TestSessionBadTTL(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + badTTL := "10z" + + // Create Session with illegal TTL + body := bytes.NewBuffer(nil) + enc := json.NewEncoder(body) + raw := map[string]interface{}{ + "TTL": badTTL, + } + enc.Encode(raw) + + req, err := http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp := httptest.NewRecorder() + obj, err := srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("illegal TTL '%s' allowed", badTTL) + } + if resp.Code != 400 { + t.Fatalf("Bad response code, should be 400") + } + + // less than SessionTTLMin + body = bytes.NewBuffer(nil) + enc = json.NewEncoder(body) + raw = map[string]interface{}{ + "TTL": "5s", + } + enc.Encode(raw) + + req, err = http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = httptest.NewRecorder() + obj, err = srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("illegal TTL '%s' allowed", badTTL) + } + if resp.Code != 400 { + t.Fatalf("Bad response code, should be 400") + } + + // more than SessionTTLMax + body = bytes.NewBuffer(nil) + enc = json.NewEncoder(body) + raw = map[string]interface{}{ + "TTL": "4000s", + } + enc.Encode(raw) + + req, err = http.NewRequest("PUT", "/v1/session/create", body) + if err != nil { + t.Fatalf("err: %v", err) + } + resp = httptest.NewRecorder() + obj, err = srv.SessionCreate(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if obj != nil { + t.Fatalf("illegal TTL '%s' allowed", badTTL) + } + if resp.Code != 400 { + t.Fatalf("Bad response code, should be 400") + } + }) +} + +func TestSessionTTLRenew(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + TTL := "10s" // use the minimum legal ttl + ttl := 10 * time.Second + + id := makeTestSessionTTL(t, srv, TTL) + + req, err := http.NewRequest("GET", + "/v1/session/info/"+id, nil) + resp := httptest.NewRecorder() + obj, err := srv.SessionGet(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok := obj.(structs.Sessions) + if !ok { + t.Fatalf("should work") + } + if len(respObj) != 1 { + t.Fatalf("bad: %v", respObj) + } + if respObj[0].TTL != TTL { + t.Fatalf("Incorrect TTL: %s", respObj[0].TTL) + } + + // Sleep to consume some time before renew + time.Sleep(ttl * (structs.SessionTTLMultiplier / 2)) + + req, err = http.NewRequest("PUT", + "/v1/session/renew/"+id, nil) + resp = httptest.NewRecorder() + obj, err = srv.SessionRenew(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok = obj.(structs.Sessions) + if !ok { + t.Fatalf("should work") + } + if len(respObj) != 1 { + t.Fatalf("bad: %v", respObj) + } + + // Sleep for ttl * TTL Multiplier + time.Sleep(ttl * structs.SessionTTLMultiplier) + + req, err = http.NewRequest("GET", + "/v1/session/info/"+id, nil) + resp = httptest.NewRecorder() + obj, err = srv.SessionGet(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok = obj.(structs.Sessions) + if !ok { + t.Fatalf("session '%s' should have renewed", id) + } + if len(respObj) != 1 { + t.Fatalf("session '%s' should have renewed", id) + } + + // now wait for timeout and expect session to get destroyed + time.Sleep(ttl * structs.SessionTTLMultiplier) + + req, err = http.NewRequest("GET", + "/v1/session/info/"+id, nil) + resp = httptest.NewRecorder() + obj, err = srv.SessionGet(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + respObj, ok = obj.(structs.Sessions) + if !ok { + t.Fatalf("session '%s' should have destroyed", id) + } + if len(respObj) != 0 { + t.Fatalf("session '%s' should have destroyed", id) + } + }) +} + func TestSessionGet(t *testing.T) { httpTest(t, func(srv *HTTPServer) { id := makeTestSession(t, srv) diff --git a/consul/leader.go b/consul/leader.go index 7f4f378a64ef..c57adb272ef9 100644 --- a/consul/leader.go +++ b/consul/leader.go @@ -61,6 +61,13 @@ func (s *Server) leaderLoop(stopCh chan struct{}) { s.logger.Printf("[ERR] consul: ACL initialization failed: %v", err) } + // Setup Session Timers if we are the leader and need to + if err := s.initializeSessionTimers(); err != nil { + s.logger.Printf("[ERR] consul: Session Timers initialization failed: %v", err) + } + // clear the session timers if we are no longer leader and exit the leaderLoop + defer s.clearAllSessionTimers() + // Reconcile channel is only used once initial reconcile // has succeeded var reconcileCh chan serf.Member diff --git a/consul/leader_test.go b/consul/leader_test.go index 75535d56bb8a..1bd910858bda 100644 --- a/consul/leader_test.go +++ b/consul/leader_test.go @@ -370,6 +370,9 @@ func TestLeader_LeftLeader(t *testing.T) { break } } + if leader == nil { + t.Fatalf("Should have a leader") + } leader.Leave() leader.Shutdown() time.Sleep(100 * time.Millisecond) diff --git a/consul/server.go b/consul/server.go index cba7f11bad90..f14dc2aa1176 100644 --- a/consul/server.go +++ b/consul/server.go @@ -128,6 +128,12 @@ type Server struct { // which SHOULD only consist of Consul servers serfWAN *serf.Serf + // sessionTimers track the expiration time of each Session that has + // a TTL. On expiration, a SessionDestroy event will occur, and + // destroy the session via standard session destory processing + sessionTimers map[string]*time.Timer + sessionTimersLock sync.RWMutex + shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex diff --git a/consul/session_endpoint.go b/consul/session_endpoint.go index cfe60ea5702e..98728d8a6a99 100644 --- a/consul/session_endpoint.go +++ b/consul/session_endpoint.go @@ -36,6 +36,16 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { default: return fmt.Errorf("Invalid Behavior setting '%s'", args.Session.Behavior) } + if args.Session.TTL != "" { + ttl, err := time.ParseDuration(args.Session.TTL) + if err != nil { + return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err) + } + + if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax { + return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]", ttl, structs.SessionTTLMin, structs.SessionTTLMax) + } + } // If this is a create, we must generate the Session ID. This must // be done prior to appending to the raft log, because the ID is not @@ -63,6 +73,13 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { s.srv.logger.Printf("[ERR] consul.session: Apply failed: %v", err) return err } + + if args.Op == structs.SessionCreate && args.Session.TTL != "" { + s.srv.resetSessionTimer(args.Session.ID, nil) + } else if args.Op == structs.SessionDestroy && args.Session.TTL != "" { + s.srv.clearSessionTimer(args.Session.ID) + } + if respErr, ok := resp.(error); ok { return respErr } @@ -133,3 +150,24 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest, return err }) } + +// Renew is used to renew the TTL on a single session +func (s *Session) Renew(args *structs.SessionSpecificRequest, + reply *structs.IndexedSessions) error { + if done, err := s.srv.forward("Session.Renew", args, args, reply); done { + return err + } + + // Get the local state + state := s.srv.fsm.State() + // Get the session, from local state + index, session, err := state.SessionGet(args.Session) + reply.Index = index + if session != nil { + reply.Sessions = structs.Sessions{session} + // reset the session TTL timer + err = s.srv.resetSessionTimer(args.Session, session) + } + + return err +} diff --git a/consul/session_endpoint_test.go b/consul/session_endpoint_test.go index 794975294790..635d2479b21c 100644 --- a/consul/session_endpoint_test.go +++ b/consul/session_endpoint_test.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/consul/testutil" "os" "testing" + "time" ) func TestSessionEndpoint_Apply(t *testing.T) { @@ -223,6 +224,161 @@ func TestSessionEndpoint_List(t *testing.T) { } } +func TestSessionEndpoint_Renew(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + client := rpcClient(t, s1) + defer client.Close() + + testutil.WaitForLeader(t, client.Call, "dc1") + TTL := "10s" // the minimum allowed ttl + ttl := 10 * time.Second + + s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + ids := []string{} + for i := 0; i < 5; i++ { + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + TTL: TTL, + }, + } + var out string + if err := client.Call("Session.Apply", &arg, &out); err != nil { + t.Fatalf("err: %v", err) + } + ids = append(ids, out) + } + + getR := structs.DCSpecificRequest{ + Datacenter: "dc1", + } + + var sessions structs.IndexedSessions + if err := client.Call("Session.List", &getR, &sessions); err != nil { + t.Fatalf("err: %v", err) + } + + if sessions.Index == 0 { + t.Fatalf("Bad: %v", sessions) + } + if len(sessions.Sessions) != 5 { + t.Fatalf("Bad: %v", sessions.Sessions) + } + for i := 0; i < len(sessions.Sessions); i++ { + s := sessions.Sessions[i] + if !strContains(ids, s.ID) { + t.Fatalf("bad: %v", s) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + if s.TTL != TTL { + t.Fatalf("bad session TTL: %s %v", s.TTL, s) + } + t.Logf("Created session '%s'", s.ID) + } + + // Sleep for time shorter than internal destroy ttl + time.Sleep(ttl * structs.SessionTTLMultiplier / 2) + + // renew 3 out of 5 sessions + for i := 0; i < 3; i++ { + renewR := structs.SessionSpecificRequest{ + Datacenter: "dc1", + Session: ids[i], + } + var session structs.IndexedSessions + if err := client.Call("Session.Renew", &renewR, &session); err != nil { + t.Fatalf("err: %v", err) + } + + if session.Index == 0 { + t.Fatalf("Bad: %v", session) + } + if len(session.Sessions) != 1 { + t.Fatalf("Bad: %v", session.Sessions) + } + + s := session.Sessions[0] + if !strContains(ids, s.ID) { + t.Fatalf("bad: %v", s) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + + t.Logf("Renewed session '%s'", s.ID) + } + + // now sleep for 2/3 the internal destroy TTL time for renewed sessions + // which is more than the internal destroy TTL time for the non-renewed sessions + time.Sleep((ttl * structs.SessionTTLMultiplier) * 2.0 / 3.0) + + var sessionsL1 structs.IndexedSessions + if err := client.Call("Session.List", &getR, &sessionsL1); err != nil { + t.Fatalf("err: %v", err) + } + + if sessionsL1.Index == 0 { + t.Fatalf("Bad: %v", sessionsL1) + } + + t.Logf("Expect 2 sessions to be destroyed") + + for i := 0; i < len(sessionsL1.Sessions); i++ { + s := sessionsL1.Sessions[i] + if !strContains(ids, s.ID) { + t.Fatalf("bad: %v", s) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + if s.TTL != TTL { + t.Fatalf("bad: %v", s) + } + if i > 2 { + t.Errorf("session '%s' should be destroyed", s.ID) + } + } + + if len(sessionsL1.Sessions) > 3 { + t.Fatalf("Bad: %v", sessionsL1.Sessions) + } + + // now sleep again for ttl*2 - no sessions should still be alive + time.Sleep(ttl * structs.SessionTTLMultiplier) + + var sessionsL2 structs.IndexedSessions + if err := client.Call("Session.List", &getR, &sessionsL2); err != nil { + t.Fatalf("err: %v", err) + } + + if sessionsL2.Index == 0 { + t.Fatalf("Bad: %v", sessionsL2) + } + if len(sessionsL2.Sessions) != 0 { + for i := 0; i < len(sessionsL2.Sessions); i++ { + s := sessionsL2.Sessions[i] + if !strContains(ids, s.ID) { + t.Fatalf("bad: %v", s) + } + if s.Node != "foo" { + t.Fatalf("bad: %v", s) + } + if s.TTL != TTL { + t.Fatalf("bad: %v", s) + } + t.Errorf("session '%s' should be destroyed", s.ID) + } + + t.Fatalf("Bad: %v", sessionsL2.Sessions) + } +} + func TestSessionEndpoint_NodeSessions(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) diff --git a/consul/session_ttl.go b/consul/session_ttl.go new file mode 100644 index 000000000000..06c572be2591 --- /dev/null +++ b/consul/session_ttl.go @@ -0,0 +1,106 @@ +package consul + +import ( + "fmt" + "github.com/hashicorp/consul/consul/structs" + "time" +) + +func (s *Server) initializeSessionTimers() error { + s.sessionTimersLock.Lock() + s.sessionTimers = make(map[string]*time.Timer) + s.sessionTimersLock.Unlock() + + // walk the TTL index and resetSessionTimer for each non-zero TTL + state := s.fsm.State() + _, sessions, err := state.SessionListTTL() + if err != nil { + return err + } + for _, session := range sessions { + err := s.resetSessionTimer(session.ID, session) + if err != nil { + return err + } + } + return nil +} + +// invalidate the session when timer expires, called by AfterFunc +func (s *Server) invalidateSession(id string) { + args := structs.SessionRequest{ + Datacenter: s.config.Datacenter, + Op: structs.SessionDestroy, + } + args.Session.ID = id + + // Apply the update to destroy the session + _, err := s.raftApply(structs.SessionRequestType, args) + if err != nil { + s.logger.Printf("[ERR] consul.session: Apply failed: %v", err) + } +} + +func (s *Server) resetSessionTimer(id string, session *structs.Session) error { + if session == nil { + var err error + + // find the session + state := s.fsm.State() + _, session, err = state.SessionGet(id) + if err != nil || session == nil { + return fmt.Errorf("Could not find session for '%s'\n", id) + } + } + + if session.TTL == "" { + return nil + } + + ttl, err := time.ParseDuration(session.TTL) + if err != nil { + return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err) + } + if ttl == 0 { + return nil + } + + s.sessionTimersLock.Lock() + if s.sessionTimers == nil { + s.sessionTimers = make(map[string]*time.Timer) + } + defer s.sessionTimersLock.Unlock() + if t := s.sessionTimers[id]; t != nil { + // TBD may modify the session's active TTL based on load here + t.Reset(ttl * structs.SessionTTLMultiplier) + } else { + s.sessionTimers[session.ID] = time.AfterFunc(ttl*structs.SessionTTLMultiplier, func() { + s.invalidateSession(session.ID) + }) + } + + return nil +} + +func (s *Server) clearSessionTimer(id string) error { + s.sessionTimersLock.Lock() + defer s.sessionTimersLock.Unlock() + if s.sessionTimers[id] != nil { + // stop the session timer and delete from the map + s.sessionTimers[id].Stop() + delete(s.sessionTimers, id) + } + return nil +} + +func (s *Server) clearAllSessionTimers() error { + s.sessionTimersLock.Lock() + defer s.sessionTimersLock.Unlock() + + // stop all timers and clear out the map + for _, t := range s.sessionTimers { + t.Stop() + } + s.sessionTimers = nil + return nil +} diff --git a/consul/session_ttl_test.go b/consul/session_ttl_test.go new file mode 100644 index 000000000000..d26264f03649 --- /dev/null +++ b/consul/session_ttl_test.go @@ -0,0 +1,168 @@ +package consul + +import ( + "errors" + "fmt" + "os" + "testing" + "time" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/consul/testutil" +) + +func TestServer_sessionTTL(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + dir2, s2 := testServerDCBootstrap(t, "dc1", false) + defer os.RemoveAll(dir2) + defer s2.Shutdown() + + dir3, s3 := testServerDCBootstrap(t, "dc1", false) + defer os.RemoveAll(dir3) + defer s3.Shutdown() + servers := []*Server{s1, s2, s3} + + // Try to join + addr := fmt.Sprintf("127.0.0.1:%d", + s1.config.SerfLANConfig.MemberlistConfig.BindPort) + if _, err := s2.JoinLAN([]string{addr}); err != nil { + t.Fatalf("err: %v", err) + } + if _, err := s3.JoinLAN([]string{addr}); err != nil { + t.Fatalf("err: %v", err) + } + + for _, s := range servers { + testutil.WaitForResult(func() (bool, error) { + peers, _ := s.raftPeers.Peers() + return len(peers) == 3, nil + }, func(err error) { + t.Fatalf("should have 3 peers") + }) + } + + // Find the leader + var leader *Server + for _, s := range servers { + // check that s.sessionTimers is empty + if len(s.sessionTimers) != 0 { + t.Fatalf("should have no sessionTimers") + } + // find the leader too + if s.IsLeader() { + leader = s + } + } + + if leader == nil { + t.Fatalf("Should have a leader") + } + + client := rpcClient(t, leader) + defer client.Close() + + leader.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + + // create a TTL session + arg := structs.SessionRequest{ + Datacenter: "dc1", + Op: structs.SessionCreate, + Session: structs.Session{ + Node: "foo", + TTL: "10s", + }, + } + var id1 string + if err := client.Call("Session.Apply", &arg, &id1); err != nil { + t.Fatalf("err: %v", err) + } + + // check that leader.sessionTimers has the session id in it + // means initializeSessionTimers was called and resetSessionTimer was called + if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil { + t.Fatalf("sessionTimers not initialized and does not contain session timer for session") + } + + time.Sleep(100 * time.Millisecond) + leader.Leave() + leader.Shutdown() + + // leader.sessionTimers should be empty due to clearAllSessionTimers getting called + if len(leader.sessionTimers) != 0 { + t.Fatalf("session timers should be empty on the shutdown leader") + } + + time.Sleep(100 * time.Millisecond) + + var remain *Server + for _, s := range servers { + if s == leader { + continue + } + remain = s + testutil.WaitForResult(func() (bool, error) { + peers, _ := s.raftPeers.Peers() + return len(peers) == 2, errors.New(fmt.Sprintf("%v", peers)) + }, func(err error) { + t.Fatalf("should have 2 peers: %v", err) + }) + } + + // Verify the old leader is deregistered + state := remain.fsm.State() + testutil.WaitForResult(func() (bool, error) { + _, found, _ := state.GetNode(leader.config.NodeName) + return !found, nil + }, func(err error) { + t.Fatalf("leader should be deregistered") + }) + + // Find the new leader + leader = nil + for _, s := range servers { + // find the leader too + if s.IsLeader() { + leader = s + } + } + + if leader == nil { + t.Fatalf("Should have a new leader") + } + + // check that new leader.sessionTimers has the session id in it + if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil { + t.Fatalf("sessionTimers not initialized and does not contain session timer for session") + } + + // create another TTL session with the same parameters + var id2 string + if err := client.Call("Session.Apply", &arg, &id2); err != nil { + t.Fatalf("err: %v", err) + } + + if len(leader.sessionTimers) != 2 { + t.Fatalf("sessionTimes length should be 2") + } + + // destroy the via invalidateSession as if on TTL expiry + leader.invalidateSession(id2) + + if len(leader.sessionTimers) != 1 { + t.Fatalf("sessionTimers length should 1") + } + + // destroy the id2 session (test clearSessionTimer) + arg.Op = structs.SessionDestroy + arg.Session.ID = id2 + if err := client.Call("Session.Apply", &arg, &id2); err != nil { + t.Fatalf("err: %v", err) + } + + if len(leader.sessionTimers) != 0 { + t.Fatalf("sessionTimers length should be 0") + } +} diff --git a/consul/state_store.go b/consul/state_store.go index 0173cff697d8..72538bb783e4 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -294,6 +294,10 @@ func (s *StateStore) initialize() error { AllowBlank: true, Fields: []string{"Node"}, }, + "ttl": &MDBIndex{ + AllowBlank: true, + Fields: []string{"TTL"}, + }, }, Decoder: func(buf []byte) interface{} { out := new(structs.Session) @@ -369,6 +373,7 @@ func (s *StateStore) initialize() error { "KVSListKeys": MDBTables{s.kvsTable}, "SessionGet": MDBTables{s.sessionTable}, "SessionList": MDBTables{s.sessionTable}, + "SessionListTTL": MDBTables{s.sessionTable}, "NodeSessions": MDBTables{s.sessionTable}, "ACLGet": MDBTables{s.aclTable}, "ACLList": MDBTables{s.aclTable}, @@ -1336,6 +1341,17 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior) } + if session.TTL != "" { + ttl, err := time.ParseDuration(session.TTL) + if err != nil { + return fmt.Errorf("Invalid Session TTL '%s': %v", session.TTL, err) + } + + if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax { + return fmt.Errorf("Invalid Session TTL '%s', must be between [%v-%v]", session.TTL, structs.SessionTTLMin, structs.SessionTTLMax) + } + } + // Assign the create index session.CreateIndex = index @@ -1445,6 +1461,16 @@ func (s *StateStore) SessionList() (uint64, []*structs.Session, error) { return idx, out, err } +// SessionListTTL is used to list all the open ttl sessions +func (s *StateStore) SessionListTTL() (uint64, []*structs.Session, error) { + idx, res, err := s.sessionTable.Get("ttl") + out := make([]*structs.Session, len(res)) + for i, raw := range res { + out[i] = raw.(*structs.Session) + } + return idx, out, err +} + // NodeSessions is used to list all the open sessions for a node func (s *StateStore) NodeSessions(node string) (uint64, []*structs.Session, error) { idx, res, err := s.sessionTable.Get("node", node) diff --git a/consul/state_store_test.go b/consul/state_store_test.go index 0dba19abda2e..c933dbdb0d94 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -703,13 +703,17 @@ func TestStoreSnapshot(t *testing.T) { if ok, err := store.KVSLock(18, d); err != nil || !ok { t.Fatalf("err: %v", err) } + session = &structs.Session{ID: generateUUID(), Node: "baz", TTL: "60s"} + if err := store.SessionCreate(19, session); err != nil { + t.Fatalf("err: %v", err) + } a1 := &structs.ACL{ ID: generateUUID(), Name: "User token", Type: structs.ACLTypeClient, } - if err := store.ACLSet(19, a1); err != nil { + if err := store.ACLSet(20, a1); err != nil { t.Fatalf("err: %v", err) } @@ -718,7 +722,7 @@ func TestStoreSnapshot(t *testing.T) { Name: "User token", Type: structs.ACLTypeClient, } - if err := store.ACLSet(20, a2); err != nil { + if err := store.ACLSet(21, a2); err != nil { t.Fatalf("err: %v", err) } @@ -730,7 +734,7 @@ func TestStoreSnapshot(t *testing.T) { defer snap.Close() // Check the last nodes - if idx := snap.LastIndex(); idx != 20 { + if idx := snap.LastIndex(); idx != 21 { t.Fatalf("bad: %v", idx) } @@ -785,15 +789,25 @@ func TestStoreSnapshot(t *testing.T) { t.Fatalf("missing KVS entries!") } - // Check there are 2 sessions + // Check there are 3 sessions sessions, err := snap.SessionList() if err != nil { t.Fatalf("err: %v", err) } - if len(sessions) != 2 { + if len(sessions) != 3 { t.Fatalf("missing sessions") } + ttls := 0 + for _, session := range sessions { + if session.TTL != "" { + ttls++ + } + } + if ttls != 1 { + t.Fatalf("Wrong number of sessions with TTL") + } + // Check for an acl acls, err := snap.ACLList() if err != nil { @@ -804,13 +818,13 @@ func TestStoreSnapshot(t *testing.T) { } // Make some changes! - if err := store.EnsureService(21, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil { + if err := store.EnsureService(22, "foo", &structs.NodeService{"db", "db", []string{"slave"}, 8000}); err != nil { t.Fatalf("err: %v", err) } - if err := store.EnsureService(22, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil { + if err := store.EnsureService(23, "bar", &structs.NodeService{"db", "db", []string{"master"}, 8000}); err != nil { t.Fatalf("err: %v", err) } - if err := store.EnsureNode(23, structs.Node{"baz", "127.0.0.3"}); err != nil { + if err := store.EnsureNode(24, structs.Node{"baz", "127.0.0.3"}); err != nil { t.Fatalf("err: %v", err) } checkAfter := &structs.HealthCheck{ @@ -820,16 +834,16 @@ func TestStoreSnapshot(t *testing.T) { Status: structs.HealthCritical, ServiceID: "db", } - if err := store.EnsureCheck(24, checkAfter); err != nil { + if err := store.EnsureCheck(26, checkAfter); err != nil { t.Fatalf("err: %v", err) } - if err := store.KVSDelete(25, "/web/b"); err != nil { + if err := store.KVSDelete(26, "/web/b"); err != nil { t.Fatalf("err: %v", err) } // Nuke an ACL - if err := store.ACLDelete(26, a1.ID); err != nil { + if err := store.ACLDelete(27, a1.ID); err != nil { t.Fatalf("err: %v", err) } @@ -883,12 +897,12 @@ func TestStoreSnapshot(t *testing.T) { t.Fatalf("missing KVS entries!") } - // Check there are 2 sessions + // Check there are 3 sessions sessions, err = snap.SessionList() if err != nil { t.Fatalf("err: %v", err) } - if len(sessions) != 2 { + if len(sessions) != 3 { t.Fatalf("missing sessions") } diff --git a/consul/structs/structs.go b/consul/structs/structs.go index ced8567d2645..2072780f3696 100644 --- a/consul/structs/structs.go +++ b/consul/structs/structs.go @@ -385,6 +385,12 @@ const ( SessionKeysDelete = "delete" ) +const ( + SessionTTLMin = 10 * time.Second + SessionTTLMax = 3600 * time.Second + SessionTTLMultiplier = 2 +) + // Session is used to represent an open session in the KV store. // This issued to associate node checks with acquired locks. type Session struct { @@ -395,6 +401,7 @@ type Session struct { Checks []string LockDelay time.Duration Behavior SessionBehavior // What to do when session is invalidated + TTL string } type Sessions []*Session