Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consul Session TTLs #524

Merged
merged 9 commits into from
Dec 12, 2014
1 change: 1 addition & 0 deletions command/agent/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {

s.mux.HandleFunc("/v1/session/create", s.wrap(s.SessionCreate))
s.mux.HandleFunc("/v1/session/destroy/", s.wrap(s.SessionDestroy))
s.mux.HandleFunc("/v1/session/renew/", s.wrap(s.SessionRenew))
s.mux.HandleFunc("/v1/session/info/", s.wrap(s.SessionGet))
s.mux.HandleFunc("/v1/session/node/", s.wrap(s.SessionsForNode))
s.mux.HandleFunc("/v1/session/list", s.wrap(s.SessionList))
Expand Down
49 changes: 49 additions & 0 deletions command/agent/session_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
Checks: []string{consul.SerfCheckID},
LockDelay: 15 * time.Second,
Behavior: structs.SessionKeysRelease,
TTL: "",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If TTL is set, we should also validate the value here. Avoid a trip to the server if the input is invalid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added checks to parseDuration after we've decoded the body.

},
}
s.parseDC(req, &args.Datacenter)
Expand All @@ -51,6 +52,21 @@ func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request)
resp.Write([]byte(fmt.Sprintf("Request decode failed: %v", err)))
return nil, nil
}

if args.Session.TTL != "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with an invalid TTL to check this code path?

ttl, err := time.ParseDuration(args.Session.TTL)
if err != nil {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL decode failed: %v", err)))
return nil, nil
}

if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
resp.WriteHeader(400)
resp.Write([]byte(fmt.Sprintf("Request TTL '%s', must be between [%v-%v]", args.Session.TTL, structs.SessionTTLMin, structs.SessionTTLMax)))
return nil, nil
}
}
}

// Create the session, get the ID
Expand Down Expand Up @@ -130,6 +146,39 @@ func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request)
return true, nil
}

// SessionRenew is used to renew the TTL on an existing TTL session
func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Mandate a PUT request
if req.Method != "PUT" {
resp.WriteHeader(405)
return nil, nil
}

args := structs.SessionSpecificRequest{}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}

// Pull out the session id
args.Session = strings.TrimPrefix(req.URL.Path, "/v1/session/renew/")
if args.Session == "" {
resp.WriteHeader(400)
resp.Write([]byte("Missing session"))
return nil, nil
}

var out structs.IndexedSessions
if err := s.agent.RPC("Session.Renew", &args, &out); err != nil {
return nil, err
} else if out.Sessions == nil {
resp.WriteHeader(404)
resp.Write([]byte(fmt.Sprintf("Session id '%s' not found", args.Session)))
return nil, nil
}

return out.Sessions, nil
}

// SessionGet is used to get info for a particular session
func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
args := structs.SessionSpecificRequest{}
Expand Down
144 changes: 144 additions & 0 deletions command/agent/session_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,28 @@ func makeTestSessionDelete(t *testing.T, srv *HTTPServer) string {
return sessResp.ID
}

func makeTestSessionTTL(t *testing.T, srv *HTTPServer, ttl string) string {
// Create Session with TTL
body := bytes.NewBuffer(nil)
enc := json.NewEncoder(body)
raw := map[string]interface{}{
"TTL": ttl,
}
enc.Encode(raw)

req, err := http.NewRequest("PUT", "/v1/session/create", body)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := httptest.NewRecorder()
obj, err := srv.SessionCreate(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
sessResp := obj.(sessionCreateResponse)
return sessResp.ID
}

func TestSessionDestroy(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
id := makeTestSession(t, srv)
Expand All @@ -192,6 +214,128 @@ func TestSessionDestroy(t *testing.T) {
})
}

func TestSessionTTL(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
TTL := "10s" // use the minimum legal ttl
ttl := 10 * time.Second

id := makeTestSessionTTL(t, srv, TTL)

req, err := http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp := httptest.NewRecorder()
obj, err := srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok := obj.(structs.Sessions)
if !ok {
t.Fatalf("should work")
}
if len(respObj) != 1 {
t.Fatalf("bad: %v", respObj)
}
if respObj[0].TTL != TTL {
t.Fatalf("Incorrect TTL: %s", respObj[0].TTL)
}

time.Sleep(ttl*structs.SessionTTLMultiplier + ttl)

req, err = http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if len(respObj) != 0 {
t.Fatalf("session '%s' should have been destroyed", id)
}
})
}

func TestSessionTTLRenew(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
TTL := "10s" // use the minimum legal ttl
ttl := 10 * time.Second

id := makeTestSessionTTL(t, srv, TTL)

req, err := http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp := httptest.NewRecorder()
obj, err := srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok := obj.(structs.Sessions)
if !ok {
t.Fatalf("should work")
}
if len(respObj) != 1 {
t.Fatalf("bad: %v", respObj)
}
if respObj[0].TTL != TTL {
t.Fatalf("Incorrect TTL: %s", respObj[0].TTL)
}

// Sleep to consume some time before renew
time.Sleep(ttl * (structs.SessionTTLMultiplier / 2))

req, err = http.NewRequest("PUT",
"/v1/session/renew/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionRenew(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if !ok {
t.Fatalf("should work")
}
if len(respObj) != 1 {
t.Fatalf("bad: %v", respObj)
}

// Sleep for ttl * TTL Multiplier
time.Sleep(ttl * structs.SessionTTLMultiplier)

req, err = http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if !ok {
t.Fatalf("session '%s' should have renewed", id)
}
if len(respObj) != 1 {
t.Fatalf("session '%s' should have renewed", id)
}

// now wait for timeout and expect session to get destroyed
time.Sleep(ttl * structs.SessionTTLMultiplier)

req, err = http.NewRequest("GET",
"/v1/session/info/"+id, nil)
resp = httptest.NewRecorder()
obj, err = srv.SessionGet(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
respObj, ok = obj.(structs.Sessions)
if !ok {
t.Fatalf("session '%s' should have destroyed", id)
}
if len(respObj) != 0 {
t.Fatalf("session '%s' should have destroyed", id)
}
})
}

func TestSessionGet(t *testing.T) {
httpTest(t, func(srv *HTTPServer) {
id := makeTestSession(t, srv)
Expand Down
7 changes: 7 additions & 0 deletions consul/leader.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ func (s *Server) leaderLoop(stopCh chan struct{}) {
s.logger.Printf("[ERR] consul: ACL initialization failed: %v", err)
}

// Setup Session Timers if we are the leader and need to
if err := s.initializeSessionTimers(); err != nil {
s.logger.Printf("[ERR] consul: Session Timers initialization failed: %v", err)
}
// clear the session timers if we are no longer leader and exit the leaderLoop
defer s.clearAllSessionTimers()

// Reconcile channel is only used once initial reconcile
// has succeeded
var reconcileCh chan serf.Member
Expand Down
3 changes: 3 additions & 0 deletions consul/leader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ func TestLeader_LeftLeader(t *testing.T) {
break
}
}
if leader == nil {
t.Fatalf("Should have a leader")
}
leader.Leave()
leader.Shutdown()
time.Sleep(100 * time.Millisecond)
Expand Down
6 changes: 6 additions & 0 deletions consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ type Server struct {
// which SHOULD only consist of Consul servers
serfWAN *serf.Serf

// sessionTimers track the expiration time of each Session that has
// a TTL. On expiration, a SessionDestroy event will occur, and
// destroy the session via standard session destory processing
sessionTimers map[string]*time.Timer
sessionTimersLock sync.RWMutex

shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
Expand Down
38 changes: 38 additions & 0 deletions consul/session_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
default:
return fmt.Errorf("Invalid Behavior setting '%s'", args.Session.Behavior)
}
if args.Session.TTL != "" {
ttl, err := time.ParseDuration(args.Session.TTL)
if err != nil {
return fmt.Errorf("Session TTL '%s' invalid: %v", args.Session.TTL, err)
}

if ttl < structs.SessionTTLMin || ttl > structs.SessionTTLMax {
return fmt.Errorf("Invalid Session TTL '%d', must be between [%v=%v]", ttl, structs.SessionTTLMin, structs.SessionTTLMax)
}
}

// If this is a create, we must generate the Session ID. This must
// be done prior to appending to the raft log, because the ID is not
Expand Down Expand Up @@ -63,6 +73,13 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
s.srv.logger.Printf("[ERR] consul.session: Apply failed: %v", err)
return err
}

if args.Op == structs.SessionCreate && args.Session.TTL != "" {
s.srv.resetSessionTimer(args.Session.ID, nil)
} else if args.Op == structs.SessionDestroy && args.Session.TTL != "" {
s.srv.clearSessionTimer(args.Session.ID)
}

if respErr, ok := resp.(error); ok {
return respErr
}
Expand Down Expand Up @@ -133,3 +150,24 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest,
return err
})
}

// Renew is used to renew the TTL on a single session
func (s *Session) Renew(args *structs.SessionSpecificRequest,
reply *structs.IndexedSessions) error {
if done, err := s.srv.forward("Session.Renew", args, args, reply); done {
return err
}

// Get the local state
state := s.srv.fsm.State()
// Get the session, from local state
index, session, err := state.SessionGet(args.Session)
reply.Index = index
if session != nil {
reply.Sessions = structs.Sessions{session}
// reset the session TTL timer
err = s.srv.resetSessionTimer(args.Session, session)
}

return err
}
Loading