diff --git a/api/lock_test.go b/api/lock_test.go index ceab5cdf9c4a..f4bad9e6b87f 100644 --- a/api/lock_test.go +++ b/api/lock_test.go @@ -57,7 +57,7 @@ func TestLock_LockUnlock(t *testing.T) { t.Fatalf("err: %v", err) } - // Should loose leadership + // Should lose leadership select { case <-leaderCh: case <-time.After(time.Second): @@ -105,32 +105,40 @@ func TestLock_DeleteKey(t *testing.T) { c, s := makeClient(t) defer s.Stop() - lock, err := c.LockKey("test/lock") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Should work - leaderCh, err := lock.Lock(nil) - if err != nil { - t.Fatalf("err: %v", err) - } - if leaderCh == nil { - t.Fatalf("not leader") - } - defer lock.Unlock() + // This uncovered some issues around special-case handling of low index + // numbers where it would work with a low number but fail for higher + // ones, so we loop this a bit to sweep the index up out of that + // territory. + for i := 0; i < 10; i++ { + func() { + lock, err := c.LockKey("test/lock") + if err != nil { + t.Fatalf("err: %v", err) + } - go func() { - // Nuke the key, simulate an operator intervention - kv := c.KV() - kv.Delete("test/lock", nil) - }() + // Should work + leaderCh, err := lock.Lock(nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if leaderCh == nil { + t.Fatalf("not leader") + } + defer lock.Unlock() - // Should loose leadership - select { - case <-leaderCh: - case <-time.After(time.Second): - t.Fatalf("should not be leader") + go func() { + // Nuke the key, simulate an operator intervention + kv := c.KV() + kv.Delete("test/lock", nil) + }() + + // Should loose leadership + select { + case <-leaderCh: + case <-time.After(time.Second): + t.Fatalf("should not be leader") + } + }() } } diff --git a/command/agent/agent.go b/command/agent/agent.go index 4509dce894c1..e4756e42940f 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -15,6 +15,7 @@ import ( "time" "github.com/hashicorp/consul/consul" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/serf/serf" ) @@ -94,7 +95,7 @@ type Agent struct { eventBuf []*UserEvent eventIndex int eventLock sync.RWMutex - eventNotify consul.NotifyGroup + eventNotify state.NotifyGroup shutdown bool shutdownCh chan struct{} diff --git a/command/agent/dns.go b/command/agent/dns.go index b1857c34ac66..8778dedcb5ec 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -390,7 +390,7 @@ RPC: } // Add the node record - records := d.formatNodeRecord(&out.NodeServices.Node, out.NodeServices.Node.Address, + records := d.formatNodeRecord(out.NodeServices.Node, out.NodeServices.Node.Address, req.Question[0].Name, qType, d.config.NodeTTL) if records != nil { resp.Answer = append(resp.Answer, records...) @@ -585,7 +585,7 @@ func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, res handled[addr] = struct{}{} // Add the node record - records := d.formatNodeRecord(&node.Node, addr, qName, qType, ttl) + records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl) if records != nil { resp.Answer = append(resp.Answer, records...) } @@ -626,7 +626,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } // Add the extra record - records := d.formatNodeRecord(&node.Node, addr, srvRec.Target, dns.TypeANY, ttl) + records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl) if records != nil { resp.Extra = append(resp.Extra, records...) } diff --git a/command/agent/local_test.go b/command/agent/local_test.go index d8f6cfbfee73..7f45267bbdd4 100644 --- a/command/agent/local_test.go +++ b/command/agent/local_test.go @@ -127,6 +127,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) { // All the services should match for id, serv := range services.NodeServices.Services { + serv.CreateIndex, serv.ModifyIndex = 0, 0 switch id { case "mysql": if !reflect.DeepEqual(serv, srv1) { @@ -236,6 +237,7 @@ func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) { // All the services should match for id, serv := range services.NodeServices.Services { + serv.CreateIndex, serv.ModifyIndex = 0, 0 switch id { case "svc_id1": if serv.ID != "svc_id1" || @@ -455,6 +457,7 @@ func TestAgentAntiEntropy_Services_ACLDeny(t *testing.T) { // All the services should match for id, serv := range services.NodeServices.Services { + serv.CreateIndex, serv.ModifyIndex = 0, 0 switch id { case "mysql": t.Fatalf("should not be permitted") @@ -581,6 +584,7 @@ func TestAgentAntiEntropy_Checks(t *testing.T) { // All the checks should match for _, chk := range checks.HealthChecks { + chk.CreateIndex, chk.ModifyIndex = 0, 0 switch chk.CheckID { case "mysql": if !reflect.DeepEqual(chk, chk1) { diff --git a/consul/acl_endpoint.go b/consul/acl_endpoint.go index f3c162b9895b..7191f057b1a4 100644 --- a/consul/acl_endpoint.go +++ b/consul/acl_endpoint.go @@ -123,16 +123,20 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest, state := a.srv.fsm.State() return a.srv.blockingRPC(&args.QueryOptions, &reply.QueryMeta, - state.QueryTables("ACLGet"), + state.GetQueryWatch("ACLGet"), func() error { index, acl, err := state.ACLGet(args.ACL) + if err != nil { + return err + } + reply.Index = index if acl != nil { reply.ACLs = structs.ACLs{acl} } else { reply.ACLs = nil } - return err + return nil }) } @@ -194,10 +198,14 @@ func (a *ACL) List(args *structs.DCSpecificRequest, state := a.srv.fsm.State() return a.srv.blockingRPC(&args.QueryOptions, &reply.QueryMeta, - state.QueryTables("ACLList"), + state.GetQueryWatch("ACLList"), func() error { - var err error - reply.Index, reply.ACLs, err = state.ACLList() - return err + index, acls, err := state.ACLList() + if err != nil { + return err + } + + reply.Index, reply.ACLs = index, acls + return nil }) } diff --git a/consul/acl_test.go b/consul/acl_test.go index 8983493699bd..702be3f95d53 100644 --- a/consul/acl_test.go +++ b/consul/acl_test.go @@ -724,7 +724,7 @@ func TestACL_filterServices(t *testing.T) { func TestACL_filterServiceNodes(t *testing.T) { // Create some service nodes nodes := structs.ServiceNodes{ - structs.ServiceNode{ + &structs.ServiceNode{ Node: "node1", ServiceName: "foo", }, @@ -748,7 +748,7 @@ func TestACL_filterServiceNodes(t *testing.T) { func TestACL_filterNodeServices(t *testing.T) { // Create some node services services := structs.NodeServices{ - Node: structs.Node{ + Node: &structs.Node{ Node: "node1", }, Services: map[string]*structs.NodeService{ @@ -778,10 +778,10 @@ func TestACL_filterCheckServiceNodes(t *testing.T) { // Create some nodes nodes := structs.CheckServiceNodes{ structs.CheckServiceNode{ - Node: structs.Node{ + Node: &structs.Node{ Node: "node1", }, - Service: structs.NodeService{ + Service: &structs.NodeService{ ID: "foo", Service: "foo", }, diff --git a/consul/catalog_endpoint.go b/consul/catalog_endpoint.go index 17d3f5d5c396..28aa4c813848 100644 --- a/consul/catalog_endpoint.go +++ b/consul/catalog_endpoint.go @@ -119,13 +119,19 @@ func (c *Catalog) ListNodes(args *structs.DCSpecificRequest, reply *structs.Inde return err } - // Get the local state + // Get the list of nodes. state := c.srv.fsm.State() - return c.srv.blockingRPC(&args.QueryOptions, + return c.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("Nodes"), + state.GetQueryWatch("Nodes"), func() error { - reply.Index, reply.Nodes = state.Nodes() + index, nodes, err := state.Nodes() + if err != nil { + return err + } + + reply.Index, reply.Nodes = index, nodes return nil }) } @@ -136,13 +142,19 @@ func (c *Catalog) ListServices(args *structs.DCSpecificRequest, reply *structs.I return err } - // Get the current nodes + // Get the list of services and their tags. state := c.srv.fsm.State() - return c.srv.blockingRPC(&args.QueryOptions, + return c.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("Services"), + state.GetQueryWatch("Services"), func() error { - reply.Index, reply.Services = state.Services() + index, services, err := state.Services() + if err != nil { + return err + } + + reply.Index, reply.Services = index, services return c.srv.filterACL(args.Token, reply) }) } @@ -160,15 +172,23 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru // Get the nodes state := c.srv.fsm.State() - err := c.srv.blockingRPC(&args.QueryOptions, + err := c.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("ServiceNodes"), + state.GetQueryWatch("ServiceNodes"), func() error { + var index uint64 + var services structs.ServiceNodes + var err error if args.TagFilter { - reply.Index, reply.ServiceNodes = state.ServiceTagNodes(args.ServiceName, args.ServiceTag) + index, services, err = state.ServiceTagNodes(args.ServiceName, args.ServiceTag) } else { - reply.Index, reply.ServiceNodes = state.ServiceNodes(args.ServiceName) + index, services, err = state.ServiceNodes(args.ServiceName) + } + if err != nil { + return err } + reply.Index, reply.ServiceNodes = index, services return c.srv.filterACL(args.Token, reply) }) @@ -198,11 +218,16 @@ func (c *Catalog) NodeServices(args *structs.NodeSpecificRequest, reply *structs // Get the node services state := c.srv.fsm.State() - return c.srv.blockingRPC(&args.QueryOptions, + return c.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("NodeServices"), + state.GetQueryWatch("NodeServices"), func() error { - reply.Index, reply.NodeServices = state.NodeServices(args.Node) + index, services, err := state.NodeServices(args.Node) + if err != nil { + return err + } + reply.Index, reply.NodeServices = index, services return c.srv.filterACL(args.Token, reply) }) } diff --git a/consul/catalog_endpoint_test.go b/consul/catalog_endpoint_test.go index d85dfc4f464b..b1ec44c36a16 100644 --- a/consul/catalog_endpoint_test.go +++ b/consul/catalog_endpoint_test.go @@ -267,7 +267,9 @@ func TestCatalogListNodes(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") // Just add a node - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } testutil.WaitForResult(func() (bool, error) { msgpackrpc.CallWithCodec(codec, "Catalog.ListNodes", &args, &out) @@ -317,12 +319,16 @@ func TestCatalogListNodes_StaleRaad(t *testing.T) { codec = codec1 // Inject fake data on the follower! - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } } else { codec = codec2 // Inject fake data on the follower! - s2.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := s2.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } } args := structs.DCSpecificRequest{ @@ -458,7 +464,9 @@ func BenchmarkCatalogListNodes(t *testing.B) { defer codec.Close() // Just add a node - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } args := structs.DCSpecificRequest{ Datacenter: "dc1", @@ -490,8 +498,12 @@ func TestCatalogListServices(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") // Just add a node - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) - s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{"db", "db", []string{"primary"}, "127.0.0.1", 5000, false}) + if err := s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } if err := msgpackrpc.CallWithCodec(codec, "Catalog.ListServices", &args, &out); err != nil { t.Fatalf("err: %v", err) @@ -541,11 +553,16 @@ func TestCatalogListServices_Blocking(t *testing.T) { args.MaxQueryTime = time.Second // Async cause a change + idx := out.Index start := time.Now() go func() { time.Sleep(100 * time.Millisecond) - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) - s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{"db", "db", []string{"primary"}, "127.0.0.1", 5000, false}) + if err := s1.fsm.State().EnsureNode(idx+1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s1.fsm.State().EnsureService(idx+2, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } }() // Re-run the query @@ -560,7 +577,7 @@ func TestCatalogListServices_Blocking(t *testing.T) { } // Check the indexes - if out.Index != 2 { + if out.Index != idx+2 { t.Fatalf("bad: %v", out) } @@ -625,8 +642,12 @@ func TestCatalogListServices_Stale(t *testing.T) { var out structs.IndexedServices // Inject a fake service - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) - s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{"db", "db", []string{"primary"}, "127.0.0.1", 5000, false}) + if err := s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } // Run the query, do not wait for leader! if err := msgpackrpc.CallWithCodec(codec, "Catalog.ListServices", &args, &out); err != nil { @@ -666,8 +687,12 @@ func TestCatalogListServiceNodes(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") // Just add a node - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) - s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{"db", "db", []string{"primary"}, "127.0.0.1", 5000, false}) + if err := s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } if err := msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &args, &out); err != nil { t.Fatalf("err: %v", err) @@ -709,9 +734,15 @@ func TestCatalogNodeServices(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") // Just add a node - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) - s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{"db", "db", []string{"primary"}, "127.0.0.1", 5000, false}) - s1.fsm.State().EnsureService(3, "foo", &structs.NodeService{"web", "web", nil, "127.0.0.1", 80, false}) + if err := s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s1.fsm.State().EnsureService(2, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s1.fsm.State().EnsureService(3, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}); err != nil { + t.Fatalf("err: %v", err) + } if err := msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &args, &out); err != nil { t.Fatalf("err: %v", err) diff --git a/consul/fsm.go b/consul/fsm.go index 58f07d45525f..89cfd5725f7d 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -4,11 +4,11 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/raft" @@ -24,15 +24,15 @@ type consulFSM struct { logOutput io.Writer logger *log.Logger path string - state *StateStore - gc *TombstoneGC + state *state.StateStore + gc *state.TombstoneGC } // consulSnapshot is used to provide a snapshot of the current // state in a way that can be accessed concurrently with operations // that may modify the live state. type consulSnapshot struct { - state *StateSnapshot + state *state.StateSnapshot } // snapshotHeader is the first entry in our snapshot @@ -43,15 +43,8 @@ type snapshotHeader struct { } // NewFSMPath is used to construct a new FSM with a blank state -func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, error) { - // Create a temporary path for the state store - tmpPath, err := ioutil.TempDir(path, "state") - if err != nil { - return nil, err - } - - // Create a state store - state, err := NewStateStorePath(gc, tmpPath, logOutput) +func NewFSM(gc *state.TombstoneGC, logOutput io.Writer) (*consulFSM, error) { + stateNew, err := state.NewStateStore(gc) if err != nil { return nil, err } @@ -59,20 +52,14 @@ func NewFSM(gc *TombstoneGC, path string, logOutput io.Writer) (*consulFSM, erro fsm := &consulFSM{ logOutput: logOutput, logger: log.New(logOutput, "", log.LstdFlags), - path: path, - state: state, + state: stateNew, gc: gc, } return fsm, nil } -// Close is used to cleanup resources associated with the FSM -func (c *consulFSM) Close() error { - return c.state.Close() -} - // State is used to return a handle to the current state -func (c *consulFSM) State() *StateStore { +func (c *consulFSM) State() *state.StateStore { return c.state } @@ -91,7 +78,7 @@ func (c *consulFSM) Apply(log *raft.Log) interface{} { switch msgType { case structs.RegisterRequestType: - return c.decodeRegister(buf[1:], log.Index) + return c.applyRegister(buf[1:], log.Index) case structs.DeregisterRequestType: return c.applyDeregister(buf[1:], log.Index) case structs.KVSRequestType: @@ -112,18 +99,15 @@ func (c *consulFSM) Apply(log *raft.Log) interface{} { } } -func (c *consulFSM) decodeRegister(buf []byte, index uint64) interface{} { +func (c *consulFSM) applyRegister(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"consul", "fsm", "register"}, time.Now()) var req structs.RegisterRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) } - return c.applyRegister(&req, index) -} -func (c *consulFSM) applyRegister(req *structs.RegisterRequest, index uint64) interface{} { - defer metrics.MeasureSince([]string{"consul", "fsm", "register"}, time.Now()) // Apply all updates in a single transaction - if err := c.state.EnsureRegistration(index, req); err != nil { + if err := c.state.EnsureRegistration(index, &req); err != nil { c.logger.Printf("[INFO] consul.fsm: EnsureRegistration failed: %v", err) return err } @@ -139,12 +123,12 @@ func (c *consulFSM) applyDeregister(buf []byte, index uint64) interface{} { // Either remove the service entry or the whole node if req.ServiceID != "" { - if err := c.state.DeleteNodeService(index, req.Node, req.ServiceID); err != nil { + if err := c.state.DeleteService(index, req.Node, req.ServiceID); err != nil { c.logger.Printf("[INFO] consul.fsm: DeleteNodeService failed: %v", err) return err } } else if req.CheckID != "" { - if err := c.state.DeleteNodeCheck(index, req.Node, req.CheckID); err != nil { + if err := c.state.DeleteCheck(index, req.Node, req.CheckID); err != nil { c.logger.Printf("[INFO] consul.fsm: DeleteNodeCheck failed: %v", err) return err } @@ -169,7 +153,7 @@ func (c *consulFSM) applyKVSOperation(buf []byte, index uint64) interface{} { case structs.KVSDelete: return c.state.KVSDelete(index, req.DirEnt.Key) case structs.KVSDeleteCAS: - act, err := c.state.KVSDeleteCheckAndSet(index, req.DirEnt.Key, req.DirEnt.ModifyIndex) + act, err := c.state.KVSDeleteCAS(index, req.DirEnt.ModifyIndex, req.DirEnt.Key) if err != nil { return err } else { @@ -178,7 +162,7 @@ func (c *consulFSM) applyKVSOperation(buf []byte, index uint64) interface{} { case structs.KVSDeleteTree: return c.state.KVSDeleteTree(index, req.DirEnt.Key) case structs.KVSCAS: - act, err := c.state.KVSCheckAndSet(index, &req.DirEnt) + act, err := c.state.KVSSetCAS(index, &req.DirEnt) if err != nil { return err } else { @@ -267,30 +251,22 @@ func (c *consulFSM) Snapshot() (raft.FSMSnapshot, error) { c.logger.Printf("[INFO] consul.fsm: snapshot created in %v", time.Now().Sub(start)) }(time.Now()) - // Create a new snapshot - snap, err := c.state.Snapshot() - if err != nil { - return nil, err - } - return &consulSnapshot{snap}, nil + return &consulSnapshot{c.state.Snapshot()}, nil } func (c *consulFSM) Restore(old io.ReadCloser) error { defer old.Close() - // Create a temporary path for the state store - tmpPath, err := ioutil.TempDir(c.path, "state") - if err != nil { - return err - } - // Create a new state store - state, err := NewStateStorePath(c.gc, tmpPath, c.logOutput) + stateNew, err := state.NewStateStore(c.gc) if err != nil { return err } - c.state.Close() - c.state = state + c.state = stateNew + + // Set up a new restore transaction + restore := c.state.Restore() + defer restore.Abort() // Create a decoder dec := codec.NewDecoder(old, msgpackHandle) @@ -319,41 +295,51 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { if err := dec.Decode(&req); err != nil { return err } - c.applyRegister(&req, header.LastIndex) + if err := restore.Registration(header.LastIndex, &req); err != nil { + return err + } case structs.KVSRequestType: var req structs.DirEntry if err := dec.Decode(&req); err != nil { return err } - if err := c.state.KVSRestore(&req); err != nil { + if err := restore.KVS(&req); err != nil { return err } - case structs.SessionRequestType: - var req structs.Session + case structs.TombstoneRequestType: + var req structs.DirEntry if err := dec.Decode(&req); err != nil { return err } - if err := c.state.SessionRestore(&req); err != nil { + + // For historical reasons, these are serialized in the + // snapshots as KV entries. We want to keep the snapshot + // format compatible with pre-0.6 versions for now. + stone := &state.Tombstone{ + Key: req.Key, + Index: req.ModifyIndex, + } + if err := restore.Tombstone(stone); err != nil { return err } - case structs.ACLRequestType: - var req structs.ACL + case structs.SessionRequestType: + var req structs.Session if err := dec.Decode(&req); err != nil { return err } - if err := c.state.ACLRestore(&req); err != nil { + if err := restore.Session(&req); err != nil { return err } - case structs.TombstoneRequestType: - var req structs.DirEntry + case structs.ACLRequestType: + var req structs.ACL if err := dec.Decode(&req); err != nil { return err } - if err := c.state.TombstoneRestore(&req); err != nil { + if err := restore.ACL(&req); err != nil { return err } @@ -362,11 +348,13 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { } } + restore.Commit() return nil } func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { defer metrics.MeasureSince([]string{"consul", "fsm", "persist"}, time.Now()) + // Register the nodes encoder := codec.NewEncoder(sink, msgpackHandle) @@ -394,7 +382,7 @@ func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { return err } - if err := s.persistKV(sink, encoder); err != nil { + if err := s.persistKVs(sink, encoder); err != nil { sink.Cancel() return err } @@ -408,15 +396,19 @@ func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, encoder *codec.Encoder) error { + // Get all the nodes - nodes := s.state.Nodes() + nodes, err := s.state.Nodes() + if err != nil { + return err + } // Register each node - var req structs.RegisterRequest - for i := 0; i < len(nodes); i++ { - req = structs.RegisterRequest{ - Node: nodes[i].Node, - Address: nodes[i].Address, + for node := nodes.Next(); node != nil; node = nodes.Next() { + n := node.(*structs.Node) + req := structs.RegisterRequest{ + Node: n.Node, + Address: n.Address, } // Register the node itself @@ -426,10 +418,13 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, } // Register each service this node has - services := s.state.NodeServices(nodes[i].Node) - for _, srv := range services.Services { - req.Service = srv + services, err := s.state.Services(n.Node) + if err != nil { + return err + } + for service := services.Next(); service != nil; service = services.Next() { sink.Write([]byte{byte(structs.RegisterRequestType)}) + req.Service = service.(*structs.ServiceNode).ToNodeService() if err := encoder.Encode(&req); err != nil { return err } @@ -437,10 +432,13 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, // Register each check this node has req.Service = nil - checks := s.state.NodeChecks(nodes[i].Node) - for _, check := range checks { - req.Check = check + checks, err := s.state.Checks(n.Node) + if err != nil { + return err + } + for check := checks.Next(); check != nil; check = checks.Next() { sink.Write([]byte{byte(structs.RegisterRequestType)}) + req.Check = check.(*structs.HealthCheck) if err := encoder.Encode(&req); err != nil { return err } @@ -451,14 +449,14 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, encoder *codec.Encoder) error { - sessions, err := s.state.SessionList() + sessions, err := s.state.Sessions() if err != nil { return err } - for _, s := range sessions { + for session := sessions.Next(); session != nil; session = sessions.Next() { sink.Write([]byte{byte(structs.SessionRequestType)}) - if err := encoder.Encode(s); err != nil { + if err := encoder.Encode(session.(*structs.Session)); err != nil { return err } } @@ -467,72 +465,61 @@ func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, encoder *codec.Encoder) error { - acls, err := s.state.ACLList() + acls, err := s.state.ACLs() if err != nil { return err } - for _, s := range acls { + for acl := acls.Next(); acl != nil; acl = acls.Next() { sink.Write([]byte{byte(structs.ACLRequestType)}) - if err := encoder.Encode(s); err != nil { + if err := encoder.Encode(acl.(*structs.ACL)); err != nil { return err } } return nil } -func (s *consulSnapshot) persistKV(sink raft.SnapshotSink, +func (s *consulSnapshot) persistKVs(sink raft.SnapshotSink, encoder *codec.Encoder) error { - streamCh := make(chan interface{}, 256) - errorCh := make(chan error) - go func() { - if err := s.state.KVSDump(streamCh); err != nil { - errorCh <- err - } - }() - - for { - select { - case raw := <-streamCh: - if raw == nil { - return nil - } - sink.Write([]byte{byte(structs.KVSRequestType)}) - if err := encoder.Encode(raw); err != nil { - return err - } + entries, err := s.state.KVs() + if err != nil { + return err + } - case err := <-errorCh: + for entry := entries.Next(); entry != nil; entry = entries.Next() { + sink.Write([]byte{byte(structs.KVSRequestType)}) + if err := encoder.Encode(entry.(*structs.DirEntry)); err != nil { return err } } + return nil } func (s *consulSnapshot) persistTombstones(sink raft.SnapshotSink, encoder *codec.Encoder) error { - streamCh := make(chan interface{}, 256) - errorCh := make(chan error) - go func() { - if err := s.state.TombstoneDump(streamCh); err != nil { - errorCh <- err - } - }() + stones, err := s.state.Tombstones() + if err != nil { + return err + } - for { - select { - case raw := <-streamCh: - if raw == nil { - return nil - } - sink.Write([]byte{byte(structs.TombstoneRequestType)}) - if err := encoder.Encode(raw); err != nil { - return err - } + for stone := stones.Next(); stone != nil; stone = stones.Next() { + sink.Write([]byte{byte(structs.TombstoneRequestType)}) - case err := <-errorCh: + // For historical reasons, these are serialized in the snapshots + // as KV entries. We want to keep the snapshot format compatible + // with pre-0.6 versions for now. + s := stone.(*state.Tombstone) + fake := &structs.DirEntry{ + Key: s.Key, + RaftIndex: structs.RaftIndex{ + ModifyIndex: s.Index, + }, + } + if err := encoder.Encode(fake); err != nil { return err } } + return nil } func (s *consulSnapshot) Release() { diff --git a/consul/fsm_test.go b/consul/fsm_test.go index 28594de41843..92d66a98931d 100644 --- a/consul/fsm_test.go +++ b/consul/fsm_test.go @@ -2,10 +2,10 @@ package consul import ( "bytes" - "io/ioutil" "os" "testing" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/raft" ) @@ -38,16 +38,10 @@ func makeLog(buf []byte) *raft.Log { } func TestFSM_RegisterNode(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() req := structs.RegisterRequest{ Datacenter: "dc1", @@ -65,30 +59,32 @@ func TestFSM_RegisterNode(t *testing.T) { } // Verify we are registered - if idx, found, _ := fsm.state.GetNode("foo"); !found { + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { t.Fatalf("not found!") - } else if idx != 1 { - t.Fatalf("bad index: %d", idx) + } + if node.ModifyIndex != 1 { + t.Fatalf("bad index: %d", node.ModifyIndex) } // Verify service registered - _, services := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if len(services.Services) != 0 { t.Fatalf("Services: %v", services) } } func TestFSM_RegisterNode_Service(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() req := structs.RegisterRequest{ Datacenter: "dc1", @@ -119,34 +115,38 @@ func TestFSM_RegisterNode_Service(t *testing.T) { } // Verify we are registered - if _, found, _ := fsm.state.GetNode("foo"); !found { + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { t.Fatalf("not found!") } // Verify service registered - _, services := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if _, ok := services.Services["db"]; !ok { t.Fatalf("not registered!") } // Verify check - _, checks := fsm.state.NodeChecks("foo") + _, checks, err := fsm.state.NodeChecks("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if checks[0].CheckID != "db" { t.Fatalf("not registered!") } } func TestFSM_DeregisterService(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm.Close() req := structs.RegisterRequest{ Datacenter: "dc1", @@ -185,28 +185,29 @@ func TestFSM_DeregisterService(t *testing.T) { } // Verify we are registered - if _, found, _ := fsm.state.GetNode("foo"); !found { + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { t.Fatalf("not found!") } // Verify service not registered - _, services := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if _, ok := services.Services["db"]; ok { t.Fatalf("db registered!") } } func TestFSM_DeregisterCheck(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() req := structs.RegisterRequest{ Datacenter: "dc1", @@ -245,28 +246,29 @@ func TestFSM_DeregisterCheck(t *testing.T) { } // Verify we are registered - if _, found, _ := fsm.state.GetNode("foo"); !found { + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node == nil { t.Fatalf("not found!") } // Verify check not registered - _, checks := fsm.state.NodeChecks("foo") + _, checks, err := fsm.state.NodeChecks("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if len(checks) != 0 { t.Fatalf("check registered!") } } func TestFSM_DeregisterNode(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() req := structs.RegisterRequest{ Datacenter: "dc1", @@ -310,43 +312,47 @@ func TestFSM_DeregisterNode(t *testing.T) { t.Fatalf("resp: %v", resp) } - // Verify we are registered - if _, found, _ := fsm.state.GetNode("foo"); found { + // Verify we are not registered + _, node, err := fsm.state.GetNode("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if node != nil { t.Fatalf("found!") } // Verify service not registered - _, services := fsm.state.NodeServices("foo") + _, services, err := fsm.state.NodeServices("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if services != nil { t.Fatalf("Services: %v", services) } // Verify checks not registered - _, checks := fsm.state.NodeChecks("foo") + _, checks, err := fsm.state.NodeChecks("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if len(checks) != 0 { t.Fatalf("Services: %v", services) } } func TestFSM_SnapshotRestore(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() // Add some state - fsm.state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) - fsm.state.EnsureNode(2, structs.Node{"baz", "127.0.0.2"}) - fsm.state.EnsureService(3, "foo", &structs.NodeService{"web", "web", nil, "127.0.0.1", 80, false}) - fsm.state.EnsureService(4, "foo", &structs.NodeService{"db", "db", []string{"primary"}, "127.0.0.1", 5000, false}) - fsm.state.EnsureService(5, "baz", &structs.NodeService{"web", "web", nil, "127.0.0.2", 80, false}) - fsm.state.EnsureService(6, "baz", &structs.NodeService{"db", "db", []string{"secondary"}, "127.0.0.2", 5000, false}) + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + fsm.state.EnsureNode(2, &structs.Node{Node: "baz", Address: "127.0.0.2"}) + fsm.state.EnsureService(3, "foo", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.1", Port: 80}) + fsm.state.EnsureService(4, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"primary"}, Address: "127.0.0.1", Port: 5000}) + fsm.state.EnsureService(5, "baz", &structs.NodeService{ID: "web", Service: "web", Tags: nil, Address: "127.0.0.2", Port: 80}) + fsm.state.EnsureService(6, "baz", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"secondary"}, Address: "127.0.0.2", Port: 5000}) fsm.state.EnsureCheck(7, &structs.HealthCheck{ Node: "foo", CheckID: "web", @@ -368,6 +374,13 @@ func TestFSM_SnapshotRestore(t *testing.T) { Value: []byte("foo"), }) fsm.state.KVSDelete(12, "/remove") + idx, _, err := fsm.state.KVSList("/remove") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 12 { + t.Fatalf("bad index: %d", idx) + } // Snapshot snap, err := fsm.Snapshot() @@ -384,11 +397,10 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Try to restore on a new FSM - fsm2, err := NewFSM(nil, path, os.Stderr) + fsm2, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm2.Close() // Do a restore if err := fsm2.Restore(sink); err != nil { @@ -396,12 +408,18 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify the contents - _, nodes := fsm2.state.Nodes() + _, nodes, err := fsm2.state.Nodes() + if err != nil { + t.Fatalf("err: %s", err) + } if len(nodes) != 2 { t.Fatalf("Bad: %v", nodes) } - _, fooSrv := fsm2.state.NodeServices("foo") + _, fooSrv, err := fsm2.state.NodeServices("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if len(fooSrv.Services) != 2 { t.Fatalf("Bad: %v", fooSrv) } @@ -412,7 +430,10 @@ func TestFSM_SnapshotRestore(t *testing.T) { t.Fatalf("Bad: %v", fooSrv) } - _, checks := fsm2.state.NodeChecks("foo") + _, checks, err := fsm2.state.NodeChecks("foo") + if err != nil { + t.Fatalf("err: %s", err) + } if len(checks) != 1 { t.Fatalf("Bad: %v", checks) } @@ -426,15 +447,6 @@ func TestFSM_SnapshotRestore(t *testing.T) { t.Fatalf("bad: %v", d) } - // Verify the index is restored - idx, _, err := fsm2.state.KVSListKeys("/blah", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx <= 1 { - t.Fatalf("bad index: %d", idx) - } - // Verify session is restored idx, s, err := fsm2.state.SessionGet(session.ID) if err != nil { @@ -448,38 +460,43 @@ func TestFSM_SnapshotRestore(t *testing.T) { } // Verify ACL is restored - idx, a, err := fsm2.state.ACLGet(acl.ID) + _, a, err := fsm2.state.ACLGet(acl.ID) if err != nil { t.Fatalf("err: %v", err) } if a.Name != "User Token" { t.Fatalf("bad: %v", a) } - if idx <= 1 { + if a.ModifyIndex <= 1 { t.Fatalf("bad index: %d", idx) } // Verify tombstones are restored - _, res, err := fsm2.state.tombstoneTable.Get("id", "/remove") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 1 { - t.Fatalf("bad: %v", res) - } + func() { + snap := fsm2.state.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + stone := stones.Next().(*state.Tombstone) + if stone == nil { + t.Fatalf("missing tombstone") + } + if stone.Key != "/remove" || stone.Index != 12 { + t.Fatalf("bad: %v", stone) + } + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") + } + }() } func TestFSM_KVSSet(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm.Close() req := structs.KVSRequest{ Datacenter: "dc1", @@ -510,16 +527,10 @@ func TestFSM_KVSSet(t *testing.T) { } func TestFSM_KVSDelete(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm.Close() req := structs.KVSRequest{ Datacenter: "dc1", @@ -561,16 +572,10 @@ func TestFSM_KVSDelete(t *testing.T) { } func TestFSM_KVSDeleteTree(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() req := structs.KVSRequest{ Datacenter: "dc1", @@ -613,16 +618,10 @@ func TestFSM_KVSDeleteTree(t *testing.T) { } func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm.Close() req := structs.KVSRequest{ Datacenter: "dc1", @@ -674,16 +673,10 @@ func TestFSM_KVSDeleteCheckAndSet(t *testing.T) { } func TestFSM_KVSCheckAndSet(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() req := structs.KVSRequest{ Datacenter: "dc1", @@ -736,18 +729,12 @@ func TestFSM_KVSCheckAndSet(t *testing.T) { } func TestFSM_SessionCreate_Destroy(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() - fsm.state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) fsm.state.EnsureCheck(2, &structs.HealthCheck{ Node: "foo", CheckID: "web", @@ -821,18 +808,12 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) { } func TestFSM_KVSLock(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() - fsm.state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(2, session) @@ -871,18 +852,12 @@ func TestFSM_KVSLock(t *testing.T) { } func TestFSM_KVSUnlock(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm.Close() - fsm.state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + fsm.state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) session := &structs.Session{ID: generateUUID(), Node: "foo"} fsm.state.SessionCreate(2, session) @@ -939,16 +914,10 @@ func TestFSM_KVSUnlock(t *testing.T) { } func TestFSM_ACL_Set_Delete(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm.Close() // Create a new ACL req := structs.ACLRequest{ @@ -1017,23 +986,24 @@ func TestFSM_ACL_Set_Delete(t *testing.T) { } func TestFSM_TombstoneReap(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() - // Create some tombstones + // Create some tombstones fsm.state.KVSSet(11, &structs.DirEntry{ Key: "/remove", Value: []byte("foo"), }) fsm.state.KVSDelete(12, "/remove") + idx, _, err := fsm.state.KVSList("/remove") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 12 { + t.Fatalf("bad index: %d", idx) + } // Create a new reap request req := structs.TombstoneRequest{ @@ -1051,26 +1021,22 @@ func TestFSM_TombstoneReap(t *testing.T) { } // Verify the tombstones are gone - _, res, err := fsm.state.tombstoneTable.Get("id") + snap := fsm.state.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() if err != nil { - t.Fatalf("err: %v", err) + t.Fatalf("err: %s", err) } - if len(res) != 0 { - t.Fatalf("bad: %v", res) + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") } } func TestFSM_IgnoreUnknown(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(path) - fsm, err := NewFSM(nil, path, os.Stderr) + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - defer fsm.Close() // Create a new reap request type UnknownRequest struct { diff --git a/consul/health_endpoint.go b/consul/health_endpoint.go index 6accd3aa6434..4bb6404c1178 100644 --- a/consul/health_endpoint.go +++ b/consul/health_endpoint.go @@ -20,11 +20,16 @@ func (h *Health) ChecksInState(args *structs.ChecksInStateRequest, // Get the state specific checks state := h.srv.fsm.State() - return h.srv.blockingRPC(&args.QueryOptions, + return h.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("ChecksInState"), + state.GetQueryWatch("ChecksInState"), func() error { - reply.Index, reply.HealthChecks = state.ChecksInState(args.State) + index, checks, err := state.ChecksInState(args.State) + if err != nil { + return err + } + reply.Index, reply.HealthChecks = index, checks return h.srv.filterACL(args.Token, reply) }) } @@ -38,11 +43,16 @@ func (h *Health) NodeChecks(args *structs.NodeSpecificRequest, // Get the node checks state := h.srv.fsm.State() - return h.srv.blockingRPC(&args.QueryOptions, + return h.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("NodeChecks"), + state.GetQueryWatch("NodeChecks"), func() error { - reply.Index, reply.HealthChecks = state.NodeChecks(args.Node) + index, checks, err := state.NodeChecks(args.Node) + if err != nil { + return err + } + reply.Index, reply.HealthChecks = index, checks return h.srv.filterACL(args.Token, reply) }) } @@ -62,11 +72,16 @@ func (h *Health) ServiceChecks(args *structs.ServiceSpecificRequest, // Get the service checks state := h.srv.fsm.State() - return h.srv.blockingRPC(&args.QueryOptions, + return h.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("ServiceChecks"), + state.GetQueryWatch("ServiceChecks"), func() error { - reply.Index, reply.HealthChecks = state.ServiceChecks(args.ServiceName) + index, checks, err := state.ServiceChecks(args.ServiceName) + if err != nil { + return err + } + reply.Index, reply.HealthChecks = index, checks return h.srv.filterACL(args.Token, reply) }) } @@ -84,15 +99,23 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc // Get the nodes state := h.srv.fsm.State() - err := h.srv.blockingRPC(&args.QueryOptions, + err := h.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("CheckServiceNodes"), + state.GetQueryWatch("CheckServiceNodes"), func() error { + var index uint64 + var nodes structs.CheckServiceNodes + var err error if args.TagFilter { - reply.Index, reply.Nodes = state.CheckServiceTagNodes(args.ServiceName, args.ServiceTag) + index, nodes, err = state.CheckServiceTagNodes(args.ServiceName, args.ServiceTag) } else { - reply.Index, reply.Nodes = state.CheckServiceNodes(args.ServiceName) + index, nodes, err = state.CheckServiceNodes(args.ServiceName) + } + if err != nil { + return err } + reply.Index, reply.Nodes = index, nodes return h.srv.filterACL(args.Token, reply) }) diff --git a/consul/health_endpoint_test.go b/consul/health_endpoint_test.go index 110b1d8b026a..77356983e2e4 100644 --- a/consul/health_endpoint_test.go +++ b/consul/health_endpoint_test.go @@ -46,11 +46,11 @@ func TestHealth_ChecksInState(t *testing.T) { t.Fatalf("Bad: %v", checks) } - // First check is automatically added for the server node - if checks[0].CheckID != SerfCheckID { + // Serf check is automatically added + if checks[0].Name != "memory utilization" { t.Fatalf("Bad: %v", checks[0]) } - if checks[1].Name != "memory utilization" { + if checks[1].CheckID != SerfCheckID { t.Fatalf("Bad: %v", checks[1]) } } @@ -205,22 +205,22 @@ func TestHealth_ServiceNodes(t *testing.T) { if len(nodes) != 2 { t.Fatalf("Bad: %v", nodes) } - if nodes[0].Node.Node != "foo" { + if nodes[0].Node.Node != "bar" { t.Fatalf("Bad: %v", nodes[0]) } - if nodes[1].Node.Node != "bar" { + if nodes[1].Node.Node != "foo" { t.Fatalf("Bad: %v", nodes[1]) } - if !strContains(nodes[0].Service.Tags, "master") { + if !strContains(nodes[0].Service.Tags, "slave") { t.Fatalf("Bad: %v", nodes[0]) } - if !strContains(nodes[1].Service.Tags, "slave") { + if !strContains(nodes[1].Service.Tags, "master") { t.Fatalf("Bad: %v", nodes[1]) } - if nodes[0].Checks[0].Status != structs.HealthPassing { + if nodes[0].Checks[0].Status != structs.HealthWarning { t.Fatalf("Bad: %v", nodes[0]) } - if nodes[1].Checks[0].Status != structs.HealthWarning { + if nodes[1].Checks[0].Status != structs.HealthPassing { t.Fatalf("Bad: %v", nodes[1]) } } diff --git a/consul/internal_endpoint.go b/consul/internal_endpoint.go index 939887a8c489..a30086f94c73 100644 --- a/consul/internal_endpoint.go +++ b/consul/internal_endpoint.go @@ -23,11 +23,17 @@ func (m *Internal) NodeInfo(args *structs.NodeSpecificRequest, // Get the node info state := m.srv.fsm.State() - return m.srv.blockingRPC(&args.QueryOptions, + return m.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("NodeInfo"), + state.GetQueryWatch("NodeInfo"), func() error { - reply.Index, reply.Dump = state.NodeInfo(args.Node) + index, dump, err := state.NodeInfo(args.Node) + if err != nil { + return err + } + + reply.Index, reply.Dump = index, dump return m.srv.filterACL(args.Token, reply) }) } @@ -41,11 +47,17 @@ func (m *Internal) NodeDump(args *structs.DCSpecificRequest, // Get all the node info state := m.srv.fsm.State() - return m.srv.blockingRPC(&args.QueryOptions, + return m.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("NodeDump"), + state.GetQueryWatch("NodeDump"), func() error { - reply.Index, reply.Dump = state.NodeDump() + index, dump, err := state.NodeDump() + if err != nil { + return err + } + + reply.Index, reply.Dump = index, dump return m.srv.filterACL(args.Token, reply) }) } diff --git a/consul/issue_test.go b/consul/issue_test.go index 5676c6a1d563..45f9e91c6a59 100644 --- a/consul/issue_test.go +++ b/consul/issue_test.go @@ -1,7 +1,6 @@ package consul import ( - "io/ioutil" "os" "reflect" "testing" @@ -11,15 +10,10 @@ import ( // Testing for GH-300 and GH-279 func TestHealthCheckRace(t *testing.T) { - path, err := ioutil.TempDir("", "fsm") + fsm, err := NewFSM(nil, os.Stderr) if err != nil { t.Fatalf("err: %v", err) } - fsm, err := NewFSM(nil, path, os.Stderr) - if err != nil { - t.Fatalf("err: %v", err) - } - defer fsm.Close() state := fsm.State() req := structs.RegisterRequest{ @@ -51,9 +45,12 @@ func TestHealthCheckRace(t *testing.T) { } // Verify the index - idx, out1 := state.CheckServiceNodes("db") + idx, out1, err := state.CheckServiceNodes("db") + if err != nil { + t.Fatalf("err: %s", err) + } if idx != 10 { - t.Fatalf("Bad index") + t.Fatalf("Bad index: %d", idx) } // Update the check state @@ -71,9 +68,12 @@ func TestHealthCheckRace(t *testing.T) { } // Verify the index changed - idx, out2 := state.CheckServiceNodes("db") + idx, out2, err := state.CheckServiceNodes("db") + if err != nil { + t.Fatalf("err: %s", err) + } if idx != 20 { - t.Fatalf("Bad index") + t.Fatalf("Bad index: %d", idx) } if reflect.DeepEqual(out1, out2) { diff --git a/consul/kvs_endpoint.go b/consul/kvs_endpoint.go index 468ee5f08976..570b7d83b24e 100644 --- a/consul/kvs_endpoint.go +++ b/consul/kvs_endpoint.go @@ -90,12 +90,11 @@ func (k *KVS) Get(args *structs.KeyRequest, reply *structs.IndexedDirEntries) er // Get the local state state := k.srv.fsm.State() - opts := blockingRPCOptions{ - queryOpts: &args.QueryOptions, - queryMeta: &reply.QueryMeta, - kvWatch: true, - kvPrefix: args.Key, - run: func() error { + return k.srv.blockingRPC( + &args.QueryOptions, + &reply.QueryMeta, + state.GetKVSWatch(args.Key), + func() error { index, ent, err := state.KVSGet(args.Key) if err != nil { return err @@ -117,9 +116,7 @@ func (k *KVS) Get(args *structs.KeyRequest, reply *structs.IndexedDirEntries) er reply.Entries = structs.DirEntries{ent} } return nil - }, - } - return k.srv.blockingRPCOpt(&opts) + }) } // List is used to list all keys with a given prefix @@ -135,13 +132,12 @@ func (k *KVS) List(args *structs.KeyRequest, reply *structs.IndexedDirEntries) e // Get the local state state := k.srv.fsm.State() - opts := blockingRPCOptions{ - queryOpts: &args.QueryOptions, - queryMeta: &reply.QueryMeta, - kvWatch: true, - kvPrefix: args.Key, - run: func() error { - tombIndex, index, ent, err := state.KVSList(args.Key) + return k.srv.blockingRPC( + &args.QueryOptions, + &reply.QueryMeta, + state.GetKVSWatch(args.Key), + func() error { + index, ent, err := state.KVSList(args.Key) if err != nil { return err } @@ -158,25 +154,12 @@ func (k *KVS) List(args *structs.KeyRequest, reply *structs.IndexedDirEntries) e reply.Index = index } reply.Entries = nil - } else { - // Determine the maximum affected index - var maxIndex uint64 - for _, e := range ent { - if e.ModifyIndex > maxIndex { - maxIndex = e.ModifyIndex - } - } - if tombIndex > maxIndex { - maxIndex = tombIndex - } - reply.Index = maxIndex + reply.Index = index reply.Entries = ent } return nil - }, - } - return k.srv.blockingRPCOpt(&opts) + }) } // ListKeys is used to list all keys with a given prefix to a separator @@ -192,21 +175,28 @@ func (k *KVS) ListKeys(args *structs.KeyListRequest, reply *structs.IndexedKeyLi // Get the local state state := k.srv.fsm.State() - opts := blockingRPCOptions{ - queryOpts: &args.QueryOptions, - queryMeta: &reply.QueryMeta, - kvWatch: true, - kvPrefix: args.Prefix, - run: func() error { + return k.srv.blockingRPC( + &args.QueryOptions, + &reply.QueryMeta, + state.GetKVSWatch(args.Prefix), + func() error { index, keys, err := state.KVSListKeys(args.Prefix, args.Seperator) - reply.Index = index + if err != nil { + return err + } + + // Must provide non-zero index to prevent blocking + // Index 1 is impossible anyways (due to Raft internals) + if index == 0 { + reply.Index = 1 + } else { + reply.Index = index + } + if acl != nil { keys = FilterKeys(acl, keys) } reply.Keys = keys - return err - - }, - } - return k.srv.blockingRPCOpt(&opts) + return nil + }) } diff --git a/consul/kvs_endpoint_test.go b/consul/kvs_endpoint_test.go index cfaee046efae..0ebec469b95c 100644 --- a/consul/kvs_endpoint_test.go +++ b/consul/kvs_endpoint_test.go @@ -278,6 +278,18 @@ func TestKVSEndpoint_List(t *testing.T) { t.Fatalf("bad: %v", d) } } + + // Try listing a nonexistent prefix + getR.Key = "/nope" + if err := msgpackrpc.CallWithCodec(codec, "KVS.List", &getR, &dirent); err != nil { + t.Fatalf("err: %v", err) + } + if dirent.Index == 0 { + t.Fatalf("Bad: %v", dirent) + } + if len(dirent.Entries) != 0 { + t.Fatalf("Bad: %v", dirent.Entries) + } } func TestKVSEndpoint_List_Blocking(t *testing.T) { @@ -514,6 +526,18 @@ func TestKVSEndpoint_ListKeys(t *testing.T) { if dirent.Keys[2] != "/test/sub/" { t.Fatalf("Bad: %v", dirent.Keys) } + + // Try listing a nonexistent prefix + getR.Prefix = "/nope" + if err := msgpackrpc.CallWithCodec(codec, "KVS.ListKeys", &getR, &dirent); err != nil { + t.Fatalf("err: %v", err) + } + if dirent.Index == 0 { + t.Fatalf("Bad: %v", dirent) + } + if len(dirent.Keys) != 0 { + t.Fatalf("Bad: %v", dirent.Keys) + } } func TestKVSEndpoint_ListKeys_ACLDeny(t *testing.T) { @@ -605,7 +629,7 @@ func TestKVS_Apply_LockDelay(t *testing.T) { // Create and invalidate a session with a lock state := s1.fsm.State() - if err := state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { + if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { t.Fatalf("err: %v", err) } session := &structs.Session{ diff --git a/consul/leader.go b/consul/leader.go index 67be5bb59184..55c487b4fdca 100644 --- a/consul/leader.go +++ b/consul/leader.go @@ -260,7 +260,10 @@ func (s *Server) reconcile() (err error) { // a "reap" event to cause the node to be cleaned up. func (s *Server) reconcileReaped(known map[string]struct{}) error { state := s.fsm.State() - _, checks := state.ChecksInState(structs.HealthAny) + _, checks, err := state.ChecksInState(structs.HealthAny) + if err != nil { + return err + } for _, check := range checks { // Ignore any non serf checks if check.CheckID != SerfCheckID { @@ -282,7 +285,10 @@ func (s *Server) reconcileReaped(known map[string]struct{}) error { } // Get the node services, look for ConsulServiceID - _, services := state.NodeServices(check.Node) + _, services, err := state.NodeServices(check.Node) + if err != nil { + return err + } serverPort := 0 for _, service := range services.Services { if service.ID == ConsulServiceID { @@ -352,8 +358,6 @@ func (s *Server) shouldHandleMember(member serf.Member) bool { // handleAliveMember is used to ensure the node // is registered, with a passing health check. func (s *Server) handleAliveMember(member serf.Member) error { - state := s.fsm.State() - // Register consul service if a server var service *structs.NodeService if valid, parts := isConsulServer(member); valid { @@ -370,12 +374,19 @@ func (s *Server) handleAliveMember(member serf.Member) error { } // Check if the node exists - _, found, addr := state.GetNode(member.Name) - if found && addr == member.Addr.String() { + state := s.fsm.State() + _, node, err := state.GetNode(member.Name) + if err != nil { + return err + } + if node != nil && node.Address == member.Addr.String() { // Check if the associated service is available if service != nil { match := false - _, services := state.NodeServices(member.Name) + _, services, err := state.NodeServices(member.Name) + if err != nil { + return err + } if services != nil { for id, _ := range services.Services { if id == service.ID { @@ -389,7 +400,10 @@ func (s *Server) handleAliveMember(member serf.Member) error { } // Check if the serfCheck is in the passing state - _, checks := state.NodeChecks(member.Name) + _, checks, err := state.NodeChecks(member.Name) + if err != nil { + return err + } for _, check := range checks { if check.CheckID == SerfCheckID && check.Status == structs.HealthPassing { return nil @@ -421,13 +435,18 @@ AFTER_CHECK: // handleFailedMember is used to mark the node's status // as being critical, along with all checks as unknown. func (s *Server) handleFailedMember(member serf.Member) error { - state := s.fsm.State() - // Check if the node exists - _, found, addr := state.GetNode(member.Name) - if found && addr == member.Addr.String() { + state := s.fsm.State() + _, node, err := state.GetNode(member.Name) + if err != nil { + return err + } + if node != nil && node.Address == member.Addr.String() { // Check if the serfCheck is in the critical state - _, checks := state.NodeChecks(member.Name) + _, checks, err := state.NodeChecks(member.Name) + if err != nil { + return err + } for _, check := range checks { if check.CheckID == SerfCheckID && check.Status == structs.HealthCritical { return nil @@ -468,7 +487,6 @@ func (s *Server) handleReapMember(member serf.Member) error { // handleDeregisterMember is used to deregister a member of a given reason func (s *Server) handleDeregisterMember(reason string, member serf.Member) error { - state := s.fsm.State() // Do not deregister ourself. This can only happen if the current leader // is leaving. Instead, we should allow a follower to take-over and // deregister us later. @@ -484,9 +502,13 @@ func (s *Server) handleDeregisterMember(reason string, member serf.Member) error } } - // Check if the node does not exists - _, found, _ := state.GetNode(member.Name) - if !found { + // Check if the node does not exist + state := s.fsm.State() + _, node, err := state.GetNode(member.Name) + if err != nil { + return err + } + if node == nil { return nil } diff --git a/consul/leader_test.go b/consul/leader_test.go index 4155caaf67ba..900f39617d5f 100644 --- a/consul/leader_test.go +++ b/consul/leader_test.go @@ -34,14 +34,20 @@ func TestLeader_RegisterMember(t *testing.T) { // Client should be registered state := s1.fsm.State() testutil.WaitForResult(func() (bool, error) { - _, found, _ := state.GetNode(c1.config.NodeName) - return found == true, nil + _, node, err := state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node != nil, nil }, func(err error) { t.Fatalf("client not registered") }) // Should have a check - _, checks := state.NodeChecks(c1.config.NodeName) + _, checks, err := state.NodeChecks(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } if len(checks) != 1 { t.Fatalf("client missing check") } @@ -56,13 +62,19 @@ func TestLeader_RegisterMember(t *testing.T) { } // Server should be registered - _, found, _ := state.GetNode(s1.config.NodeName) - if !found { + _, node, err := state.GetNode(s1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + if node == nil { t.Fatalf("server not registered") } // Service should be registered - _, services := state.NodeServices(s1.config.NodeName) + _, services, err := state.NodeServices(s1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } if _, ok := services.Services["consul"]; !ok { t.Fatalf("consul service not registered: %v", services) } @@ -92,14 +104,20 @@ func TestLeader_FailedMember(t *testing.T) { // Should be registered state := s1.fsm.State() testutil.WaitForResult(func() (bool, error) { - _, found, _ := state.GetNode(c1.config.NodeName) - return found == true, nil + _, node, err := state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node != nil, nil }, func(err error) { t.Fatalf("client not registered") }) // Should have a check - _, checks := state.NodeChecks(c1.config.NodeName) + _, checks, err := state.NodeChecks(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } if len(checks) != 1 { t.Fatalf("client missing check") } @@ -111,7 +129,10 @@ func TestLeader_FailedMember(t *testing.T) { } testutil.WaitForResult(func() (bool, error) { - _, checks = state.NodeChecks(c1.config.NodeName) + _, checks, err = state.NodeChecks(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } return checks[0].Status == structs.HealthCritical, errors.New(checks[0].Status) }, func(err error) { t.Fatalf("check status is %v, should be critical", err) @@ -134,13 +155,15 @@ func TestLeader_LeftMember(t *testing.T) { t.Fatalf("err: %v", err) } - var found bool state := s1.fsm.State() // Should be registered testutil.WaitForResult(func() (bool, error) { - _, found, _ = state.GetNode(c1.config.NodeName) - return found == true, nil + _, node, err := state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node != nil, nil }, func(err error) { t.Fatalf("client should be registered") }) @@ -151,8 +174,11 @@ func TestLeader_LeftMember(t *testing.T) { // Should be deregistered testutil.WaitForResult(func() (bool, error) { - _, found, _ = state.GetNode(c1.config.NodeName) - return found == false, nil + _, node, err := state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node == nil, nil }, func(err error) { t.Fatalf("client should not be registered") }) @@ -174,13 +200,15 @@ func TestLeader_ReapMember(t *testing.T) { t.Fatalf("err: %v", err) } - var found bool state := s1.fsm.State() // Should be registered testutil.WaitForResult(func() (bool, error) { - _, found, _ = state.GetNode(c1.config.NodeName) - return found == true, nil + _, node, err := state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node != nil, nil }, func(err error) { t.Fatalf("client should be registered") }) @@ -199,8 +227,11 @@ func TestLeader_ReapMember(t *testing.T) { // Should be deregistered testutil.WaitForResult(func() (bool, error) { - _, found, _ = state.GetNode(c1.config.NodeName) - return found == false, nil + _, node, err := state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node == nil, nil }, func(err error) { t.Fatalf("client should not be registered") }) @@ -237,8 +268,11 @@ func TestLeader_Reconcile_ReapMember(t *testing.T) { // Node should be gone state := s1.fsm.State() - _, found, _ := state.GetNode("no-longer-around") - if found { + _, node, err := state.GetNode("no-longer-around") + if err != nil { + t.Fatalf("err: %v", err) + } + if node != nil { t.Fatalf("client registered") } } @@ -261,15 +295,21 @@ func TestLeader_Reconcile(t *testing.T) { // Should not be registered state := s1.fsm.State() - _, found, _ := state.GetNode(c1.config.NodeName) - if found { + _, node, err := state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + if node != nil { t.Fatalf("client registered") } // Should be registered testutil.WaitForResult(func() (bool, error) { - _, found, _ = state.GetNode(c1.config.NodeName) - return found == true, nil + _, node, err = state.GetNode(c1.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node != nil, nil }, func(err error) { t.Fatalf("client should be registered") }) @@ -393,8 +433,11 @@ func TestLeader_LeftLeader(t *testing.T) { // Verify the old leader is deregistered state := remain.fsm.State() testutil.WaitForResult(func() (bool, error) { - _, found, _ := state.GetNode(leader.config.NodeName) - return !found, nil + _, node, err := state.GetNode(leader.config.NodeName) + if err != nil { + t.Fatalf("err: %v", err) + } + return node == nil, nil }, func(err error) { t.Fatalf("leader should be deregistered") }) @@ -536,25 +579,39 @@ func TestLeader_ReapTombstones(t *testing.T) { t.Fatalf("err: %v", err) } - // Delete the KV entry (tombstoned) + // Delete the KV entry (tombstoned). arg.Op = structs.KVSDelete if err := msgpackrpc.CallWithCodec(codec, "KVS.Apply", &arg, &out); err != nil { t.Fatalf("err: %v", err) } - // Ensure we have a tombstone - _, res, err := s1.fsm.State().tombstoneTable.Get("id") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) == 0 { - t.Fatalf("missing tombstones") - } + // Make sure there's a tombstone. + state := s1.fsm.State() + func() { + snap := state.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + if stones.Next() == nil { + t.Fatalf("missing tombstones") + } + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") + } + }() - // Check that the new leader has a pending GC expiration + // Check that the new leader has a pending GC expiration by + // watching for the tombstone to get removed. testutil.WaitForResult(func() (bool, error) { - _, res, err := s1.fsm.State().tombstoneTable.Get("id") - return len(res) == 0, err + snap := state.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + return false, err + } + return stones.Next() == nil, nil }, func(err error) { t.Fatalf("err: %v", err) }) diff --git a/consul/mdb_table.go b/consul/mdb_table.go deleted file mode 100644 index 37eb52503842..000000000000 --- a/consul/mdb_table.go +++ /dev/null @@ -1,830 +0,0 @@ -package consul - -import ( - "bytes" - "fmt" - "reflect" - "strings" - "sync/atomic" - "time" - - "github.com/armon/gomdb" -) - -var ( - noIndex = fmt.Errorf("undefined index") - tooManyFields = fmt.Errorf("number of fields exceeds index arity") -) - -const ( - // lastIndexRowID is a special RowID used to represent the - // last Raft index that affected the table. The index value - // is not used by MDBTable, but is stored so that the client can map - // back to the Raft index number - lastIndexRowID = 0 - - // deadlockTimeout is a heuristic to detect a potential MDB deadlock. - // If we have a transaction that is left open indefinitely, it can - // prevent new transactions from making progress and deadlocking - // the system. If we fail to start a transaction after this long, - // assume a potential deadlock and panic. - deadlockTimeout = 30 * time.Second -) - -/* - An MDB table is a logical representation of a table, which is a - generic row store. It provides a simple mechanism to store rows - using a row id, while maintaining any number of secondary indexes. -*/ -type MDBTable struct { - // Last used rowID. Must be first to avoid 64bit alignment issues. - lastRowID uint64 - - Env *mdb.Env - Name string // This is the name of the table, must be unique - Indexes map[string]*MDBIndex - Encoder func(interface{}) []byte - Decoder func([]byte) interface{} -} - -// MDBTables is used for when we have a collection of tables -type MDBTables []*MDBTable - -// An Index is named, and uses a series of column values to -// map to the row-id containing the table -type MDBIndex struct { - AllowBlank bool // Can fields be blank - Unique bool // Controls if values are unique - Fields []string // Fields are used to build the index - IdxFunc IndexFunc // Can be used to provide custom indexing - Virtual bool // Virtual index does not exist, but can be used for queries - RealIndex string // Virtual indexes use a RealIndex for iteration - CaseInsensitive bool // Controls if values are case-insensitive - - table *MDBTable - name string - dbiName string - realIndex *MDBIndex -} - -// MDBTxn is used to wrap an underlying transaction -type MDBTxn struct { - readonly bool - tx *mdb.Txn - dbis map[string]mdb.DBI - after []func() -} - -// Abort is used to close the transaction -func (t *MDBTxn) Abort() { - if t != nil && t.tx != nil { - t.tx.Abort() - } -} - -// Commit is used to commit a transaction -func (t *MDBTxn) Commit() error { - if err := t.tx.Commit(); err != nil { - return err - } - for _, f := range t.after { - f() - } - t.after = nil - return nil -} - -// Defer is used to defer a function call until a successful commit -func (t *MDBTxn) Defer(f func()) { - t.after = append(t.after, f) -} - -type IndexFunc func(*MDBIndex, []string) string - -// DefaultIndexFunc is used if no IdxFunc is provided. It joins -// the columns using '||' which is reasonably unlikely to occur. -// We also prefix with a byte to ensure we never have a zero length -// key -func DefaultIndexFunc(idx *MDBIndex, parts []string) string { - if len(parts) == 0 { - return "_" - } - prefix := "_" + strings.Join(parts, "||") + "||" - return prefix -} - -// DefaultIndexPrefixFunc can be used with DefaultIndexFunc to scan -// for index prefix values. This should only be used as part of a -// virtual index. -func DefaultIndexPrefixFunc(idx *MDBIndex, parts []string) string { - if len(parts) == 0 { - return "_" - } - prefix := "_" + strings.Join(parts, "||") - return prefix -} - -// Init is used to initialize the MDBTable and ensure it's ready -func (t *MDBTable) Init() error { - if t.Env == nil { - return fmt.Errorf("Missing mdb env") - } - if t.Name == "" { - return fmt.Errorf("Missing table name") - } - if t.Indexes == nil { - return fmt.Errorf("Missing table indexes") - } - - // Ensure we have a unique id index - id, ok := t.Indexes["id"] - if !ok { - return fmt.Errorf("Missing id index") - } - if !id.Unique { - return fmt.Errorf("id index must be unique") - } - if id.AllowBlank { - return fmt.Errorf("id index must not allow blanks") - } - if id.Virtual { - return fmt.Errorf("id index cannot be virtual") - } - - // Create the table - if err := t.createTable(); err != nil { - return fmt.Errorf("table create failed: %v", err) - } - - // Initialize the indexes - for name, index := range t.Indexes { - if err := index.init(t, name); err != nil { - return fmt.Errorf("index %s error: %s", name, err) - } - } - - // Get the maximum row id - if err := t.restoreLastRowID(); err != nil { - return fmt.Errorf("error scanning table: %s", err) - } - - return nil -} - -// createTable is used to ensure the table exists -func (t *MDBTable) createTable() error { - tx, err := t.Env.BeginTxn(nil, 0) - if err != nil { - return err - } - if _, err := tx.DBIOpen(t.Name, mdb.CREATE); err != nil { - tx.Abort() - return err - } - return tx.Commit() -} - -// restoreLastRowID is used to set the last rowID that we've used -func (t *MDBTable) restoreLastRowID() error { - tx, err := t.StartTxn(true, nil) - if err != nil { - return err - } - defer tx.Abort() - - cursor, err := tx.tx.CursorOpen(tx.dbis[t.Name]) - if err != nil { - return err - } - defer cursor.Close() - - key, _, err := cursor.Get(nil, mdb.LAST) - if err == mdb.NotFound { - t.lastRowID = 0 - return nil - } else if err != nil { - return err - } - - // Set the last row id - t.lastRowID = bytesToUint64(key) - return nil -} - -// nextRowID returns the next usable row id -func (t *MDBTable) nextRowID() uint64 { - return atomic.AddUint64(&t.lastRowID, 1) -} - -// startTxn is used to start a transaction -func (t *MDBTable) StartTxn(readonly bool, mdbTxn *MDBTxn) (*MDBTxn, error) { - var txFlags uint = 0 - var tx *mdb.Txn - var err error - - // Panic if we deadlock acquiring a transaction - timeout := time.AfterFunc(deadlockTimeout, func() { - panic("Timeout starting MDB transaction, potential deadlock") - }) - defer timeout.Stop() - - // Ensure the modes agree - if mdbTxn != nil { - if mdbTxn.readonly != readonly { - return nil, fmt.Errorf("Cannot mix read/write transactions") - } - tx = mdbTxn.tx - goto EXTEND - } - - if readonly { - txFlags |= mdb.RDONLY - } - - tx, err = t.Env.BeginTxn(nil, txFlags) - if err != nil { - return nil, err - } - - mdbTxn = &MDBTxn{ - readonly: readonly, - tx: tx, - dbis: make(map[string]mdb.DBI), - } -EXTEND: - dbi, err := tx.DBIOpen(t.Name, 0) - if err != nil { - tx.Abort() - return nil, err - } - mdbTxn.dbis[t.Name] = dbi - - for _, index := range t.Indexes { - if index.Virtual { - continue - } - dbi, err := index.openDBI(tx) - if err != nil { - tx.Abort() - return nil, err - } - mdbTxn.dbis[index.dbiName] = dbi - } - - return mdbTxn, nil -} - -// objIndexKeys builds the indexes for a given object -func (t *MDBTable) objIndexKeys(obj interface{}) (map[string][]byte, error) { - // Construct the indexes keys - indexes := make(map[string][]byte) - for name, index := range t.Indexes { - if index.Virtual { - continue - } - key, err := index.keyFromObject(obj) - if err != nil { - return nil, err - } - indexes[name] = key - } - return indexes, nil -} - -// Insert is used to insert or update an object -func (t *MDBTable) Insert(obj interface{}) error { - // Start a new txn - tx, err := t.StartTxn(false, nil) - if err != nil { - return err - } - defer tx.Abort() - - if err := t.InsertTxn(tx, obj); err != nil { - return err - } - return tx.Commit() -} - -// Insert is used to insert or update an object within -// a given transaction -func (t *MDBTable) InsertTxn(tx *MDBTxn, obj interface{}) error { - var n int - // Construct the indexes keys - indexes, err := t.objIndexKeys(obj) - if err != nil { - return err - } - - // Encode the obj - raw := t.Encoder(obj) - - // Scan and check if this primary key already exists - primaryDbi := tx.dbis[t.Indexes["id"].dbiName] - _, err = tx.tx.Get(primaryDbi, indexes["id"]) - if err == mdb.NotFound { - goto AFTER_DELETE - } - - // Delete the existing row - n, err = t.deleteWithIndex(tx, t.Indexes["id"], indexes["id"]) - if err != nil { - return err - } - if n != 1 { - return fmt.Errorf("unexpected number of updates: %d", n) - } - -AFTER_DELETE: - // Insert with a new row ID - rowId := t.nextRowID() - encRowId := uint64ToBytes(rowId) - table := tx.dbis[t.Name] - if err := tx.tx.Put(table, encRowId, raw, 0); err != nil { - return err - } - - // Insert the new indexes - for name, index := range t.Indexes { - if index.Virtual { - continue - } - dbi := tx.dbis[index.dbiName] - if err := tx.tx.Put(dbi, indexes[name], encRowId, 0); err != nil { - return err - } - } - return nil -} - -// Get is used to lookup one or more rows. An index an appropriate -// fields are specified. The fields can be a prefix of the index. -func (t *MDBTable) Get(index string, parts ...string) (uint64, []interface{}, error) { - // Start a readonly txn - tx, err := t.StartTxn(true, nil) - if err != nil { - return 0, nil, err - } - defer tx.Abort() - - // Get the last associated index - idx, err := t.LastIndexTxn(tx) - if err != nil { - return 0, nil, err - } - - // Get the actual results - res, err := t.GetTxn(tx, index, parts...) - return idx, res, err -} - -// GetTxn is like Get but it operates within a specific transaction. -// This can be used for read that span multiple tables -func (t *MDBTable) GetTxn(tx *MDBTxn, index string, parts ...string) ([]interface{}, error) { - // Get the associated index - idx, key, err := t.getIndex(index, parts) - if err != nil { - return nil, err - } - - // Accumulate the results - var results []interface{} - err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { - obj := t.Decoder(res) - results = append(results, obj) - return false, false - }) - - return results, err -} - -// GetTxnLimit is like GetTxn limits the maximum number of -// rows it will return -func (t *MDBTable) GetTxnLimit(tx *MDBTxn, limit int, index string, parts ...string) ([]interface{}, error) { - // Get the associated index - idx, key, err := t.getIndex(index, parts) - if err != nil { - return nil, err - } - - // Accumulate the results - var results []interface{} - num := 0 - err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { - num++ - obj := t.Decoder(res) - results = append(results, obj) - return false, num == limit - }) - - return results, err -} - -// StreamTxn is like GetTxn but it streams the results over a channel. -// This can be used if the expected data set is very large. The stream -// is always closed on return. -func (t *MDBTable) StreamTxn(stream chan<- interface{}, tx *MDBTxn, index string, parts ...string) error { - // Always close the stream on return - defer close(stream) - - // Get the associated index - idx, key, err := t.getIndex(index, parts) - if err != nil { - return err - } - - // Stream the results - err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { - obj := t.Decoder(res) - stream <- obj - return false, false - }) - - return err -} - -// getIndex is used to get the proper index, and also check the arity -func (t *MDBTable) getIndex(index string, parts []string) (*MDBIndex, []byte, error) { - // Get the index - idx, ok := t.Indexes[index] - if !ok { - return nil, nil, noIndex - } - - // Check the arity - arity := idx.arity() - if len(parts) > arity { - return nil, nil, tooManyFields - } - - if idx.CaseInsensitive { - parts = ToLowerList(parts) - } - - // Construct the key - key := idx.keyFromParts(parts...) - return idx, key, nil -} - -// Delete is used to delete one or more rows. An index an appropriate -// fields are specified. The fields can be a prefix of the index. -// Returns the rows deleted or an error. -func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) { - // Start a write txn - tx, err := t.StartTxn(false, nil) - if err != nil { - return 0, err - } - defer tx.Abort() - - num, err = t.DeleteTxn(tx, index, parts...) - if err != nil { - return 0, err - } - return num, tx.Commit() -} - -// DeleteTxn is like Delete, but occurs in a specific transaction -// that can span multiple tables. -func (t *MDBTable) DeleteTxn(tx *MDBTxn, index string, parts ...string) (int, error) { - // Get the associated index - idx, key, err := t.getIndex(index, parts) - if err != nil { - return 0, err - } - - // Delete with the index - return t.deleteWithIndex(tx, idx, key) -} - -// deleteWithIndex deletes all associated rows while scanning -// a given index for a key prefix. May perform multiple index traversals. -// This is a hack around a bug in LMDB which can cause a partial delete to -// take place. To fix this, we invoke the innerDelete until all rows are -// removed. This hack can be removed once the LMDB bug is resolved. -func (t *MDBTable) deleteWithIndex(tx *MDBTxn, idx *MDBIndex, key []byte) (int, error) { - var total int - var num int - var err error -DELETE: - num, err = t.innerDeleteWithIndex(tx, idx, key) - total += num - if err != nil { - return total, err - } - if num > 0 { - goto DELETE - } - return total, nil -} - -// innerDeleteWithIndex deletes all associated rows while scanning -// a given index for a key prefix. It only traverses the index a single time. -func (t *MDBTable) innerDeleteWithIndex(tx *MDBTxn, idx *MDBIndex, key []byte) (num int, err error) { - // Handle an error while deleting - defer func() { - if r := recover(); r != nil { - num = 0 - err = fmt.Errorf("Panic while deleting: %v", r) - } - }() - - // Delete everything as we iterate - err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { - // Get the object - obj := t.Decoder(res) - - // Build index values - indexes, err := t.objIndexKeys(obj) - if err != nil { - panic(err) - } - - // Delete the indexes we are not iterating - for name, otherIdx := range t.Indexes { - if name == idx.name { - continue - } - if idx.Virtual && name == idx.RealIndex { - continue - } - if otherIdx.Virtual { - continue - } - dbi := tx.dbis[otherIdx.dbiName] - if err := tx.tx.Del(dbi, indexes[name], encRowId); err != nil { - panic(err) - } - } - - // Delete the data row - if err := tx.tx.Del(tx.dbis[t.Name], encRowId, nil); err != nil { - panic(err) - } - - // Delete the object - num++ - return true, false - }) - if err != nil { - return 0, err - } - - // Return the deleted count - return num, nil -} - -// Initializes an index and returns a potential error -func (i *MDBIndex) init(table *MDBTable, name string) error { - i.table = table - i.name = name - i.dbiName = fmt.Sprintf("%s_%s_idx", i.table.Name, i.name) - if i.IdxFunc == nil { - i.IdxFunc = DefaultIndexFunc - } - if len(i.Fields) == 0 { - return fmt.Errorf("index missing fields") - } - if err := i.createIndex(); err != nil { - return err - } - // Verify real index exists - if i.Virtual { - if realIndex, ok := table.Indexes[i.RealIndex]; !ok { - return fmt.Errorf("real index '%s' missing", i.RealIndex) - } else { - i.realIndex = realIndex - } - } - return nil -} - -// createIndex is used to ensure the index exists -func (i *MDBIndex) createIndex() error { - // Do not create if this is a virtual index - if i.Virtual { - return nil - } - tx, err := i.table.Env.BeginTxn(nil, 0) - if err != nil { - return err - } - var dbFlags uint = mdb.CREATE - if !i.Unique { - dbFlags |= mdb.DUPSORT - } - if _, err := tx.DBIOpen(i.dbiName, dbFlags); err != nil { - tx.Abort() - return err - } - return tx.Commit() -} - -// openDBI is used to open a handle to the index for a transaction -func (i *MDBIndex) openDBI(tx *mdb.Txn) (mdb.DBI, error) { - var dbFlags uint - if !i.Unique { - dbFlags |= mdb.DUPSORT - } - return tx.DBIOpen(i.dbiName, dbFlags) -} - -// Returns the arity of the index -func (i *MDBIndex) arity() int { - return len(i.Fields) -} - -// keyFromObject constructs the index key from the object -func (i *MDBIndex) keyFromObject(obj interface{}) ([]byte, error) { - v := reflect.ValueOf(obj) - v = reflect.Indirect(v) // Derefence the pointer if any - parts := make([]string, 0, i.arity()) - for _, field := range i.Fields { - fv := v.FieldByName(field) - if !fv.IsValid() { - return nil, fmt.Errorf("Field '%s' for %#v is invalid", field, obj) - } - val := fv.String() - if !i.AllowBlank && val == "" { - return nil, fmt.Errorf("Field '%s' must be set: %#v", field, obj) - } - if i.CaseInsensitive { - val = strings.ToLower(val) - } - parts = append(parts, val) - } - key := i.keyFromParts(parts...) - return key, nil -} - -// keyFromParts returns the key from component parts -func (i *MDBIndex) keyFromParts(parts ...string) []byte { - return []byte(i.IdxFunc(i, parts)) -} - -// iterate is used to iterate over keys matching the prefix, -// and invoking the cb with each row. We dereference the rowid, -// and only return the object row -func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, - cb func(encRowId, res []byte) (bool, bool)) error { - table := tx.dbis[i.table.Name] - - // If virtual, use the correct DBI - var dbi mdb.DBI - if i.Virtual { - dbi = tx.dbis[i.realIndex.dbiName] - } else { - dbi = tx.dbis[i.dbiName] - } - - cursor, err := tx.tx.CursorOpen(dbi) - if err != nil { - return err - } - // Read-only cursors are NOT closed by MDB when a transaction - // either commits or aborts, so must be closed explicitly - if tx.readonly { - defer cursor.Close() - } - - var key, encRowId, objBytes []byte - first := true - shouldStop := false - shouldDelete := false - for !shouldStop { - if first && len(prefix) > 0 { - first = false - key, encRowId, err = cursor.Get(prefix, mdb.SET_RANGE) - } else if shouldDelete { - key, encRowId, err = cursor.Get(nil, mdb.GET_CURRENT) - shouldDelete = false - - // LMDB will return EINVAL(22) for the GET_CURRENT op if - // there is no further keys. We treat this as no more - // keys being found. - if num, ok := err.(mdb.Errno); ok && num == 22 { - err = mdb.NotFound - } - } else if i.Unique { - key, encRowId, err = cursor.Get(nil, mdb.NEXT) - } else { - key, encRowId, err = cursor.Get(nil, mdb.NEXT_DUP) - if err == mdb.NotFound { - key, encRowId, err = cursor.Get(nil, mdb.NEXT) - } - } - if err == mdb.NotFound { - break - } else if err != nil { - return fmt.Errorf("iterate failed: %v", err) - } - - // Bail if this does not match our filter - if len(prefix) > 0 && !bytes.HasPrefix(key, prefix) { - break - } - - // Lookup the actual object - objBytes, err = tx.tx.Get(table, encRowId) - if err != nil { - return fmt.Errorf("rowid lookup failed: %v (%v)", err, encRowId) - } - - // Invoke the cb - shouldDelete, shouldStop = cb(encRowId, objBytes) - if shouldDelete { - if err := cursor.Del(0); err != nil { - return fmt.Errorf("delete failed: %v", err) - } - } - } - return nil -} - -// LastIndex is get the last index that updated the table -func (t *MDBTable) LastIndex() (uint64, error) { - // Start a readonly txn - tx, err := t.StartTxn(true, nil) - if err != nil { - return 0, err - } - defer tx.Abort() - return t.LastIndexTxn(tx) -} - -// LastIndexTxn is like LastIndex but it operates within a specific transaction. -func (t *MDBTable) LastIndexTxn(tx *MDBTxn) (uint64, error) { - encRowId := uint64ToBytes(lastIndexRowID) - val, err := tx.tx.Get(tx.dbis[t.Name], encRowId) - if err == mdb.NotFound { - return 0, nil - } else if err != nil { - return 0, err - } - - // Return the last index - return bytesToUint64(val), nil -} - -// SetLastIndex is used to set the last index that updated the table -func (t *MDBTable) SetLastIndex(index uint64) error { - tx, err := t.StartTxn(false, nil) - if err != nil { - return err - } - defer tx.Abort() - - if err := t.SetLastIndexTxn(tx, index); err != nil { - return err - } - return tx.Commit() -} - -// SetLastIndexTxn is used to set the last index within a transaction -func (t *MDBTable) SetLastIndexTxn(tx *MDBTxn, index uint64) error { - encRowId := uint64ToBytes(lastIndexRowID) - encIndex := uint64ToBytes(index) - return tx.tx.Put(tx.dbis[t.Name], encRowId, encIndex, 0) -} - -// SetMaxLastIndexTxn is used to set the last index within a transaction -// if it exceeds the current maximum -func (t *MDBTable) SetMaxLastIndexTxn(tx *MDBTxn, index uint64) error { - current, err := t.LastIndexTxn(tx) - if err != nil { - return err - } - if index > current { - return t.SetLastIndexTxn(tx, index) - } - return nil -} - -// StartTxn is used to create a transaction that spans a list of tables -func (t MDBTables) StartTxn(readonly bool) (*MDBTxn, error) { - var tx *MDBTxn - for _, table := range t { - newTx, err := table.StartTxn(readonly, tx) - if err != nil { - tx.Abort() - return nil, err - } - tx = newTx - } - return tx, nil -} - -// LastIndexTxn is used to get the last transaction from all of the tables -func (t MDBTables) LastIndexTxn(tx *MDBTxn) (uint64, error) { - var index uint64 - for _, table := range t { - idx, err := table.LastIndexTxn(tx) - if err != nil { - return index, err - } - if idx > index { - index = idx - } - } - return index, nil -} diff --git a/consul/mdb_table_test.go b/consul/mdb_table_test.go deleted file mode 100644 index 73e4001d12e1..000000000000 --- a/consul/mdb_table_test.go +++ /dev/null @@ -1,1048 +0,0 @@ -package consul - -import ( - "bytes" - "io/ioutil" - "os" - "reflect" - "testing" - - "github.com/armon/gomdb" - "github.com/hashicorp/go-msgpack/codec" -) - -type MockData struct { - Key string - First string - Last string - Country string -} - -func MockEncoder(obj interface{}) []byte { - buf := bytes.NewBuffer(nil) - encoder := codec.NewEncoder(buf, msgpackHandle) - err := encoder.Encode(obj) - if err != nil { - panic(err) - } - return buf.Bytes() -} - -func MockDecoder(buf []byte) interface{} { - out := new(MockData) - err := codec.NewDecoder(bytes.NewReader(buf), msgpackHandle).Decode(out) - if err != nil { - panic(err) - } - return out -} - -func testMDBEnv(t *testing.T) (string, *mdb.Env) { - // Create a new temp dir - path, err := ioutil.TempDir("", "consul") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Open the env - env, err := mdb.NewEnv() - if err != nil { - t.Fatalf("err: %v", err) - } - - // Setup the Env first - if err := env.SetMaxDBs(mdb.DBI(32)); err != nil { - t.Fatalf("err: %v", err) - } - - // Increase the maximum map size - if err := env.SetMapSize(dbMaxMapSize32bit); err != nil { - t.Fatalf("err: %v", err) - } - - // Open the DB - var flags uint = mdb.NOMETASYNC | mdb.NOSYNC | mdb.NOTLS - if err := env.Open(path, flags, 0755); err != nil { - t.Fatalf("err: %v", err) - } - - return path, env -} - -func TestMDBTableInsert(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - Country: "Mexico", - }, - } - - // Insert some mock objects - for idx, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - if err := table.SetLastIndex(uint64(idx + 1)); err != nil { - t.Fatalf("err: %v", err) - } - } - - // Verify with some gets - idx, res, err := table.Get("id", "1") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 1 { - t.Fatalf("expect 1 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[0]) { - t.Fatalf("bad: %#v", res[0]) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - - idx, res, err = table.Get("name", "Kevin") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 2 { - t.Fatalf("expect 2 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[0]) { - t.Fatalf("bad: %#v", res[0]) - } - if !reflect.DeepEqual(res[1], objs[1]) { - t.Fatalf("bad: %#v", res[1]) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - - idx, res, err = table.Get("country", "Mexico") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 1 { - t.Fatalf("expect 1 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[2]) { - t.Fatalf("bad: %#v", res[2]) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } - - idx, res, err = table.Get("id") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 3 { - t.Fatalf("expect 2 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[0]) { - t.Fatalf("bad: %#v", res[0]) - } - if !reflect.DeepEqual(res[1], objs[1]) { - t.Fatalf("bad: %#v", res[1]) - } - if !reflect.DeepEqual(res[2], objs[2]) { - t.Fatalf("bad: %#v", res[2]) - } - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestMDBTableInsert_MissingFields(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - }, - } - - // Insert some mock objects - for _, obj := range objs { - if err := table.Insert(obj); err == nil { - t.Fatalf("expected err") - } - } -} - -func TestMDBTableInsert_AllowBlank(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - AllowBlank: true, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "", - }, - } - - // Insert some mock objects - for _, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - } -} - -func TestMDBTableDelete(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - Country: "Mexico", - }, - } - - // Insert some mock objects - for _, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - } - - _, _, err := table.Get("id", "3") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Verify with some gets - num, err := table.Delete("id", "3") - if err != nil { - t.Fatalf("err: %v", err) - } - if num != 1 { - t.Fatalf("expect 1 delete: %#v", num) - } - _, res, err := table.Get("id", "3") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 0 { - t.Fatalf("expect 0 result: %#v", res) - } - - num, err = table.Delete("name", "Kevin") - if err != nil { - t.Fatalf("err: %v", err) - } - if num != 2 { - t.Fatalf("expect 2 deletes: %#v", num) - } - _, res, err = table.Get("name", "Kevin") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 0 { - t.Fatalf("expect 0 results: %#v", res) - } -} - -func TestMDBTableUpdate(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - Country: "Mexico", - }, - &MockData{ - Key: "1", - First: "Roger", - Last: "Rodrigez", - Country: "Mexico", - }, - &MockData{ - Key: "2", - First: "Anna", - Last: "Smith", - Country: "UK", - }, - &MockData{ - Key: "3", - First: "Ahmad", - Last: "Badari", - Country: "Iran", - }, - } - - // Insert and update some mock objects - for _, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - } - - // Verify with some gets - _, res, err := table.Get("id", "1") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 1 { - t.Fatalf("expect 1 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[3]) { - t.Fatalf("bad: %#v", res[0]) - } - - _, res, err = table.Get("name", "Kevin") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 0 { - t.Fatalf("expect 0 result: %#v", res) - } - - _, res, err = table.Get("name", "Ahmad") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 1 { - t.Fatalf("expect 1 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[5]) { - t.Fatalf("bad: %#v", res[0]) - } - - _, res, err = table.Get("country", "Mexico") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 1 { - t.Fatalf("expect 1 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[3]) { - t.Fatalf("bad: %#v", res[0]) - } - - _, res, err = table.Get("id") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 3 { - t.Fatalf("expect 3 result: %#v", res) - } - if !reflect.DeepEqual(res[0], objs[3]) { - t.Fatalf("bad: %#v", res[0]) - } - if !reflect.DeepEqual(res[1], objs[4]) { - t.Fatalf("bad: %#v", res[1]) - } - if !reflect.DeepEqual(res[2], objs[5]) { - t.Fatalf("bad: %#v", res[2]) - } -} - -func TestMDBTableLastRowID(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - if table.lastRowID != 0 { - t.Fatalf("bad last row id: %d", table.lastRowID) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - Country: "Mexico", - }, - } - - // Insert some mock objects - for _, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - } - - if table.lastRowID != 3 { - t.Fatalf("bad last row id: %d", table.lastRowID) - } - - // Remount the table - table2 := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table2.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - if table2.lastRowID != 3 { - t.Fatalf("bad last row id: %d", table2.lastRowID) - } -} - -func TestMDBTableIndex(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - if table.lastRowID != 0 { - t.Fatalf("bad last row id: %d", table.lastRowID) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - Country: "Mexico", - }, - } - - // Insert some mock objects - for idx, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - if err := table.SetLastIndex(uint64(4 * idx)); err != nil { - t.Fatalf("err: %v", err) - } - } - - if table.lastRowID != 3 { - t.Fatalf("bad last row id: %d", table.lastRowID) - } - - if idx, _ := table.LastIndex(); idx != 8 { - t.Fatalf("bad last idx: %d", idx) - } - - // Remount the table - table2 := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table2.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - if table2.lastRowID != 3 { - t.Fatalf("bad last row id: %d", table2.lastRowID) - } - - if idx, _ := table2.LastIndex(); idx != 8 { - t.Fatalf("bad last idx: %d", idx) - } -} - -func TestMDBTableDelete_Prefix(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"First", "Last"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "James", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Kevin", - Last: "Torres", - Country: "Mexico", - }, - &MockData{ - Key: "1", - First: "Lana", - Last: "Smith", - Country: "USA", - }, - } - - // Insert some mock objects - for _, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - } - - // This should nuke all kevins - num, err := table.Delete("id", "Kevin") - if err != nil { - t.Fatalf("err: %v", err) - } - if num != 3 { - t.Fatalf("expect 3 delete: %#v", num) - } - _, res, err := table.Get("id") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 2 { - t.Fatalf("expect 2 result: %#v", res) - } -} - -func TestMDBTableVirtualIndex(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"First"}, - }, - "id_prefix": &MDBIndex{ - Virtual: true, - RealIndex: "id", - Fields: []string{"First"}, - IdxFunc: DefaultIndexPrefixFunc, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - if table.lastRowID != 0 { - t.Fatalf("bad last row id: %d", table.lastRowID) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Jack", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "John", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "James", - Last: "Torres", - Country: "Mexico", - }, - } - - // Insert some mock objects - for idx, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - if err := table.SetLastIndex(uint64(4 * idx)); err != nil { - t.Fatalf("err: %v", err) - } - } - - if table.lastRowID != 3 { - t.Fatalf("bad last row id: %d", table.lastRowID) - } - - if idx, _ := table.LastIndex(); idx != 8 { - t.Fatalf("bad last idx: %d", idx) - } - - _, res, err := table.Get("id_prefix", "J") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 3 { - t.Fatalf("expect 3 result: %#v", res) - } - - _, res, err = table.Get("id_prefix", "Ja") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 2 { - t.Fatalf("expect 2 result: %#v", res) - } - - num, err := table.Delete("id_prefix", "Ja") - if err != nil { - t.Fatalf("err: %v", err) - } - if num != 2 { - t.Fatalf("expect 2 result: %#v", num) - } - - _, res, err = table.Get("id_prefix", "J") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 1 { - t.Fatalf("expect 1 result: %#v", res) - } -} - -func TestMDBTableStream(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - Country: "Mexico", - }, - } - - // Insert some mock objects - for idx, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - if err := table.SetLastIndex(uint64(idx + 1)); err != nil { - t.Fatalf("err: %v", err) - } - } - - // Start a readonly txn - tx, err := table.StartTxn(true, nil) - if err != nil { - panic(err) - } - defer tx.Abort() - - // Stream the records - streamCh := make(chan interface{}) - go func() { - if err := table.StreamTxn(streamCh, tx, "id"); err != nil { - t.Fatalf("err: %v", err) - } - }() - - // Verify we get them all - idx := 0 - for obj := range streamCh { - p := obj.(*MockData) - if !reflect.DeepEqual(p, objs[idx]) { - t.Fatalf("bad: %#v %#v", p, objs[idx]) - } - idx++ - } - - if idx != 3 { - t.Fatalf("bad index: %d", idx) - } -} - -func TestMDBTableGetTxnLimit(t *testing.T) { - dir, env := testMDBEnv(t) - defer os.RemoveAll(dir) - defer env.Close() - - table := &MDBTable{ - Env: env, - Name: "test", - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "name": &MDBIndex{ - Fields: []string{"First", "Last"}, - }, - "country": &MDBIndex{ - Fields: []string{"Country"}, - }, - }, - Encoder: MockEncoder, - Decoder: MockDecoder, - } - if err := table.Init(); err != nil { - t.Fatalf("err: %v", err) - } - - objs := []*MockData{ - &MockData{ - Key: "1", - First: "Kevin", - Last: "Smith", - Country: "USA", - }, - &MockData{ - Key: "2", - First: "Kevin", - Last: "Wang", - Country: "USA", - }, - &MockData{ - Key: "3", - First: "Bernardo", - Last: "Torres", - Country: "Mexico", - }, - } - - // Insert some mock objects - for idx, obj := range objs { - if err := table.Insert(obj); err != nil { - t.Fatalf("err: %v", err) - } - if err := table.SetLastIndex(uint64(idx + 1)); err != nil { - t.Fatalf("err: %v", err) - } - } - - // Start a readonly txn - tx, err := table.StartTxn(true, nil) - if err != nil { - panic(err) - } - defer tx.Abort() - - // Verify with some gets - res, err := table.GetTxnLimit(tx, 2, "id") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 2 { - t.Fatalf("expect 2 result: %#v", res) - } -} diff --git a/consul/rpc.go b/consul/rpc.go index 292f71949ecf..3beeb07dac28 100644 --- a/consul/rpc.go +++ b/consul/rpc.go @@ -10,6 +10,7 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/yamux" @@ -296,98 +297,67 @@ func (s *Server) raftApply(t structs.MessageType, msg interface{}) (interface{}, return future.Response(), nil } -// blockingRPC is used for queries that need to wait for a -// minimum index. This is used to block and wait for changes. -func (s *Server) blockingRPC(b *structs.QueryOptions, m *structs.QueryMeta, - tables MDBTables, run func() error) error { - opts := blockingRPCOptions{ - queryOpts: b, - queryMeta: m, - tables: tables, - run: run, - } - return s.blockingRPCOpt(&opts) -} - -// blockingRPCOptions is used to parameterize blockingRPCOpt since -// it takes so many options. It should be preferred over blockingRPC. -type blockingRPCOptions struct { - queryOpts *structs.QueryOptions - queryMeta *structs.QueryMeta - tables MDBTables - kvWatch bool - kvPrefix string - run func() error -} - -// blockingRPCOpt is the replacement for blockingRPC as it allows -// for more parameterization easily. It should be preferred over blockingRPC. -func (s *Server) blockingRPCOpt(opts *blockingRPCOptions) error { +// blockingRPC is used for queries that need to wait for a minimum index. This +// is used to block and wait for changes. +func (s *Server) blockingRPC(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta, + watch state.Watch, run func() error) error { var timeout *time.Timer var notifyCh chan struct{} - var state *StateStore - // Fast path non-blocking - if opts.queryOpts.MinQueryIndex == 0 { + // Fast path right to the non-blocking query. + if queryOpts.MinQueryIndex == 0 { goto RUN_QUERY } - // Sanity check that we have tables to block on - if len(opts.tables) == 0 && !opts.kvWatch { - panic("no tables to block on") + // Make sure a watch was given if we were asked to block. + if watch == nil { + panic("no watch given for blocking query") } - // Restrict the max query time, and ensure there is always one - if opts.queryOpts.MaxQueryTime > maxQueryTime { - opts.queryOpts.MaxQueryTime = maxQueryTime - } else if opts.queryOpts.MaxQueryTime <= 0 { - opts.queryOpts.MaxQueryTime = defaultQueryTime + // Restrict the max query time, and ensure there is always one. + if queryOpts.MaxQueryTime > maxQueryTime { + queryOpts.MaxQueryTime = maxQueryTime + } else if queryOpts.MaxQueryTime <= 0 { + queryOpts.MaxQueryTime = defaultQueryTime } - // Apply a small amount of jitter to the request - opts.queryOpts.MaxQueryTime += randomStagger(opts.queryOpts.MaxQueryTime / jitterFraction) + // Apply a small amount of jitter to the request. + queryOpts.MaxQueryTime += randomStagger(queryOpts.MaxQueryTime / jitterFraction) - // Setup a query timeout - timeout = time.NewTimer(opts.queryOpts.MaxQueryTime) + // Setup a query timeout. + timeout = time.NewTimer(queryOpts.MaxQueryTime) - // Setup the notify channel + // Setup the notify channel. notifyCh = make(chan struct{}, 1) - // Ensure we tear down any watchers on return - state = s.fsm.State() + // Ensure we tear down any watches on return. defer func() { timeout.Stop() - state.StopWatch(opts.tables, notifyCh) - if opts.kvWatch { - state.StopWatchKV(opts.kvPrefix, notifyCh) - } + watch.Clear(notifyCh) }() REGISTER_NOTIFY: - // Register the notification channel. This may be done - // multiple times if we have not reached the target wait index. - state.Watch(opts.tables, notifyCh) - if opts.kvWatch { - state.WatchKV(opts.kvPrefix, notifyCh) - } + // Register the notification channel. This may be done multiple times if + // we haven't reached the target wait index. + watch.Wait(notifyCh) RUN_QUERY: - // Update the query meta data - s.setQueryMeta(opts.queryMeta) + // Update the query metadata. + s.setQueryMeta(queryMeta) - // Check if query must be consistent - if opts.queryOpts.RequireConsistent { + // If the read must be consistent we verify that we are still the leader. + if queryOpts.RequireConsistent { if err := s.consistentRead(); err != nil { return err } } - // Run the query function + // Run the query. metrics.IncrCounter([]string{"consul", "rpc", "query"}, 1) - err := opts.run() + err := run() - // Check for minimum query time - if err == nil && opts.queryMeta.Index > 0 && opts.queryMeta.Index <= opts.queryOpts.MinQueryIndex { + // Check for minimum query time. + if err == nil && queryMeta.Index > 0 && queryMeta.Index <= queryOpts.MinQueryIndex { select { case <-notifyCh: goto REGISTER_NOTIFY diff --git a/consul/server.go b/consul/server.go index 34ec75224b65..528b6f16ac40 100644 --- a/consul/server.go +++ b/consul/server.go @@ -15,6 +15,7 @@ import ( "time" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/consul/state" "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/raft" "github.com/hashicorp/raft-boltdb" @@ -33,7 +34,6 @@ const ( serfLANSnapshot = "serf/local.snapshot" serfWANSnapshot = "serf/remote.snapshot" raftState = "raft/" - tmpStatePath = "tmp/" snapshotsRetained = 2 // serverRPCCache controls how long we keep an idle connection @@ -135,7 +135,7 @@ type Server struct { // tombstoneGC is used to track the pending GC invocations // for the KV tombstones - tombstoneGC *TombstoneGC + tombstoneGC *state.TombstoneGC shutdown bool shutdownCh chan struct{} @@ -193,7 +193,7 @@ func NewServer(config *Config) (*Server, error) { logger := log.New(config.LogOutput, "", log.LstdFlags) // Create the tombstone GC - gc, err := NewTombstoneGC(config.TombstoneTTL, config.TombstoneTTLGranularity) + gc, err := state.NewTombstoneGC(config.TombstoneTTL, config.TombstoneTTLGranularity) if err != nil { return nil, err } @@ -316,18 +316,9 @@ func (s *Server) setupRaft() error { s.config.RaftConfig.EnableSingleNode = true } - // Create the base state path - statePath := filepath.Join(s.config.DataDir, tmpStatePath) - if err := os.RemoveAll(statePath); err != nil { - return err - } - if err := ensurePath(statePath, true); err != nil { - return err - } - // Create the FSM var err error - s.fsm, err = NewFSM(s.tombstoneGC, statePath, s.config.LogOutput) + s.fsm, err = NewFSM(s.tombstoneGC, s.config.LogOutput) if err != nil { return err } @@ -490,11 +481,6 @@ func (s *Server) Shutdown() error { // Close the connection pool s.connPool.Shutdown() - // Close the fsm - if s.fsm != nil { - s.fsm.Close() - } - return nil } diff --git a/consul/session_endpoint.go b/consul/session_endpoint.go index 2d6fe05ada57..df9c2072d2ad 100644 --- a/consul/session_endpoint.go +++ b/consul/session_endpoint.go @@ -109,18 +109,23 @@ func (s *Session) Get(args *structs.SessionSpecificRequest, // Get the local state state := s.srv.fsm.State() - return s.srv.blockingRPC(&args.QueryOptions, + return s.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("SessionGet"), + state.GetQueryWatch("SessionGet"), func() error { index, session, err := state.SessionGet(args.Session) + if err != nil { + return err + } + reply.Index = index if session != nil { reply.Sessions = structs.Sessions{session} } else { reply.Sessions = nil } - return err + return nil }) } @@ -133,13 +138,18 @@ func (s *Session) List(args *structs.DCSpecificRequest, // Get the local state state := s.srv.fsm.State() - return s.srv.blockingRPC(&args.QueryOptions, + return s.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("SessionList"), + state.GetQueryWatch("SessionList"), func() error { - var err error - reply.Index, reply.Sessions, err = state.SessionList() - return err + index, sessions, err := state.SessionList() + if err != nil { + return err + } + + reply.Index, reply.Sessions = index, sessions + return nil }) } @@ -152,13 +162,18 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest, // Get the local state state := s.srv.fsm.State() - return s.srv.blockingRPC(&args.QueryOptions, + return s.srv.blockingRPC( + &args.QueryOptions, &reply.QueryMeta, - state.QueryTables("NodeSessions"), + state.GetQueryWatch("NodeSessions"), func() error { - var err error - reply.Index, reply.Sessions, err = state.NodeSessions(args.Node) - return err + index, sessions, err := state.NodeSessions(args.Node) + if err != nil { + return err + } + + reply.Index, reply.Sessions = index, sessions + return nil }) } diff --git a/consul/session_endpoint_test.go b/consul/session_endpoint_test.go index 5e76da093800..db8035f3b0cf 100644 --- a/consul/session_endpoint_test.go +++ b/consul/session_endpoint_test.go @@ -20,7 +20,7 @@ func TestSessionEndpoint_Apply(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") // Just add a node - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) arg := structs.SessionRequest{ Datacenter: "dc1", @@ -79,7 +79,7 @@ func TestSessionEndpoint_DeleteApply(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") // Just add a node - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) arg := structs.SessionRequest{ Datacenter: "dc1", @@ -141,7 +141,7 @@ func TestSessionEndpoint_Get(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) arg := structs.SessionRequest{ Datacenter: "dc1", Op: structs.SessionCreate, @@ -184,7 +184,7 @@ func TestSessionEndpoint_List(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) ids := []string{} for i := 0; i < 5; i++ { arg := structs.SessionRequest{ @@ -235,7 +235,7 @@ func TestSessionEndpoint_ApplyTimers(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) arg := structs.SessionRequest{ Datacenter: "dc1", Op: structs.SessionCreate, @@ -278,7 +278,7 @@ func TestSessionEndpoint_Renew(t *testing.T) { TTL := "10s" // the minimum allowed ttl ttl := 10 * time.Second - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) ids := []string{} for i := 0; i < 5; i++ { arg := structs.SessionRequest{ @@ -436,8 +436,8 @@ func TestSessionEndpoint_NodeSessions(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") - s1.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) - s1.fsm.State().EnsureNode(1, structs.Node{"bar", "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}) + s1.fsm.State().EnsureNode(1, &structs.Node{Node: "bar", Address: "127.0.0.1"}) ids := []string{} for i := 0; i < 10; i++ { arg := structs.SessionRequest{ diff --git a/consul/session_ttl_test.go b/consul/session_ttl_test.go index e732b5d01752..aa09b6d70d2d 100644 --- a/consul/session_ttl_test.go +++ b/consul/session_ttl_test.go @@ -20,7 +20,9 @@ func TestInitializeSessionTimers(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") state := s1.fsm.State() - state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %s", err) + } session := &structs.Session{ ID: generateUUID(), Node: "foo", @@ -51,14 +53,16 @@ func TestResetSessionTimer_Fault(t *testing.T) { testutil.WaitForLeader(t, s1.RPC, "dc1") // Should not exist - err := s1.resetSessionTimer("nope", nil) + err := s1.resetSessionTimer(generateUUID(), nil) if err == nil || !strings.Contains(err.Error(), "not found") { t.Fatalf("err: %v", err) } // Create a session state := s1.fsm.State() - state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %s", err) + } session := &structs.Session{ ID: generateUUID(), Node: "foo", @@ -90,7 +94,9 @@ func TestResetSessionTimer_NoTTL(t *testing.T) { // Create a session state := s1.fsm.State() - state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %s", err) + } session := &structs.Session{ ID: generateUUID(), Node: "foo", @@ -201,7 +207,9 @@ func TestInvalidateSession(t *testing.T) { // Create a session state := s1.fsm.State() - state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %s", err) + } session := &structs.Session{ ID: generateUUID(), Node: "foo", diff --git a/consul/state/delay.go b/consul/state/delay.go new file mode 100644 index 000000000000..206fe4da6a21 --- /dev/null +++ b/consul/state/delay.go @@ -0,0 +1,54 @@ +package state + +import ( + "sync" + "time" +) + +// Delay is used to mark certain locks as unacquirable. When a lock is +// forcefully released (failing health check, destroyed session, etc.), it is +// subject to the LockDelay impossed by the session. This prevents another +// session from acquiring the lock for some period of time as a protection +// against split-brains. This is inspired by the lock-delay in Chubby. Because +// this relies on wall-time, we cannot assume all peers perceive time as flowing +// uniformly. This means KVSLock MUST ignore lockDelay, since the lockDelay may +// have expired on the leader, but not on the follower. Rejecting the lock could +// result in inconsistencies in the FSMs due to the rate time progresses. Instead, +// only the opinion of the leader is respected, and the Raft log is never +// questioned. +type Delay struct { + // delay has the set of active delay expiration times, organized by key. + delay map[string]time.Time + + // lock protects the delay map. + lock sync.RWMutex +} + +// NewDelay returns a new delay manager. +func NewDelay() *Delay { + return &Delay{delay: make(map[string]time.Time)} +} + +// GetExpiration returns the expiration time of a key lock delay. This must be +// checked on the leader node, and not in KVSLock due to the variability of +// clocks. +func (d *Delay) GetExpiration(key string) time.Time { + d.lock.RLock() + expires := d.delay[key] + d.lock.RUnlock() + return expires +} + +// SetExpiration sets the expiration time for the lock delay to the given +// delay from the given now time. +func (d *Delay) SetExpiration(key string, now time.Time, delay time.Duration) { + d.lock.Lock() + defer d.lock.Unlock() + + d.delay[key] = now.Add(delay) + time.AfterFunc(delay, func() { + d.lock.Lock() + delete(d.delay, key) + d.lock.Unlock() + }) +} diff --git a/consul/state/delay_test.go b/consul/state/delay_test.go new file mode 100644 index 000000000000..68f67d3bef4f --- /dev/null +++ b/consul/state/delay_test.go @@ -0,0 +1,29 @@ +package state + +import ( + "testing" + "time" +) + +func TestDelay(t *testing.T) { + d := NewDelay() + + // An unknown key should have a time in the past. + if exp := d.GetExpiration("nope"); !exp.Before(time.Now()) { + t.Fatalf("bad: %v", exp) + } + + // Add a key and set a short expiration. + now := time.Now() + delay := 250 * time.Millisecond + d.SetExpiration("bye", now, delay) + if exp := d.GetExpiration("bye"); !exp.After(now) { + t.Fatalf("bad: %v", exp) + } + + // Wait for the key to expire and check again. + time.Sleep(2 * delay) + if exp := d.GetExpiration("bye"); !exp.Before(now) { + t.Fatalf("bad: %v", exp) + } +} diff --git a/consul/state/graveyard.go b/consul/state/graveyard.go new file mode 100644 index 000000000000..0ecd0974b1cd --- /dev/null +++ b/consul/state/graveyard.go @@ -0,0 +1,114 @@ +package state + +import ( + "fmt" + + "github.com/hashicorp/go-memdb" +) + +// Tombstone is the internal type used to track tombstones. +type Tombstone struct { + Key string + Index uint64 +} + +// Graveyard manages a set of tombstones. +type Graveyard struct { + // GC is when we create tombstones to track their time-to-live. + // The GC is consumed upstream to manage clearing of tombstones. + gc *TombstoneGC +} + +// NewGraveyard returns a new graveyard. +func NewGraveyard(gc *TombstoneGC) *Graveyard { + return &Graveyard{gc: gc} +} + +// InsertTxn adds a new tombstone. +func (g *Graveyard) InsertTxn(tx *memdb.Txn, key string, idx uint64) error { + // Insert the tombstone. + stone := &Tombstone{Key: key, Index: idx} + if err := tx.Insert("tombstones", stone); err != nil { + return fmt.Errorf("failed inserting tombstone: %s", err) + } + + if err := tx.Insert("index", &IndexEntry{"tombstones", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // If GC is configured, then we hint that this index requires reaping. + if g.gc != nil { + tx.Defer(func() { g.gc.Hint(idx) }) + } + return nil +} + +// GetMaxIndexTxn returns the highest index tombstone whose key matches the +// given context, using a prefix match. +func (g *Graveyard) GetMaxIndexTxn(tx *memdb.Txn, prefix string) (uint64, error) { + stones, err := tx.Get("tombstones", "id_prefix", prefix) + if err != nil { + return 0, fmt.Errorf("failed querying tombstones: %s", err) + } + + var lindex uint64 + for stone := stones.Next(); stone != nil; stone = stones.Next() { + s := stone.(*Tombstone) + if s.Index > lindex { + lindex = s.Index + } + } + return lindex, nil +} + +// DumpTxn returns all the tombstones. +func (g *Graveyard) DumpTxn(tx *memdb.Txn) (memdb.ResultIterator, error) { + iter, err := tx.Get("tombstones", "id") + if err != nil { + return nil, err + } + + return iter, nil +} + +// RestoreTxn is used when restoring from a snapshot. For general inserts, use +// InsertTxn. +func (g *Graveyard) RestoreTxn(tx *memdb.Txn, stone *Tombstone) error { + if err := tx.Insert("tombstones", stone); err != nil { + return fmt.Errorf("failed inserting tombstone: %s", err) + } + + if err := indexUpdateMaxTxn(tx, stone.Index, "tombstones"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + +// ReapTxn cleans out all tombstones whose index values are less than or equal +// to the given idx. This prevents unbounded storage growth of the tombstones. +func (g *Graveyard) ReapTxn(tx *memdb.Txn, idx uint64) error { + // This does a full table scan since we currently can't index on a + // numeric value. Since this is all in-memory and done infrequently + // this pretty reasonable. + stones, err := tx.Get("tombstones", "id") + if err != nil { + return fmt.Errorf("failed querying tombstones: %s", err) + } + + // Find eligible tombstones. + var objs []interface{} + for stone := stones.Next(); stone != nil; stone = stones.Next() { + if stone.(*Tombstone).Index <= idx { + objs = append(objs, stone) + } + } + + // Delete the tombstones in a separate loop so we don't trash the + // iterator. + for _, obj := range objs { + if err := tx.Delete("tombstones", obj); err != nil { + return fmt.Errorf("failed deleting tombstone: %s", err) + } + } + return nil +} diff --git a/consul/state/graveyard_test.go b/consul/state/graveyard_test.go new file mode 100644 index 000000000000..4b7f46e27f86 --- /dev/null +++ b/consul/state/graveyard_test.go @@ -0,0 +1,262 @@ +package state + +import ( + "reflect" + "testing" + "time" +) + +func TestGraveyard_Lifecycle(t *testing.T) { + g := NewGraveyard(nil) + + // Make a donor state store to steal its database, all prepared for + // tombstones. + s := testStateStore(t) + + // Create some tombstones. + func() { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + t.Fatalf("err: %s", err) + } + if err := g.InsertTxn(tx, "foo/bar/baz", 5); err != nil { + t.Fatalf("err: %s", err) + } + if err := g.InsertTxn(tx, "foo/bar/zoo", 8); err != nil { + t.Fatalf("err: %s", err) + } + if err := g.InsertTxn(tx, "some/other/path", 9); err != nil { + t.Fatalf("err: %s", err) + } + tx.Commit() + }() + + // Check some prefixes. + func() { + tx := s.db.Txn(false) + defer tx.Abort() + + if idx, err := g.GetMaxIndexTxn(tx, "foo"); idx != 8 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "foo/in/the/house"); idx != 2 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/baz"); idx != 5 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/zoo"); idx != 8 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "some/other/path"); idx != 9 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, ""); idx != 9 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "nope"); idx != 0 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + }() + + // Reap some tombstones. + func() { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := g.ReapTxn(tx, 6); err != nil { + t.Fatalf("err: %s", err) + } + tx.Commit() + }() + + // Check prefixes to see that the reap took effect at the right index. + func() { + tx := s.db.Txn(false) + defer tx.Abort() + + if idx, err := g.GetMaxIndexTxn(tx, "foo"); idx != 8 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "foo/in/the/house"); idx != 0 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/baz"); idx != 0 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "foo/bar/zoo"); idx != 8 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "some/other/path"); idx != 9 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, ""); idx != 9 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + if idx, err := g.GetMaxIndexTxn(tx, "nope"); idx != 0 || err != nil { + t.Fatalf("bad: %d (%s)", idx, err) + } + }() +} + +func TestGraveyard_GC_Trigger(t *testing.T) { + // Set up a fast-expiring GC. + ttl, granularity := 100*time.Millisecond, 20*time.Millisecond + gc, err := NewTombstoneGC(ttl, granularity) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Make a new graveyard and assign the GC. + g := NewGraveyard(gc) + gc.SetEnabled(true) + + // Make sure there's nothing already expiring. + if gc.PendingExpiration() { + t.Fatalf("should not have any expiring items") + } + + // Create a tombstone but abort the transaction, this should not trigger + // GC. + s := testStateStore(t) + func() { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + t.Fatalf("err: %s", err) + } + }() + + // Make sure there's nothing already expiring. + if gc.PendingExpiration() { + t.Fatalf("should not have any expiring items") + } + + // Now commit. + func() { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + t.Fatalf("err: %s", err) + } + tx.Commit() + }() + + // Make sure the GC got hinted. + if !gc.PendingExpiration() { + t.Fatalf("should have a pending expiration") + } + + // Make sure the index looks good. + select { + case idx := <-gc.ExpireCh(): + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + case <-time.After(2 * ttl): + t.Fatalf("should have gotten an expire notice") + } +} + +func TestGraveyard_Snapshot_Restore(t *testing.T) { + g := NewGraveyard(nil) + + // Make a donor state store to steal its database, all prepared for + // tombstones. + s := testStateStore(t) + + // Create some tombstones. + func() { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := g.InsertTxn(tx, "foo/in/the/house", 2); err != nil { + t.Fatalf("err: %s", err) + } + if err := g.InsertTxn(tx, "foo/bar/baz", 5); err != nil { + t.Fatalf("err: %s", err) + } + if err := g.InsertTxn(tx, "foo/bar/zoo", 8); err != nil { + t.Fatalf("err: %s", err) + } + if err := g.InsertTxn(tx, "some/other/path", 9); err != nil { + t.Fatalf("err: %s", err) + } + tx.Commit() + }() + + // Verify the index was set correctly. + if idx := s.maxIndex("tombstones"); idx != 9 { + t.Fatalf("bad index: %d", idx) + } + + // Dump them as if we are doing a snapshot. + dump := func() []*Tombstone { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := g.DumpTxn(tx) + if err != nil { + t.Fatalf("err: %s", err) + } + var dump []*Tombstone + for ti := iter.Next(); ti != nil; ti = iter.Next() { + dump = append(dump, ti.(*Tombstone)) + } + return dump + }() + + // Verify the dump, which should be ordered by key. + expected := []*Tombstone{ + &Tombstone{Key: "foo/bar/baz", Index: 5}, + &Tombstone{Key: "foo/bar/zoo", Index: 8}, + &Tombstone{Key: "foo/in/the/house", Index: 2}, + &Tombstone{Key: "some/other/path", Index: 9}, + } + if !reflect.DeepEqual(dump, expected) { + t.Fatalf("bad: %v", dump) + } + + // Make another state store and restore from the dump. + func() { + s := testStateStore(t) + func() { + tx := s.db.Txn(true) + defer tx.Abort() + + for _, stone := range dump { + if err := g.RestoreTxn(tx, stone); err != nil { + t.Fatalf("err: %s", err) + } + } + tx.Commit() + }() + + // Verify that the restore works. + if idx := s.maxIndex("tombstones"); idx != 9 { + t.Fatalf("bad index: %d", idx) + } + + dump := func() []*Tombstone { + tx := s.db.Txn(false) + defer tx.Abort() + + iter, err := g.DumpTxn(tx) + if err != nil { + t.Fatalf("err: %s", err) + } + var dump []*Tombstone + for ti := iter.Next(); ti != nil; ti = iter.Next() { + dump = append(dump, ti.(*Tombstone)) + } + return dump + }() + if !reflect.DeepEqual(dump, expected) { + t.Fatalf("bad: %v", dump) + } + }() +} diff --git a/consul/notify.go b/consul/state/notify.go similarity index 98% rename from consul/notify.go rename to consul/state/notify.go index 2fe5acbe2b79..3b991a656ab7 100644 --- a/consul/notify.go +++ b/consul/state/notify.go @@ -1,4 +1,4 @@ -package consul +package state import ( "sync" diff --git a/consul/notify_test.go b/consul/state/notify_test.go similarity index 98% rename from consul/notify_test.go rename to consul/state/notify_test.go index 2133e9b3125c..34c14f46dbbf 100644 --- a/consul/notify_test.go +++ b/consul/state/notify_test.go @@ -1,4 +1,4 @@ -package consul +package state import ( "testing" diff --git a/consul/state/schema.go b/consul/state/schema.go new file mode 100644 index 000000000000..ac58f15386f6 --- /dev/null +++ b/consul/state/schema.go @@ -0,0 +1,347 @@ +package state + +import ( + "fmt" + + "github.com/hashicorp/go-memdb" +) + +// schemaFn is an interface function used to create and return +// new memdb schema structs for constructing an in-memory db. +type schemaFn func() *memdb.TableSchema + +// stateStoreSchema is used to return the combined schema for +// the state store. +func stateStoreSchema() *memdb.DBSchema { + // Create the root DB schema + db := &memdb.DBSchema{ + Tables: make(map[string]*memdb.TableSchema), + } + + // Collect the needed schemas + schemas := []schemaFn{ + indexTableSchema, + nodesTableSchema, + servicesTableSchema, + checksTableSchema, + kvsTableSchema, + tombstonesTableSchema, + sessionsTableSchema, + sessionChecksTableSchema, + aclsTableSchema, + } + + // Add the tables to the root schema + for _, fn := range schemas { + schema := fn() + if _, ok := db.Tables[schema.Name]; ok { + panic(fmt.Sprintf("duplicate table name: %s", schema.Name)) + } + db.Tables[schema.Name] = schema + } + return db +} + +// indexTableSchema returns a new table schema used for +// tracking various indexes for the Raft log. +func indexTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "index", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Key", + Lowercase: true, + }, + }, + }, + } +} + +// nodesTableSchema returns a new table schema used for +// storing node information. +func nodesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "nodes", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + }, + } +} + +// servicesTableSchema returns a new TableSchema used to +// store information about services. +func servicesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "services", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "ServiceID", + Lowercase: true, + }, + }, + }, + }, + "node": &memdb.IndexSchema{ + Name: "node", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + "service": &memdb.IndexSchema{ + Name: "service", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "ServiceName", + Lowercase: true, + }, + }, + }, + } +} + +// checksTableSchema returns a new table schema used for +// storing and indexing health check information. Health +// checks have a number of different attributes we want to +// filter by, so this table is a bit more complex. +func checksTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "checks", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "CheckID", + Lowercase: true, + }, + }, + }, + }, + "status": &memdb.IndexSchema{ + Name: "status", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Status", + Lowercase: false, + }, + }, + "service": &memdb.IndexSchema{ + Name: "service", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "ServiceName", + Lowercase: true, + }, + }, + "node": &memdb.IndexSchema{ + Name: "node", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + "node_service": &memdb.IndexSchema{ + Name: "node_service", + AllowMissing: true, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "ServiceID", + Lowercase: true, + }, + }, + }, + }, + }, + } +} + +// kvsTableSchema returns a new table schema used for storing +// key/value data from consul's kv store. +func kvsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "kvs", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Key", + Lowercase: false, + }, + }, + "session": &memdb.IndexSchema{ + Name: "session", + AllowMissing: true, + Unique: false, + Indexer: &memdb.UUIDFieldIndex{ + Field: "Session", + }, + }, + }, + } +} + +// tombstonesTableSchema returns a new table schema used for +// storing tombstones during KV delete operations to prevent +// the index from sliding backwards. +func tombstonesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "tombstones", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Key", + Lowercase: false, + }, + }, + }, + } +} + +// sessionsTableSchema returns a new TableSchema used for +// storing session information. +func sessionsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "sessions", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "ID", + }, + }, + "node": &memdb.IndexSchema{ + Name: "node", + AllowMissing: true, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + }, + }, + } +} + +// sessionChecksTableSchema returns a new table schema used +// for storing session checks. +func sessionChecksTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "session_checks", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "CheckID", + Lowercase: true, + }, + &memdb.UUIDFieldIndex{ + Field: "Session", + }, + }, + }, + }, + "node_check": &memdb.IndexSchema{ + Name: "node_check", + AllowMissing: false, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Node", + Lowercase: true, + }, + &memdb.StringFieldIndex{ + Field: "CheckID", + Lowercase: true, + }, + }, + }, + }, + "session": &memdb.IndexSchema{ + Name: "session", + AllowMissing: false, + Unique: false, + Indexer: &memdb.UUIDFieldIndex{ + Field: "Session", + }, + }, + }, + } +} + +// aclsTableSchema returns a new table schema used for +// storing ACL information. +func aclsTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "acls", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + Lowercase: false, + }, + }, + }, + } +} diff --git a/consul/state/schema_test.go b/consul/state/schema_test.go new file mode 100644 index 000000000000..b96a027c2e71 --- /dev/null +++ b/consul/state/schema_test.go @@ -0,0 +1,17 @@ +package state + +import ( + "testing" + + "github.com/hashicorp/go-memdb" +) + +func TestStateStore_Schema(t *testing.T) { + // First call the schema creation + schema := stateStoreSchema() + + // Try to initialize a new memdb using the schema + if _, err := memdb.NewMemDB(schema); err != nil { + t.Fatalf("err: %s", err) + } +} diff --git a/consul/state/state_store.go b/consul/state/state_store.go new file mode 100644 index 000000000000..402cf64cec56 --- /dev/null +++ b/consul/state/state_store.go @@ -0,0 +1,2233 @@ +package state + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/go-memdb" +) + +var ( + // ErrMissingNode is the error returned when trying an operation + // which requires a node registration but none exists. + ErrMissingNode = errors.New("Missing node registration") + + // ErrMissingService is the error we return if trying an + // operation which requires a service but none exists. + ErrMissingService = errors.New("Missing service registration") + + // ErrMissingSessionID is returned when a session registration + // is attempted with an empty session ID. + ErrMissingSessionID = errors.New("Missing session ID") + + // ErrMissingACLID is returned when a session set is called on + // a session with an empty ID. + ErrMissingACLID = errors.New("Missing ACL ID") +) + +// StateStore is where we store all of Consul's state, including +// records of node registrations, services, checks, key/value +// pairs and more. The DB is entirely in-memory and is constructed +// from the Raft log through the FSM. +type StateStore struct { + schema *memdb.DBSchema + db *memdb.MemDB + + // tableWatches holds all the full table watches, indexed by table name. + tableWatches map[string]*FullTableWatch + + // kvsWatch holds the special prefix watch for the key value store. + kvsWatch *PrefixWatch + + // kvsGraveyard manages tombstones for the key value store. + kvsGraveyard *Graveyard + + // lockDelay holds expiration times for locks associated with keys. + lockDelay *Delay +} + +// StateSnapshot is used to provide a point-in-time snapshot. It +// works by starting a read transaction against the whole state store. +type StateSnapshot struct { + store *StateStore + tx *memdb.Txn + lastIndex uint64 +} + +// StateRestore is used to efficiently manage restoring a large amount of +// data to a state store. +type StateRestore struct { + store *StateStore + tx *memdb.Txn + watches *DumbWatchManager +} + +// IndexEntry keeps a record of the last index per-table. +type IndexEntry struct { + Key string + Value uint64 +} + +// sessionCheck is used to create a many-to-one table such that +// each check registered by a session can be mapped back to the +// session table. This is only used internally in the state +// store and thus it is not exported. +type sessionCheck struct { + Node string + CheckID string + Session string +} + +// NewStateStore creates a new in-memory state storage layer. +func NewStateStore(gc *TombstoneGC) (*StateStore, error) { + // Create the in-memory DB. + schema := stateStoreSchema() + db, err := memdb.NewMemDB(schema) + if err != nil { + return nil, fmt.Errorf("Failed setting up state store: %s", err) + } + + // Build up the all-table watches. + tableWatches := make(map[string]*FullTableWatch) + for table, _ := range schema.Tables { + if table == "kvs" || table == "tombstones" { + continue + } + + tableWatches[table] = NewFullTableWatch() + } + + // Create and return the state store. + s := &StateStore{ + schema: schema, + db: db, + tableWatches: tableWatches, + kvsWatch: NewPrefixWatch(), + kvsGraveyard: NewGraveyard(gc), + lockDelay: NewDelay(), + } + return s, nil +} + +// Snapshot is used to create a point-in-time snapshot of the entire db. +func (s *StateStore) Snapshot() *StateSnapshot { + tx := s.db.Txn(false) + + var tables []string + for table, _ := range s.schema.Tables { + tables = append(tables, table) + } + idx := maxIndexTxn(tx, tables...) + + return &StateSnapshot{s, tx, idx} +} + +// LastIndex returns that last index that affects the snapshotted data. +func (s *StateSnapshot) LastIndex() uint64 { + return s.lastIndex +} + +// Close performs cleanup of a state snapshot. +func (s *StateSnapshot) Close() { + s.tx.Abort() +} + +// Nodes is used to pull the full list of nodes for use during snapshots. +func (s *StateSnapshot) Nodes() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("nodes", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// Services is used to pull the full list of services for a given node for use +// during snapshots. +func (s *StateSnapshot) Services(node string) (memdb.ResultIterator, error) { + iter, err := s.tx.Get("services", "node", node) + if err != nil { + return nil, err + } + return iter, nil +} + +// Checks is used to pull the full list of checks for a given node for use +// during snapshots. +func (s *StateSnapshot) Checks(node string) (memdb.ResultIterator, error) { + iter, err := s.tx.Get("checks", "node", node) + if err != nil { + return nil, err + } + return iter, nil +} + +// KVs is used to pull the full list of KVS entries for use during snapshots. +func (s *StateSnapshot) KVs() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("kvs", "id_prefix") + if err != nil { + return nil, err + } + return iter, nil +} + +// Tombstones is used to pull all the tombstones from the graveyard. +func (s *StateSnapshot) Tombstones() (memdb.ResultIterator, error) { + return s.store.kvsGraveyard.DumpTxn(s.tx) +} + +// Sessions is used to pull the full list of sessions for use during snapshots. +func (s *StateSnapshot) Sessions() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("sessions", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// ACLs is used to pull all the ACLs from the snapshot. +func (s *StateSnapshot) ACLs() (memdb.ResultIterator, error) { + iter, err := s.tx.Get("acls", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// Restore is used to efficiently manage restoring a large amount of data into +// the state store. It works by doing all the restores inside of a single +// transaction. +func (s *StateStore) Restore() *StateRestore { + tx := s.db.Txn(true) + watches := NewDumbWatchManager(s.tableWatches) + return &StateRestore{s, tx, watches} +} + +// Abort abandons the changes made by a restore. This or Commit should always be +// called. +func (s *StateRestore) Abort() { + s.tx.Abort() +} + +// Commit commits the changes made by a restore. This or Abort should always be +// called. +func (s *StateRestore) Commit() { + // Fire off a single KVS watch instead of a zillion prefix ones, and use + // a dumb watch manager to single-fire all the full table watches. + s.tx.Defer(func() { s.store.kvsWatch.Notify("", true) }) + s.tx.Defer(func() { s.watches.Notify() }) + + s.tx.Commit() +} + +// Registration is used to make sure a node, service, and check registration is +// performed within a single transaction to avoid race conditions on state +// updates. +func (s *StateRestore) Registration(idx uint64, req *structs.RegisterRequest) error { + if err := s.store.ensureRegistrationTxn(s.tx, idx, s.watches, req); err != nil { + return err + } + return nil +} + +// KVS is used when restoring from a snapshot. Use KVSSet for general inserts. +func (s *StateRestore) KVS(entry *structs.DirEntry) error { + if err := s.tx.Insert("kvs", entry); err != nil { + return fmt.Errorf("failed inserting kvs entry: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, entry.ModifyIndex, "kvs"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // We have a single top-level KVS watch trigger instead of doing + // tons of prefix watches. + return nil +} + +// Tombstone is used when restoring from a snapshot. For general inserts, use +// Graveyard.InsertTxn. +func (s *StateRestore) Tombstone(stone *Tombstone) error { + if err := s.store.kvsGraveyard.RestoreTxn(s.tx, stone); err != nil { + return fmt.Errorf("failed restoring tombstone: %s", err) + } + return nil +} + +// Session is used when restoring from a snapshot. For general inserts, use +// SessionCreate. +func (s *StateRestore) Session(sess *structs.Session) error { + // Insert the session. + if err := s.tx.Insert("sessions", sess); err != nil { + return fmt.Errorf("failed inserting session: %s", err) + } + + // Insert the check mappings. + for _, checkID := range sess.Checks { + mapping := &sessionCheck{ + Node: sess.Node, + CheckID: checkID, + Session: sess.ID, + } + if err := s.tx.Insert("session_checks", mapping); err != nil { + return fmt.Errorf("failed inserting session check mapping: %s", err) + } + } + + // Update the index. + if err := indexUpdateMaxTxn(s.tx, sess.ModifyIndex, "sessions"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + s.watches.Arm("sessions") + return nil +} + +// ACL is used when restoring from a snapshot. For general inserts, use ACLSet. +func (s *StateRestore) ACL(acl *structs.ACL) error { + if err := s.tx.Insert("acls", acl); err != nil { + return fmt.Errorf("failed restoring acl: %s", err) + } + + if err := indexUpdateMaxTxn(s.tx, acl.ModifyIndex, "acls"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + s.watches.Arm("acls") + return nil +} + +// maxIndex is a helper used to retrieve the highest known index +// amongst a set of tables in the db. +func (s *StateStore) maxIndex(tables ...string) uint64 { + tx := s.db.Txn(false) + defer tx.Abort() + return maxIndexTxn(tx, tables...) +} + +// maxIndexTxn is a helper used to retrieve the highest known index +// amongst a set of tables in the db. +func maxIndexTxn(tx *memdb.Txn, tables ...string) uint64 { + var lindex uint64 + for _, table := range tables { + ti, err := tx.First("index", "id", table) + if err != nil { + panic(fmt.Sprintf("unknown index: %s err: %s", table, err)) + } + if idx, ok := ti.(*IndexEntry); ok && idx.Value > lindex { + lindex = idx.Value + } + } + return lindex +} + +// indexUpdateMaxTxn is used when restoring entries and sets the table's index to +// the given idx only if it's greater than the current index. +func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error { + ti, err := tx.First("index", "id", table) + if err != nil { + return fmt.Errorf("failed to retrieve existing index: %s", err) + } + + // Always take the first update, otherwise do the > check. + if ti == nil { + if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil { + return fmt.Errorf("failed updating index %s", err) + } + } else if cur, ok := ti.(*IndexEntry); ok && idx > cur.Value { + if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil { + return fmt.Errorf("failed updating index %s", err) + } + } + + return nil +} + +// ReapTombstones is used to delete all the tombstones with an index +// less than or equal to the given index. This is used to prevent +// unbounded storage growth of the tombstones. +func (s *StateStore) ReapTombstones(index uint64) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.kvsGraveyard.ReapTxn(tx, index); err != nil { + return fmt.Errorf("failed to reap kvs tombstones: %s", err) + } + + tx.Commit() + return nil +} + +// getWatchTables returns the list of tables that should be watched and used for +// max index calculations for the given query method. This is used for all +// methods except for KVS. This will panic if the method is unknown. +func (s *StateStore) getWatchTables(method string) []string { + switch method { + case "GetNode", "Nodes": + return []string{"nodes"} + case "Services": + return []string{"services"} + case "ServiceNodes", "NodeServices": + return []string{"nodes", "services"} + case "NodeChecks", "ServiceChecks", "ChecksInState": + return []string{"checks"} + case "CheckServiceNodes", "NodeInfo", "NodeDump": + return []string{"nodes", "services", "checks"} + case "SessionGet", "SessionList", "NodeSessions": + return []string{"sessions"} + case "ACLGet", "ACLList": + return []string{"acls"} + } + + panic(fmt.Sprintf("Unknown method %s", method)) +} + +// getTableWatch returns a full table watch for the given table. This will panic +// if the table doesn't have a full table watch. +func (s *StateStore) getTableWatch(table string) Watch { + if watch, ok := s.tableWatches[table]; ok { + return watch + } + + panic(fmt.Sprintf("Unknown watch for table %s", table)) +} + +// GetQueryWatch returns a watch for the given query method. This is +// used for all methods except for KV; you should call GetKVSWatch instead. +// This will panic if the method is unknown. +func (s *StateStore) GetQueryWatch(method string) Watch { + tables := s.getWatchTables(method) + if len(tables) == 1 { + return s.getTableWatch(tables[0]) + } + + var watches []Watch + for _, table := range tables { + watches = append(watches, s.getTableWatch(table)) + } + return NewMultiWatch(watches...) +} + +// GetKVSWatch returns a watch for the given prefix in the key value store. +func (s *StateStore) GetKVSWatch(prefix string) Watch { + return s.kvsWatch.GetSubwatch(prefix) +} + +// EnsureRegistration is used to make sure a node, service, and check +// registration is performed within a single transaction to avoid race +// conditions on state updates. +func (s *StateStore) EnsureRegistration(idx uint64, req *structs.RegisterRequest) error { + tx := s.db.Txn(true) + defer tx.Abort() + + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureRegistrationTxn(tx, idx, watches, req); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureRegistrationTxn is used to make sure a node, service, and check +// registration is performed within a single transaction to avoid race +// conditions on state updates. +func (s *StateStore) ensureRegistrationTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + req *structs.RegisterRequest) error { + // Add the node. + node := &structs.Node{Node: req.Node, Address: req.Address} + if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { + return fmt.Errorf("failed inserting node: %s", err) + } + + // Add the service, if any. + if req.Service != nil { + if err := s.ensureServiceTxn(tx, idx, watches, req.Node, req.Service); err != nil { + return fmt.Errorf("failed inserting service: %s", err) + } + } + + // Add the checks, if any. + if req.Check != nil { + if err := s.ensureCheckTxn(tx, idx, watches, req.Check); err != nil { + return fmt.Errorf("failed inserting check: %s", err) + } + } + for _, check := range req.Checks { + if err := s.ensureCheckTxn(tx, idx, watches, check); err != nil { + return fmt.Errorf("failed inserting check: %s", err) + } + } + + return nil +} + +// EnsureNode is used to upsert node registration or modification. +func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the node upsert + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureNodeTxn(tx, idx, watches, node); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureNodeTxn is the inner function called to actually create a node +// registration or modify an existing one in the state store. It allows +// passing in a memdb transaction so it may be part of a larger txn. +func (s *StateStore) ensureNodeTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + node *structs.Node) error { + // Check for an existing node + existing, err := tx.First("nodes", "id", node.Node) + if err != nil { + return fmt.Errorf("node lookup failed: %s", err) + } + + // Get the indexes + if existing != nil { + node.CreateIndex = existing.(*structs.Node).CreateIndex + node.ModifyIndex = idx + } else { + node.CreateIndex = idx + node.ModifyIndex = idx + } + + // Insert the node and update the index + if err := tx.Insert("nodes", node); err != nil { + return fmt.Errorf("failed inserting node: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"nodes", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("nodes") + return nil +} + +// GetNode is used to retrieve a node registration by node ID. +func (s *StateStore) GetNode(id string) (uint64, *structs.Node, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("GetNode")...) + + // Retrieve the node from the state store + node, err := tx.First("nodes", "id", id) + if err != nil { + return 0, nil, fmt.Errorf("node lookup failed: %s", err) + } + if node != nil { + return idx, node.(*structs.Node), nil + } + return idx, nil, nil +} + +// Nodes is used to return all of the known nodes. +func (s *StateStore) Nodes() (uint64, structs.Nodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("Nodes")...) + + // Retrieve all of the nodes + nodes, err := tx.Get("nodes", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed nodes lookup: %s", err) + } + + // Create and return the nodes list. + var results structs.Nodes + for node := nodes.Next(); node != nil; node = nodes.Next() { + results = append(results, node.(*structs.Node)) + } + return idx, results, nil +} + +// DeleteNode is used to delete a given node by its ID. +func (s *StateStore) DeleteNode(idx uint64, nodeID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the node deletion. + if err := s.deleteNodeTxn(tx, idx, nodeID); err != nil { + return err + } + + tx.Commit() + return nil +} + +// deleteNodeTxn is the inner method used for removing a node from +// the store within a given transaction. +func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { + // Look up the node. + node, err := tx.First("nodes", "id", nodeID) + if err != nil { + return fmt.Errorf("node lookup failed: %s", err) + } + if node == nil { + return nil + } + + // Use a watch manager since the inner functions can perform multiple + // ops per table. + watches := NewDumbWatchManager(s.tableWatches) + watches.Arm("nodes") + + // Delete all services associated with the node and update the service index. + services, err := tx.Get("services", "node", nodeID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + var sids []string + for service := services.Next(); service != nil; service = services.Next() { + sids = append(sids, service.(*structs.ServiceNode).ServiceID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, sid := range sids { + if err := s.deleteServiceTxn(tx, idx, watches, nodeID, sid); err != nil { + return err + } + } + + // Delete all checks associated with the node. This will invalidate + // sessions as necessary. + checks, err := tx.Get("checks", "node", nodeID) + if err != nil { + return fmt.Errorf("failed check lookup: %s", err) + } + var cids []string + for check := checks.Next(); check != nil; check = checks.Next() { + cids = append(cids, check.(*structs.HealthCheck).CheckID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, cid := range cids { + if err := s.deleteCheckTxn(tx, idx, watches, nodeID, cid); err != nil { + return err + } + } + + // Delete the node and update the index. + if err := tx.Delete("nodes", node); err != nil { + return fmt.Errorf("failed deleting node: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"nodes", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Invalidate any sessions for this node. + sessions, err := tx.Get("sessions", "node", nodeID) + if err != nil { + return fmt.Errorf("failed session lookup: %s", err) + } + var ids []string + for sess := sessions.Next(); sess != nil; sess = sessions.Next() { + ids = append(ids, sess.(*structs.Session).ID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, id := range ids { + if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + return fmt.Errorf("failed session delete: %s", err) + } + } + + tx.Defer(func() { watches.Notify() }) + return nil +} + +// EnsureService is called to upsert creation of a given NodeService. +func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeService) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the service registration upsert + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureServiceTxn(tx, idx, watches, node, svc); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureServiceTxn is used to upsert a service registration within an +// existing memdb transaction. +func (s *StateStore) ensureServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + node string, svc *structs.NodeService) error { + // Check for existing service + existing, err := tx.First("services", "id", node, svc.ID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + + // Create the service node entry and populate the indexes. We leave the + // address blank and fill that in on the way out during queries. + entry := svc.ToServiceNode(node, "") + if existing != nil { + entry.CreateIndex = existing.(*structs.ServiceNode).CreateIndex + entry.ModifyIndex = idx + } else { + entry.CreateIndex = idx + entry.ModifyIndex = idx + } + + // Get the node + n, err := tx.First("nodes", "id", node) + if err != nil { + return fmt.Errorf("failed node lookup: %s", err) + } + if n == nil { + return ErrMissingNode + } + + // Insert the service and update the index + if err := tx.Insert("services", entry); err != nil { + return fmt.Errorf("failed inserting service: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"services", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("services") + return nil +} + +// Services returns all services along with a list of associated tags. +func (s *StateStore) Services() (uint64, structs.Services, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("Services")...) + + // List all the services. + services, err := tx.Get("services", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed querying services: %s", err) + } + + // Rip through the services and enumerate them and their unique set of + // tags. + unique := make(map[string]map[string]struct{}) + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + tags, ok := unique[svc.ServiceName] + if !ok { + unique[svc.ServiceName] = make(map[string]struct{}) + tags = unique[svc.ServiceName] + } + for _, tag := range svc.ServiceTags { + tags[tag] = struct{}{} + } + } + + // Generate the output structure. + var results = make(structs.Services) + for service, tags := range unique { + results[service] = make([]string, 0) + for tag, _ := range tags { + results[service] = append(results[service], tag) + } + } + return idx, results, nil +} + +// ServiceNodes returns the nodes associated with a given service name. +func (s *StateStore) ServiceNodes(serviceName string) (uint64, structs.ServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + + // List all the services. + services, err := tx.Get("services", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + results = append(results, service.(*structs.ServiceNode)) + } + + // Fill in the address details. + results, err = s.parseServiceNodes(tx, results) + if err != nil { + return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) + } + return idx, results, nil +} + +// ServiceTagNodes returns the nodes associated with a given service, filtering +// out services that don't contain the given tag. +func (s *StateStore) ServiceTagNodes(service, tag string) (uint64, structs.ServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ServiceNodes")...) + + // List all the services. + services, err := tx.Get("services", "service", service) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + + // Gather all the services and apply the tag filter. + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + if !serviceTagFilter(svc, tag) { + results = append(results, svc) + } + } + + // Fill in the address details. + results, err = s.parseServiceNodes(tx, results) + if err != nil { + return 0, nil, fmt.Errorf("failed parsing service nodes: %s", err) + } + return idx, results, nil +} + +// serviceTagFilter returns true (should filter) if the given service node +// doesn't contain the given tag. +func serviceTagFilter(sn *structs.ServiceNode, tag string) bool { + tag = strings.ToLower(tag) + + // Look for the lower cased version of the tag. + for _, t := range sn.ServiceTags { + if strings.ToLower(t) == tag { + return false + } + } + + // If we didn't hit the tag above then we should filter. + return true +} + +// parseServiceNodes iterates over a services query and fills in the node details, +// returning a ServiceNodes slice. +func (s *StateStore) parseServiceNodes(tx *memdb.Txn, services structs.ServiceNodes) (structs.ServiceNodes, error) { + var results structs.ServiceNodes + for _, sn := range services { + // Note that we have to clone here because we don't want to + // modify the address field on the object in the database, + // which is what we are referencing. + s := sn.Clone() + + // Fill in the address of the node. + n, err := tx.First("nodes", "id", sn.Node) + if err != nil { + return nil, fmt.Errorf("failed node lookup: %s", err) + } + s.Address = n.(*structs.Node).Address + results = append(results, s) + } + return results, nil +} + +// NodeServices is used to query service registrations by node ID. +func (s *StateStore) NodeServices(nodeID string) (uint64, *structs.NodeServices, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeServices")...) + + // Query the node + n, err := tx.First("nodes", "id", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("node lookup failed: %s", err) + } + if n == nil { + return 0, nil, nil + } + node := n.(*structs.Node) + + // Read all of the services + services, err := tx.Get("services", "node", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("failed querying services for node %q: %s", nodeID, err) + } + + // Initialize the node services struct + ns := &structs.NodeServices{ + Node: node, + Services: make(map[string]*structs.NodeService), + } + + // Add all of the services to the map. + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode).ToNodeService() + ns.Services[svc.ID] = svc + } + + return idx, ns, nil +} + +// DeleteService is used to delete a given service associated with a node. +func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the service deletion + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteServiceTxn(tx, idx, watches, nodeID, serviceID); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// deleteServiceTxn is the inner method called to remove a service +// registration within an existing transaction. +func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, nodeID, serviceID string) error { + // Look up the service. + service, err := tx.First("services", "id", nodeID, serviceID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + if service == nil { + return nil + } + + // Delete any checks associated with the service. This will invalidate + // sessions as necessary. + checks, err := tx.Get("checks", "node_service", nodeID, serviceID) + if err != nil { + return fmt.Errorf("failed service check lookup: %s", err) + } + var cids []string + for check := checks.Next(); check != nil; check = checks.Next() { + cids = append(cids, check.(*structs.HealthCheck).CheckID) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, cid := range cids { + if err := s.deleteCheckTxn(tx, idx, watches, nodeID, cid); err != nil { + return err + } + } + + // Update the index. + if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Delete the service and update the index + if err := tx.Delete("services", service); err != nil { + return fmt.Errorf("failed deleting service: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"services", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("services") + return nil +} + +// EnsureCheck is used to store a check registration in the db. +func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the check registration + watches := NewDumbWatchManager(s.tableWatches) + if err := s.ensureCheckTxn(tx, idx, watches, hc); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// ensureCheckTransaction is used as the inner method to handle inserting +// a health check into the state store. It ensures safety against inserting +// checks with no matching node or service. +func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, + hc *structs.HealthCheck) error { + // Check if we have an existing health check + existing, err := tx.First("checks", "id", hc.Node, hc.CheckID) + if err != nil { + return fmt.Errorf("failed health check lookup: %s", err) + } + + // Set the indexes + if existing != nil { + hc.CreateIndex = existing.(*structs.HealthCheck).CreateIndex + hc.ModifyIndex = idx + } else { + hc.CreateIndex = idx + hc.ModifyIndex = idx + } + + // Use the default check status if none was provided + if hc.Status == "" { + hc.Status = structs.HealthCritical + } + + // Get the node + node, err := tx.First("nodes", "id", hc.Node) + if err != nil { + return fmt.Errorf("failed node lookup: %s", err) + } + if node == nil { + return ErrMissingNode + } + + // If the check is associated with a service, check that we have + // a registration for the service. + if hc.ServiceID != "" { + service, err := tx.First("services", "id", hc.Node, hc.ServiceID) + if err != nil { + return fmt.Errorf("failed service lookup: %s", err) + } + if service == nil { + return ErrMissingService + } + + // Copy in the service name + hc.ServiceName = service.(*structs.ServiceNode).ServiceName + } + + // Delete any sessions for this check if the health is critical. + if hc.Status == structs.HealthCritical { + mappings, err := tx.Get("session_checks", "node_check", hc.Node, hc.CheckID) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + + var ids []string + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + ids = append(ids, mapping.(*sessionCheck).Session) + } + + // Delete the session in a separate loop so we don't trash the + // iterator. + watches := NewDumbWatchManager(s.tableWatches) + for _, id := range ids { + if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + } + tx.Defer(func() { watches.Notify() }) + } + + // Persist the check registration in the db. + if err := tx.Insert("checks", hc); err != nil { + return fmt.Errorf("failed inserting service: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + watches.Arm("checks") + return nil +} + +// NodeChecks is used to retrieve checks associated with the +// given node from the state store. +func (s *StateStore) NodeChecks(nodeID string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeChecks")...) + + // Return the checks. + checks, err := tx.Get("checks", "node", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) +} + +// ServiceChecks is used to get all checks associated with a +// given service ID. The query is performed against a service +// _name_ instead of a service ID. +func (s *StateStore) ServiceChecks(serviceName string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ServiceChecks")...) + + // Return the checks. + checks, err := tx.Get("checks", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) +} + +// ChecksInState is used to query the state store for all checks +// which are in the provided state. +func (s *StateStore) ChecksInState(state string) (uint64, structs.HealthChecks, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ChecksInState")...) + + // Query all checks if HealthAny is passed + if state == structs.HealthAny { + checks, err := tx.Get("checks", "status") + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) + } + + // Any other state we need to query for explicitly + checks, err := tx.Get("checks", "status", state) + if err != nil { + return 0, nil, fmt.Errorf("failed check lookup: %s", err) + } + return s.parseChecks(idx, checks) +} + +// parseChecks is a helper function used to deduplicate some +// repetitive code for returning health checks. +func (s *StateStore) parseChecks(idx uint64, iter memdb.ResultIterator) (uint64, structs.HealthChecks, error) { + // Gather the health checks and return them properly type casted. + var results structs.HealthChecks + for check := iter.Next(); check != nil; check = iter.Next() { + results = append(results, check.(*structs.HealthCheck)) + } + return idx, results, nil +} + +// DeleteCheck is used to delete a health check registration. +func (s *StateStore) DeleteCheck(idx uint64, node, id string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the check deletion + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteCheckTxn(tx, idx, watches, node, id); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// deleteCheckTxn is the inner method used to call a health +// check deletion within an existing transaction. +func (s *StateStore) deleteCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, node, id string) error { + // Try to retrieve the existing health check. + hc, err := tx.First("checks", "id", node, id) + if err != nil { + return fmt.Errorf("check lookup failed: %s", err) + } + if hc == nil { + return nil + } + + // Delete the check from the DB and update the index. + if err := tx.Delete("checks", hc); err != nil { + return fmt.Errorf("failed removing check: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Delete any sessions for this check. + mappings, err := tx.Get("session_checks", "node_check", node, id) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + var ids []string + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + ids = append(ids, mapping.(*sessionCheck).Session) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, id := range ids { + if err := s.deleteSessionTxn(tx, idx, watches, id); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + } + + watches.Arm("checks") + return nil +} + +// CheckServiceNodes is used to query all nodes and checks for a given service +// The results are compounded into a CheckServiceNodes, and the index returned +// is the maximum index observed over any node, check, or service in the result +// set. +func (s *StateStore) CheckServiceNodes(serviceName string) (uint64, structs.CheckServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) + + // Query the state store for the service. + services, err := tx.Get("services", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + + // Return the results. + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + results = append(results, service.(*structs.ServiceNode)) + } + return s.parseCheckServiceNodes(tx, idx, results, err) +} + +// CheckServiceTagNodes is used to query all nodes and checks for a given +// service, filtering out services that don't contain the given tag. The results +// are compounded into a CheckServiceNodes, and the index returned is the maximum +// index observed over any node, check, or service in the result set. +func (s *StateStore) CheckServiceTagNodes(serviceName, tag string) (uint64, structs.CheckServiceNodes, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("CheckServiceNodes")...) + + // Query the state store for the service. + services, err := tx.Get("services", "service", serviceName) + if err != nil { + return 0, nil, fmt.Errorf("failed service lookup: %s", err) + } + + // Return the results, filtering by tag. + var results structs.ServiceNodes + for service := services.Next(); service != nil; service = services.Next() { + svc := service.(*structs.ServiceNode) + if !serviceTagFilter(svc, tag) { + results = append(results, svc) + } + } + return s.parseCheckServiceNodes(tx, idx, results, err) +} + +// parseCheckServiceNodes is used to parse through a given set of services, +// and query for an associated node and a set of checks. This is the inner +// method used to return a rich set of results from a more simple query. +func (s *StateStore) parseCheckServiceNodes( + tx *memdb.Txn, idx uint64, services structs.ServiceNodes, + err error) (uint64, structs.CheckServiceNodes, error) { + if err != nil { + return 0, nil, err + } + + var results structs.CheckServiceNodes + for _, sn := range services { + // Retrieve the node. + n, err := tx.First("nodes", "id", sn.Node) + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + if n == nil { + return 0, nil, ErrMissingNode + } + node := n.(*structs.Node) + + // We need to return the checks specific to the given service + // as well as the node itself. Unfortunately, memdb won't let + // us use the index to do the latter query so we have to pull + // them all and filter. + var checks structs.HealthChecks + iter, err := tx.Get("checks", "node", sn.Node) + if err != nil { + return 0, nil, err + } + for check := iter.Next(); check != nil; check = iter.Next() { + hc := check.(*structs.HealthCheck) + if hc.ServiceID == "" || hc.ServiceID == sn.ServiceID { + checks = append(checks, hc) + } + } + + // Append to the results. + results = append(results, structs.CheckServiceNode{ + Node: node, + Service: sn.ToNodeService(), + Checks: checks, + }) + } + + return idx, results, nil +} + +// NodeInfo is used to generate a dump of a single node. The dump includes +// all services and checks which are registered against the node. +func (s *StateStore) NodeInfo(node string) (uint64, structs.NodeDump, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeInfo")...) + + // Query the node by the passed node + nodes, err := tx.Get("nodes", "id", node) + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + return s.parseNodes(tx, idx, nodes) +} + +// NodeDump is used to generate a dump of all nodes. This call is expensive +// as it has to query every node, service, and check. The response can also +// be quite large since there is currently no filtering applied. +func (s *StateStore) NodeDump() (uint64, structs.NodeDump, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeDump")...) + + // Fetch all of the registered nodes + nodes, err := tx.Get("nodes", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + return s.parseNodes(tx, idx, nodes) +} + +// parseNodes takes an iterator over a set of nodes and returns a struct +// containing the nodes along with all of their associated services +// and/or health checks. +func (s *StateStore) parseNodes(tx *memdb.Txn, idx uint64, + iter memdb.ResultIterator) (uint64, structs.NodeDump, error) { + + var results structs.NodeDump + for n := iter.Next(); n != nil; n = iter.Next() { + node := n.(*structs.Node) + + // Create the wrapped node + dump := &structs.NodeInfo{ + Node: node.Node, + Address: node.Address, + } + + // Query the node services + services, err := tx.Get("services", "node", node.Node) + if err != nil { + return 0, nil, fmt.Errorf("failed services lookup: %s", err) + } + for service := services.Next(); service != nil; service = services.Next() { + ns := service.(*structs.ServiceNode).ToNodeService() + dump.Services = append(dump.Services, ns) + } + + // Query the node checks + checks, err := tx.Get("checks", "node", node.Node) + if err != nil { + return 0, nil, fmt.Errorf("failed node lookup: %s", err) + } + for check := checks.Next(); check != nil; check = checks.Next() { + hc := check.(*structs.HealthCheck) + dump.Checks = append(dump.Checks, hc) + } + + // Add the result to the slice + results = append(results, dump) + } + return idx, results, nil +} + +// KVSSet is used to store a key/value pair. +func (s *StateStore) KVSSet(idx uint64, entry *structs.DirEntry) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Perform the actual set. + if err := s.kvsSetTxn(tx, idx, entry, false); err != nil { + return err + } + + tx.Commit() + return nil +} + +// kvsSetTxn is used to insert or update a key/value pair in the state +// store. It is the inner method used and handles only the actual storage. +// If updateSession is true, then the incoming entry will set the new +// session (should be validated before calling this). Otherwise, we will keep +// whatever the existing session is. +func (s *StateStore) kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry, updateSession bool) error { + // Retrieve an existing KV pair + existing, err := tx.First("kvs", "id", entry.Key) + if err != nil { + return fmt.Errorf("failed kvs lookup: %s", err) + } + + // Set the indexes. + if existing != nil { + entry.CreateIndex = existing.(*structs.DirEntry).CreateIndex + } else { + entry.CreateIndex = idx + } + entry.ModifyIndex = idx + + // Preserve the existing session unless told otherwise. The "existing" + // session for a new entry is "no session". + if !updateSession { + if existing != nil { + entry.Session = existing.(*structs.DirEntry).Session + } else { + entry.Session = "" + } + } + + // Store the kv pair in the state store and update the index. + if err := tx.Insert("kvs", entry); err != nil { + return fmt.Errorf("failed inserting kvs entry: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.kvsWatch.Notify(entry.Key, false) }) + return nil +} + +// KVSGet is used to retrieve a key/value pair from the state store. +func (s *StateStore) KVSGet(key string) (uint64, *structs.DirEntry, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, "kvs", "tombstones") + + // Retrieve the key. + entry, err := tx.First("kvs", "id", key) + if err != nil { + return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) + } + if entry != nil { + return idx, entry.(*structs.DirEntry), nil + } + return idx, nil, nil +} + +// KVSList is used to list out all keys under a given prefix. If the +// prefix is left empty, all keys in the KVS will be returned. The returned +// is the max index of the returned kvs entries or applicable tombstones, or +// else it's the full table indexes for kvs and tombstones. +func (s *StateStore) KVSList(prefix string) (uint64, structs.DirEntries, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table indexes. + idx := maxIndexTxn(tx, "kvs", "tombstones") + + // Query the prefix and list the available keys + entries, err := tx.Get("kvs", "id_prefix", prefix) + if err != nil { + return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) + } + + // Gather all of the keys found in the store + var ents structs.DirEntries + var lindex uint64 + for entry := entries.Next(); entry != nil; entry = entries.Next() { + e := entry.(*structs.DirEntry) + ents = append(ents, e) + if e.ModifyIndex > lindex { + lindex = e.ModifyIndex + } + } + + // Check for the highest index in the graveyard. If the prefix is empty + // then just use the full table indexes since we are listing everything. + if prefix != "" { + gindex, err := s.kvsGraveyard.GetMaxIndexTxn(tx, prefix) + if err != nil { + return 0, nil, fmt.Errorf("failed graveyard lookup: %s", err) + } + if gindex > lindex { + lindex = gindex + } + } else { + lindex = idx + } + + // Use the sub index if it was set and there are entries, otherwise use + // the full table index from above. + if lindex != 0 { + idx = lindex + } + return idx, ents, nil +} + +// KVSListKeys is used to query the KV store for keys matching the given prefix. +// An optional separator may be specified, which can be used to slice off a part +// of the response so that only a subset of the prefix is returned. In this +// mode, the keys which are omitted are still counted in the returned index. +func (s *StateStore) KVSListKeys(prefix, sep string) (uint64, []string, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table indexes. + idx := maxIndexTxn(tx, "kvs", "tombstones") + + // Fetch keys using the specified prefix + entries, err := tx.Get("kvs", "id_prefix", prefix) + if err != nil { + return 0, nil, fmt.Errorf("failed kvs lookup: %s", err) + } + + prefixLen := len(prefix) + sepLen := len(sep) + + var keys []string + var lindex uint64 + var last string + for entry := entries.Next(); entry != nil; entry = entries.Next() { + e := entry.(*structs.DirEntry) + + // Accumulate the high index + if e.ModifyIndex > lindex { + lindex = e.ModifyIndex + } + + // Always accumulate if no separator provided + if sepLen == 0 { + keys = append(keys, e.Key) + continue + } + + // Parse and de-duplicate the returned keys based on the + // key separator, if provided. + after := e.Key[prefixLen:] + sepIdx := strings.Index(after, sep) + if sepIdx > -1 { + key := e.Key[:prefixLen+sepIdx+sepLen] + if key != last { + keys = append(keys, key) + last = key + } + } else { + keys = append(keys, e.Key) + } + } + + // Check for the highest index in the graveyard. If the prefix is empty + // then just use the full table indexes since we are listing everything. + if prefix != "" { + gindex, err := s.kvsGraveyard.GetMaxIndexTxn(tx, prefix) + if err != nil { + return 0, nil, fmt.Errorf("failed graveyard lookup: %s", err) + } + if gindex > lindex { + lindex = gindex + } + } else { + lindex = idx + } + + // Use the sub index if it was set and there are entries, otherwise use + // the full table index from above. + if lindex != 0 { + idx = lindex + } + return idx, keys, nil +} + +// KVSDelete is used to perform a shallow delete on a single key in the +// the state store. +func (s *StateStore) KVSDelete(idx uint64, key string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Perform the actual delete + if err := s.kvsDeleteTxn(tx, idx, key); err != nil { + return err + } + + tx.Commit() + return nil +} + +// kvsDeleteTxn is the inner method used to perform the actual deletion +// of a key/value pair within an existing transaction. +func (s *StateStore) kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { + // Look up the entry in the state store. + entry, err := tx.First("kvs", "id", key) + if err != nil { + return fmt.Errorf("failed kvs lookup: %s", err) + } + if entry == nil { + return nil + } + + // Create a tombstone. + if err := s.kvsGraveyard.InsertTxn(tx, key, idx); err != nil { + return fmt.Errorf("failed adding to graveyard: %s", err) + } + + // Delete the entry and update the index. + if err := tx.Delete("kvs", entry); err != nil { + return fmt.Errorf("failed deleting kvs entry: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.kvsWatch.Notify(key, false) }) + return nil +} + +// KVSDeleteCAS is used to try doing a KV delete operation with a given +// raft index. If the CAS index specified is not equal to the last +// observed index for the given key, then the call is a noop, otherwise +// a normal KV delete is invoked. +func (s *StateStore) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) { + tx := s.db.Txn(true) + defer tx.Abort() + + // Retrieve the existing kvs entry, if any exists. + entry, err := tx.First("kvs", "id", key) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + + // If the existing index does not match the provided CAS + // index arg, then we shouldn't update anything and can safely + // return early here. + e, ok := entry.(*structs.DirEntry) + if !ok || e.ModifyIndex != cidx { + return entry == nil, nil + } + + // Call the actual deletion if the above passed. + if err := s.kvsDeleteTxn(tx, idx, key); err != nil { + return false, err + } + + tx.Commit() + return true, nil +} + +// KVSSetCAS is used to do a check-and-set operation on a KV entry. The +// ModifyIndex in the provided entry is used to determine if we should +// write the entry to the state store or bail. Returns a bool indicating +// if a write happened and any error. +func (s *StateStore) KVSSetCAS(idx uint64, entry *structs.DirEntry) (bool, error) { + tx := s.db.Txn(true) + defer tx.Abort() + + // Retrieve the existing entry. + existing, err := tx.First("kvs", "id", entry.Key) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + + // Check if the we should do the set. A ModifyIndex of 0 means that + // we are doing a set-if-not-exists. + if entry.ModifyIndex == 0 && existing != nil { + return false, nil + } + if entry.ModifyIndex != 0 && existing == nil { + return false, nil + } + e, ok := existing.(*structs.DirEntry) + if ok && entry.ModifyIndex != 0 && entry.ModifyIndex != e.ModifyIndex { + return false, nil + } + + // If we made it this far, we should perform the set. + if err := s.kvsSetTxn(tx, idx, entry, false); err != nil { + return false, err + } + + tx.Commit() + return true, nil +} + +// KVSDeleteTree is used to do a recursive delete on a key prefix +// in the state store. If any keys are modified, the last index is +// set, otherwise this is a no-op. +func (s *StateStore) KVSDeleteTree(idx uint64, prefix string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Get an iterator over all of the keys with the given prefix. + entries, err := tx.Get("kvs", "id_prefix", prefix) + if err != nil { + return fmt.Errorf("failed kvs lookup: %s", err) + } + + // Go over all of the keys and remove them. We call the delete + // directly so that we only update the index once. We also add + // tombstones as we go. + var modified bool + var objs []interface{} + for entry := entries.Next(); entry != nil; entry = entries.Next() { + e := entry.(*structs.DirEntry) + if err := s.kvsGraveyard.InsertTxn(tx, e.Key, idx); err != nil { + return fmt.Errorf("failed adding to graveyard: %s", err) + } + objs = append(objs, entry) + modified = true + } + + // Do the actual deletes in a separate loop so we don't trash the + // iterator as we go. + for _, obj := range objs { + if err := tx.Delete("kvs", obj); err != nil { + return fmt.Errorf("failed deleting kvs entry: %s", err) + } + } + + // Update the index + if modified { + tx.Defer(func() { s.kvsWatch.Notify(prefix, true) }) + if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + } + + tx.Commit() + return nil +} + +// KVSLockDelay returns the expiration time for any lock delay associated with +// the given key. +func (s *StateStore) KVSLockDelay(key string) time.Time { + return s.lockDelay.GetExpiration(key) +} + +// KVSLock is similar to KVSSet but only performs the set if the lock can be +// acquired. +func (s *StateStore) KVSLock(idx uint64, entry *structs.DirEntry) (bool, error) { + tx := s.db.Txn(true) + defer tx.Abort() + + // Verify that a session is present. + if entry.Session == "" { + return false, fmt.Errorf("missing session") + } + + // Verify that the session exists. + sess, err := tx.First("sessions", "id", entry.Session) + if err != nil { + return false, fmt.Errorf("failed session lookup: %s", err) + } + if sess == nil { + return false, fmt.Errorf("invalid session %#v", entry.Session) + } + + // Retrieve the existing entry. + existing, err := tx.First("kvs", "id", entry.Key) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + + // Set up the entry, using the existing entry if present. + if existing != nil { + e := existing.(*structs.DirEntry) + if e.Session == entry.Session { + // We already hold this lock, good to go. + entry.CreateIndex = e.CreateIndex + entry.LockIndex = e.LockIndex + } else if e.Session != "" { + // Bail out, someone else holds this lock. + return false, nil + } else { + // Set up a new lock with this session. + entry.CreateIndex = e.CreateIndex + entry.LockIndex = e.LockIndex + 1 + } + } else { + entry.CreateIndex = idx + entry.LockIndex = 1 + } + entry.ModifyIndex = idx + + // If we made it this far, we should perform the set. + if err := s.kvsSetTxn(tx, idx, entry, true); err != nil { + return false, err + } + + tx.Commit() + return true, nil +} + +// KVSUnlock is similar to KVSSet but only performs the set if the lock can be +// unlocked (the key must already exist and be locked). +func (s *StateStore) KVSUnlock(idx uint64, entry *structs.DirEntry) (bool, error) { + tx := s.db.Txn(true) + defer tx.Abort() + + // Verify that a session is present. + if entry.Session == "" { + return false, fmt.Errorf("missing session") + } + + // Retrieve the existing entry. + existing, err := tx.First("kvs", "id", entry.Key) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + + // Bail if there's no existing key. + if existing == nil { + return false, nil + } + + // Make sure the given session is the lock holder. + e := existing.(*structs.DirEntry) + if e.Session != entry.Session { + return false, nil + } + + // Clear the lock and update the entry. + entry.Session = "" + entry.LockIndex = e.LockIndex + entry.CreateIndex = e.CreateIndex + entry.ModifyIndex = idx + + // If we made it this far, we should perform the set. + if err := s.kvsSetTxn(tx, idx, entry, true); err != nil { + return false, err + } + + tx.Commit() + return true, nil +} + +// SessionCreate is used to register a new session in the state store. +func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // This code is technically able to (incorrectly) update an existing + // session but we never do that in practice. The upstream endpoint code + // always adds a unique ID when doing a create operation so we never hit + // an existing session again. It isn't worth the overhead to verify + // that here, but it's worth noting that we should never do this in the + // future. + + // Call the session creation + if err := s.sessionCreateTxn(tx, idx, sess); err != nil { + return err + } + + tx.Commit() + return nil +} + +// sessionCreateTxn is the inner method used for creating session entries in +// an open transaction. Any health checks registered with the session will be +// checked for failing status. Returns any error encountered. +func (s *StateStore) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { + // Check that we have a session ID + if sess.ID == "" { + return ErrMissingSessionID + } + + // Verify the session behavior is valid + switch sess.Behavior { + case "": + // Release by default to preserve backwards compatibility + sess.Behavior = structs.SessionKeysRelease + case structs.SessionKeysRelease: + case structs.SessionKeysDelete: + default: + return fmt.Errorf("Invalid session behavior: %s", sess.Behavior) + } + + // Assign the indexes. ModifyIndex likely will not be used but + // we set it here anyways for sanity. + sess.CreateIndex = idx + sess.ModifyIndex = idx + + // Check that the node exists + node, err := tx.First("nodes", "id", sess.Node) + if err != nil { + return fmt.Errorf("failed node lookup: %s", err) + } + if node == nil { + return ErrMissingNode + } + + // Go over the session checks and ensure they exist. + for _, checkID := range sess.Checks { + check, err := tx.First("checks", "id", sess.Node, checkID) + if err != nil { + return fmt.Errorf("failed check lookup: %s", err) + } + if check == nil { + return fmt.Errorf("Missing check '%s' registration", checkID) + } + + // Check that the check is not in critical state + status := check.(*structs.HealthCheck).Status + if status == structs.HealthCritical { + return fmt.Errorf("Check '%s' is in %s state", checkID, status) + } + } + + // Insert the session + if err := tx.Insert("sessions", sess); err != nil { + return fmt.Errorf("failed inserting session: %s", err) + } + + // Insert the check mappings + for _, checkID := range sess.Checks { + mapping := &sessionCheck{ + Node: sess.Node, + CheckID: checkID, + Session: sess.ID, + } + if err := tx.Insert("session_checks", mapping); err != nil { + return fmt.Errorf("failed inserting session check mapping: %s", err) + } + } + + // Update the index + if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["sessions"].Notify() }) + return nil +} + +// SessionGet is used to retrieve an active session from the state store. +func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("SessionGet")...) + + // Look up the session by its ID + session, err := tx.First("sessions", "id", sessionID) + if err != nil { + return 0, nil, fmt.Errorf("failed session lookup: %s", err) + } + if session != nil { + return idx, session.(*structs.Session), nil + } + return idx, nil, nil +} + +// SessionList returns a slice containing all of the active sessions. +func (s *StateStore) SessionList() (uint64, structs.Sessions, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("SessionList")...) + + // Query all of the active sessions. + sessions, err := tx.Get("sessions", "id") + if err != nil { + return 0, nil, fmt.Errorf("failed session lookup: %s", err) + } + + // Go over the sessions and create a slice of them. + var result structs.Sessions + for session := sessions.Next(); session != nil; session = sessions.Next() { + result = append(result, session.(*structs.Session)) + } + return idx, result, nil +} + +// NodeSessions returns a set of active sessions associated +// with the given node ID. The returned index is the highest +// index seen from the result set. +func (s *StateStore) NodeSessions(nodeID string) (uint64, structs.Sessions, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("NodeSessions")...) + + // Get all of the sessions which belong to the node + sessions, err := tx.Get("sessions", "node", nodeID) + if err != nil { + return 0, nil, fmt.Errorf("failed session lookup: %s", err) + } + + // Go over all of the sessions and return them as a slice + var result structs.Sessions + for session := sessions.Next(); session != nil; session = sessions.Next() { + result = append(result, session.(*structs.Session)) + } + return idx, result, nil +} + +// SessionDestroy is used to remove an active session. This will +// implicitly invalidate the session and invoke the specified +// session destroy behavior. +func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the session deletion. + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteSessionTxn(tx, idx, watches, sessionID); err != nil { + return err + } + + tx.Defer(func() { watches.Notify() }) + tx.Commit() + return nil +} + +// deleteSessionTxn is the inner method, which is used to do the actual +// session deletion and handle session invalidation, watch triggers, etc. +func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, sessionID string) error { + // Look up the session. + sess, err := tx.First("sessions", "id", sessionID) + if err != nil { + return fmt.Errorf("failed session lookup: %s", err) + } + if sess == nil { + return nil + } + + // Delete the session and write the new index. + if err := tx.Delete("sessions", sess); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + // Enforce the max lock delay. + session := sess.(*structs.Session) + delay := session.LockDelay + if delay > structs.MaxLockDelay { + delay = structs.MaxLockDelay + } + + // Snag the current now time so that all the expirations get calculated + // the same way. + now := time.Now() + + // Get an iterator over all of the keys with the given session. + entries, err := tx.Get("kvs", "session", sessionID) + if err != nil { + return fmt.Errorf("failed kvs lookup: %s", err) + } + var kvs []interface{} + for entry := entries.Next(); entry != nil; entry = entries.Next() { + kvs = append(kvs, entry) + } + + // Invalidate any held locks. + switch session.Behavior { + case structs.SessionKeysRelease: + for _, obj := range kvs { + // Note that we clone here since we are modifying the + // returned object and want to make sure our set op + // respects the transaction we are in. + e := obj.(*structs.DirEntry).Clone() + e.Session = "" + if err := s.kvsSetTxn(tx, idx, e, true); err != nil { + return fmt.Errorf("failed kvs update: %s", err) + } + + // Apply the lock delay if present. + if delay > 0 { + s.lockDelay.SetExpiration(e.Key, now, delay) + } + } + case structs.SessionKeysDelete: + for _, obj := range kvs { + e := obj.(*structs.DirEntry) + if err := s.kvsDeleteTxn(tx, idx, e.Key); err != nil { + return fmt.Errorf("failed kvs delete: %s", err) + } + + // Apply the lock delay if present. + if delay > 0 { + s.lockDelay.SetExpiration(e.Key, now, delay) + } + } + default: + return fmt.Errorf("unknown session behavior %#v", session.Behavior) + } + + // Delete any check mappings. + mappings, err := tx.Get("session_checks", "session", sessionID) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + var objs []interface{} + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + objs = append(objs, mapping) + } + + // Do the delete in a separate loop so we don't trash the iterator. + for _, obj := range objs { + if err := tx.Delete("session_checks", obj); err != nil { + return fmt.Errorf("failed deleting session check: %s", err) + } + } + + watches.Arm("sessions") + return nil +} + +// ACLSet is used to insert an ACL rule into the state store. +func (s *StateStore) ACLSet(idx uint64, acl *structs.ACL) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call set on the ACL + if err := s.aclSetTxn(tx, idx, acl); err != nil { + return err + } + + tx.Commit() + return nil +} + +// aclSetTxn is the inner method used to insert an ACL rule with the +// proper indexes into the state store. +func (s *StateStore) aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error { + // Check that the ID is set + if acl.ID == "" { + return ErrMissingACLID + } + + // Check for an existing ACL + existing, err := tx.First("acls", "id", acl.ID) + if err != nil { + return fmt.Errorf("failed acl lookup: %s", err) + } + + // Set the indexes + if existing != nil { + acl.CreateIndex = existing.(*structs.ACL).CreateIndex + acl.ModifyIndex = idx + } else { + acl.CreateIndex = idx + acl.ModifyIndex = idx + } + + // Insert the ACL + if err := tx.Insert("acls", acl); err != nil { + return fmt.Errorf("failed inserting acl: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["acls"].Notify() }) + return nil +} + +// ACLGet is used to look up an existing ACL by ID. +func (s *StateStore) ACLGet(aclID string) (uint64, *structs.ACL, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ACLGet")...) + + // Query for the existing ACL + acl, err := tx.First("acls", "id", aclID) + if err != nil { + return 0, nil, fmt.Errorf("failed acl lookup: %s", err) + } + if acl != nil { + return idx, acl.(*structs.ACL), nil + } + return idx, nil, nil +} + +// ACLList is used to list out all of the ACLs in the state store. +func (s *StateStore) ACLList() (uint64, structs.ACLs, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + // Get the table index. + idx := maxIndexTxn(tx, s.getWatchTables("ACLList")...) + + // Return the ACLs. + acls, err := s.aclListTxn(tx) + if err != nil { + return 0, nil, fmt.Errorf("failed acl lookup: %s", err) + } + return idx, acls, nil +} + +// aclListTxn is used to list out all of the ACLs in the state store. This is a +// function vs. a method so it can be called from the snapshotter. +func (s *StateStore) aclListTxn(tx *memdb.Txn) (structs.ACLs, error) { + // Query all of the ACLs in the state store + acls, err := tx.Get("acls", "id") + if err != nil { + return nil, fmt.Errorf("failed acl lookup: %s", err) + } + + // Go over all of the ACLs and build the response + var result structs.ACLs + for acl := acls.Next(); acl != nil; acl = acls.Next() { + a := acl.(*structs.ACL) + result = append(result, a) + } + return result, nil +} + +// ACLDelete is used to remove an existing ACL from the state store. If +// the ACL does not exist this is a no-op and no error is returned. +func (s *StateStore) ACLDelete(idx uint64, aclID string) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Call the ACL delete + if err := s.aclDeleteTxn(tx, idx, aclID); err != nil { + return err + } + + tx.Commit() + return nil +} + +// aclDeleteTxn is used to delete an ACL from the state store within +// an existing transaction. +func (s *StateStore) aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error { + // Look up the existing ACL + acl, err := tx.First("acls", "id", aclID) + if err != nil { + return fmt.Errorf("failed acl lookup: %s", err) + } + if acl == nil { + return nil + } + + // Delete the ACL from the state store and update indexes + if err := tx.Delete("acls", acl); err != nil { + return fmt.Errorf("failed deleting acl: %s", err) + } + if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["acls"].Notify() }) + return nil +} diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go new file mode 100644 index 000000000000..ae2b82dd692a --- /dev/null +++ b/consul/state/state_store_test.go @@ -0,0 +1,4723 @@ +package state + +import ( + crand "crypto/rand" + "fmt" + "reflect" + "sort" + "strings" + "testing" + "time" + + "github.com/hashicorp/consul/consul/structs" +) + +func testUUID() string { + buf := make([]byte, 16) + if _, err := crand.Read(buf); err != nil { + panic(fmt.Errorf("failed to read random bytes: %v", err)) + } + + return fmt.Sprintf("%08x-%04x-%04x-%04x-%12x", + buf[0:4], + buf[4:6], + buf[6:8], + buf[8:10], + buf[10:16]) +} + +func testStateStore(t *testing.T) *StateStore { + s, err := NewStateStore(nil) + if err != nil { + t.Fatalf("err: %s", err) + } + if s == nil { + t.Fatalf("missing state store") + } + return s +} + +func testRegisterNode(t *testing.T, s *StateStore, idx uint64, nodeID string) { + node := &structs.Node{Node: nodeID} + if err := s.EnsureNode(idx, node); err != nil { + t.Fatalf("err: %s", err) + } + + tx := s.db.Txn(false) + defer tx.Abort() + n, err := tx.First("nodes", "id", nodeID) + if err != nil { + t.Fatalf("err: %s", err) + } + if result, ok := n.(*structs.Node); !ok || result.Node != nodeID { + t.Fatalf("bad node: %#v", result) + } +} + +func testRegisterService(t *testing.T, s *StateStore, idx uint64, nodeID, serviceID string) { + svc := &structs.NodeService{ + ID: serviceID, + Service: serviceID, + Address: "1.1.1.1", + Port: 1111, + } + if err := s.EnsureService(idx, nodeID, svc); err != nil { + t.Fatalf("err: %s", err) + } + + tx := s.db.Txn(false) + defer tx.Abort() + service, err := tx.First("services", "id", nodeID, serviceID) + if err != nil { + t.Fatalf("err: %s", err) + } + if result, ok := service.(*structs.ServiceNode); !ok || + result.Node != nodeID || + result.ServiceID != serviceID { + t.Fatalf("bad service: %#v", result) + } +} + +func testRegisterCheck(t *testing.T, s *StateStore, idx uint64, + nodeID, serviceID, checkID, state string) { + chk := &structs.HealthCheck{ + Node: nodeID, + CheckID: checkID, + ServiceID: serviceID, + Status: state, + } + if err := s.EnsureCheck(idx, chk); err != nil { + t.Fatalf("err: %s", err) + } + + tx := s.db.Txn(false) + defer tx.Abort() + c, err := tx.First("checks", "id", nodeID, checkID) + if err != nil { + t.Fatalf("err: %s", err) + } + if result, ok := c.(*structs.HealthCheck); !ok || + result.Node != nodeID || + result.ServiceID != serviceID || + result.CheckID != checkID { + t.Fatalf("bad check: %#v", result) + } +} + +func testSetKey(t *testing.T, s *StateStore, idx uint64, key, value string) { + entry := &structs.DirEntry{Key: key, Value: []byte(value)} + if err := s.KVSSet(idx, entry); err != nil { + t.Fatalf("err: %s", err) + } + + tx := s.db.Txn(false) + defer tx.Abort() + e, err := tx.First("kvs", "id", key) + if err != nil { + t.Fatalf("err: %s", err) + } + if result, ok := e.(*structs.DirEntry); !ok || result.Key != key { + t.Fatalf("bad kvs entry: %#v", result) + } +} + +func TestStateStore_Restore_Abort(t *testing.T) { + s := testStateStore(t) + + // The detailed restore functions are tested below, this just checks + // that abort works. + restore := s.Restore() + entry := &structs.DirEntry{ + Key: "foo", + Value: []byte("bar"), + RaftIndex: structs.RaftIndex{ + ModifyIndex: 5, + }, + } + if err := restore.KVS(entry); err != nil { + t.Fatalf("err: %s", err) + } + restore.Abort() + + idx, entries, err := s.KVSList("") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 0 { + t.Fatalf("bad index: %d", idx) + } + if len(entries) != 0 { + t.Fatalf("bad: %#v", entries) + } +} + +func TestStateStore_maxIndex(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 0, "foo") + testRegisterNode(t, s, 1, "bar") + testRegisterService(t, s, 2, "foo", "consul") + + if max := s.maxIndex("nodes", "services"); max != 2 { + t.Fatalf("bad max: %d", max) + } +} + +func TestStateStore_indexUpdateMaxTxn(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 0, "foo") + testRegisterNode(t, s, 1, "bar") + + tx := s.db.Txn(true) + if err := indexUpdateMaxTxn(tx, 3, "nodes"); err != nil { + t.Fatalf("err: %s", err) + } + tx.Commit() + + if max := s.maxIndex("nodes"); max != 3 { + t.Fatalf("bad max: %d", max) + } +} + +func TestStateStore_GC(t *testing.T) { + // Build up a fast GC. + ttl := 10 * time.Millisecond + gran := 5 * time.Millisecond + gc, err := NewTombstoneGC(ttl, gran) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Enable it and attach it to the state store. + gc.SetEnabled(true) + s, err := NewStateStore(gc) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Create some KV pairs. + testSetKey(t, s, 1, "foo", "foo") + testSetKey(t, s, 2, "foo/bar", "bar") + testSetKey(t, s, 3, "foo/baz", "bar") + testSetKey(t, s, 4, "foo/moo", "bar") + testSetKey(t, s, 5, "foo/zoo", "bar") + + // Delete a key and make sure the GC sees it. + if err := s.KVSDelete(6, "foo/zoo"); err != nil { + t.Fatalf("err: %s", err) + } + select { + case idx := <-gc.ExpireCh(): + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + case <-time.After(2 * ttl): + t.Fatalf("GC never fired") + } + + // Check for the same behavior with a tree delete. + if err := s.KVSDeleteTree(7, "foo/moo"); err != nil { + t.Fatalf("err: %s", err) + } + select { + case idx := <-gc.ExpireCh(): + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + case <-time.After(2 * ttl): + t.Fatalf("GC never fired") + } + + // Check for the same behavior with a CAS delete. + if ok, err := s.KVSDeleteCAS(8, 3, "foo/baz"); !ok || err != nil { + t.Fatalf("err: %s", err) + } + select { + case idx := <-gc.ExpireCh(): + if idx != 8 { + t.Fatalf("bad index: %d", idx) + } + case <-time.After(2 * ttl): + t.Fatalf("GC never fired") + } + + // Finally, try it with an expiring session. + testRegisterNode(t, s, 9, "node1") + session := &structs.Session{ + ID: testUUID(), + Node: "node1", + Behavior: structs.SessionKeysDelete, + } + if err := s.SessionCreate(10, session); err != nil { + t.Fatalf("err: %s", err) + } + d := &structs.DirEntry{ + Key: "lock", + Session: session.ID, + } + if ok, err := s.KVSLock(11, d); !ok || err != nil { + t.Fatalf("err: %v", err) + } + if err := s.SessionDestroy(12, session.ID); err != nil { + t.Fatalf("err: %s", err) + } + select { + case idx := <-gc.ExpireCh(): + if idx != 12 { + t.Fatalf("bad index: %d", idx) + } + case <-time.After(2 * ttl): + t.Fatalf("GC never fired") + } +} + +func TestStateStore_ReapTombstones(t *testing.T) { + s := testStateStore(t) + + // Create some KV pairs. + testSetKey(t, s, 1, "foo", "foo") + testSetKey(t, s, 2, "foo/bar", "bar") + testSetKey(t, s, 3, "foo/baz", "bar") + testSetKey(t, s, 4, "foo/moo", "bar") + testSetKey(t, s, 5, "foo/zoo", "bar") + + // Call a delete on some specific keys. + if err := s.KVSDelete(6, "foo/baz"); err != nil { + t.Fatalf("err: %s", err) + } + if err := s.KVSDelete(7, "foo/moo"); err != nil { + t.Fatalf("err: %s", err) + } + + // Pull out the list and check the index, which should come from the + // tombstones. + idx, _, err := s.KVSList("foo/") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Reap the tombstones <= 6. + if err := s.ReapTombstones(6); err != nil { + t.Fatalf("err: %s", err) + } + + // Should still be good because 7 is in there. + idx, _, err = s.KVSList("foo/") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Now reap them all. + if err := s.ReapTombstones(7); err != nil { + t.Fatalf("err: %s", err) + } + + // At this point the sub index will slide backwards. + idx, _, err = s.KVSList("foo/") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure the tombstones are actually gone. + snap := s.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") + } +} + +func TestStateStore_GetWatches(t *testing.T) { + s := testStateStore(t) + + // This test does two things - it makes sure there's no full table + // watch for KVS, and it makes sure that asking for a watch that + // doesn't exist causes a panic. + func() { + defer func() { + if r := recover(); r == nil { + t.Fatalf("didn't get expected panic") + } + }() + s.getTableWatch("kvs") + }() + + // Similar for tombstones; those don't support watches at all. + func() { + defer func() { + if r := recover(); r == nil { + t.Fatalf("didn't get expected panic") + } + }() + s.getTableWatch("tombstones") + }() + + // Make sure requesting a bogus method causes a panic. + func() { + defer func() { + if r := recover(); r == nil { + t.Fatalf("didn't get expected panic") + } + }() + s.GetQueryWatch("dogs") + }() + + // Request valid watches. + if w := s.GetQueryWatch("Nodes"); w == nil { + t.Fatalf("didn't get a watch") + } + if w := s.GetQueryWatch("NodeDump"); w == nil { + t.Fatalf("didn't get a watch") + } + if w := s.GetKVSWatch("/dogs"); w == nil { + t.Fatalf("didn't get a watch") + } +} + +func TestStateStore_EnsureRegistration(t *testing.T) { + s := testStateStore(t) + + // Start with just a node. + req := &structs.RegisterRequest{ + Node: "node1", + Address: "1.2.3.4", + } + if err := s.EnsureRegistration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the node and verify its contents. + verifyNode := func(created, modified uint64) { + _, out, err := s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if out.Node != "node1" || out.Address != "1.2.3.4" || + out.CreateIndex != created || out.ModifyIndex != modified { + t.Fatalf("bad node returned: %#v", out) + } + } + verifyNode(1, 1) + + // Add in a service definition. + req.Service = &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + } + if err := s.EnsureRegistration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify that the service got registered. + verifyService := func(created, modified uint64) { + idx, out, err := s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out.Services) != 1 { + t.Fatalf("bad: %#v", out.Services) + } + s := out.Services["redis1"] + if s.ID != "redis1" || s.Service != "redis" || + s.Address != "1.1.1.1" || s.Port != 8080 || + s.CreateIndex != created || s.ModifyIndex != modified { + t.Fatalf("bad service returned: %#v", s) + } + } + verifyNode(1, 2) + verifyService(2, 2) + + // Add in a top-level check. + req.Check = &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "check", + } + if err := s.EnsureRegistration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify that the check got registered. + verifyCheck := func(created, modified uint64) { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 1 { + t.Fatalf("bad: %#v", out) + } + c := out[0] + if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || + c.CreateIndex != created || c.ModifyIndex != modified { + t.Fatalf("bad check returned: %#v", c) + } + } + verifyNode(1, 3) + verifyService(2, 3) + verifyCheck(3, 3) + + // Add in another check via the slice. + req.Checks = structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + Name: "check", + }, + } + if err := s.EnsureRegistration(4, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify that the additional check got registered. + verifyNode(1, 4) + verifyService(2, 4) + func() { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 2 { + t.Fatalf("bad: %#v", out) + } + c1 := out[0] + if c1.Node != "node1" || c1.CheckID != "check1" || c1.Name != "check" || + c1.CreateIndex != 3 || c1.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c1) + } + + c2 := out[1] + if c2.Node != "node1" || c2.CheckID != "check2" || c2.Name != "check" || + c2.CreateIndex != 4 || c2.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c2) + } + }() +} + +func TestStateStore_EnsureRegistration_Restore(t *testing.T) { + s := testStateStore(t) + + // Start with just a node. + req := &structs.RegisterRequest{ + Node: "node1", + Address: "1.2.3.4", + } + restore := s.Restore() + if err := restore.Registration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Retrieve the node and verify its contents. + verifyNode := func(created, modified uint64) { + _, out, err := s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if out.Node != "node1" || out.Address != "1.2.3.4" || + out.CreateIndex != created || out.ModifyIndex != modified { + t.Fatalf("bad node returned: %#v", out) + } + } + verifyNode(1, 1) + + // Add in a service definition. + req.Service = &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + } + restore = s.Restore() + if err := restore.Registration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Verify that the service got registered. + verifyService := func(created, modified uint64) { + idx, out, err := s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out.Services) != 1 { + t.Fatalf("bad: %#v", out.Services) + } + s := out.Services["redis1"] + if s.ID != "redis1" || s.Service != "redis" || + s.Address != "1.1.1.1" || s.Port != 8080 || + s.CreateIndex != created || s.ModifyIndex != modified { + t.Fatalf("bad service returned: %#v", s) + } + } + verifyNode(1, 2) + verifyService(2, 2) + + // Add in a top-level check. + req.Check = &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "check", + } + restore = s.Restore() + if err := restore.Registration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Verify that the check got registered. + verifyCheck := func(created, modified uint64) { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != modified { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 1 { + t.Fatalf("bad: %#v", out) + } + c := out[0] + if c.Node != "node1" || c.CheckID != "check1" || c.Name != "check" || + c.CreateIndex != created || c.ModifyIndex != modified { + t.Fatalf("bad check returned: %#v", c) + } + } + verifyNode(1, 3) + verifyService(2, 3) + verifyCheck(3, 3) + + // Add in another check via the slice. + req.Checks = structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + Name: "check", + }, + } + restore = s.Restore() + if err := restore.Registration(4, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + + // Verify that the additional check got registered. + verifyNode(1, 4) + verifyService(2, 4) + func() { + idx, out, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if len(out) != 2 { + t.Fatalf("bad: %#v", out) + } + c1 := out[0] + if c1.Node != "node1" || c1.CheckID != "check1" || c1.Name != "check" || + c1.CreateIndex != 3 || c1.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c1) + } + + c2 := out[1] + if c2.Node != "node1" || c2.CheckID != "check2" || c2.Name != "check" || + c2.CreateIndex != 4 || c2.ModifyIndex != 4 { + t.Fatalf("bad check returned: %#v", c2) + } + }() +} + +func TestStateStore_EnsureRegistration_Watches(t *testing.T) { + s := testStateStore(t) + + req := &structs.RegisterRequest{ + Node: "node1", + Address: "1.2.3.4", + } + + // The nodes watch should fire for this one. + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyNoWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureRegistration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + // The nodes watch should fire for this one. + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyNoWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + restore := s.Restore() + if err := restore.Registration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) + }) + }) + + // With a service definition added it should fire nodes and + // services. + req.Service = &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "1.1.1.1", + Port: 8080, + } + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureRegistration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyNoWatch(t, s.getTableWatch("checks"), func() { + restore := s.Restore() + if err := restore.Registration(2, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) + }) + }) + + // Now with a check it should hit all three. + req.Check = &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "check", + } + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureRegistration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + restore := s.Restore() + if err := restore.Registration(3, req); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) + }) + }) +} + +func TestStateStore_EnsureNode(t *testing.T) { + s := testStateStore(t) + + // Fetching a non-existent node returns nil + if _, node, err := s.GetNode("node1"); node != nil || err != nil { + t.Fatalf("expected (nil, nil), got: (%#v, %#v)", node, err) + } + + // Create a node registration request + in := &structs.Node{ + Node: "node1", + Address: "1.1.1.1", + } + + // Ensure the node is registered in the db + if err := s.EnsureNode(1, in); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the node again + idx, out, err := s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Correct node was returned + if out.Node != "node1" || out.Address != "1.1.1.1" { + t.Fatalf("bad node returned: %#v", out) + } + + // Indexes are set properly + if out.CreateIndex != 1 || out.ModifyIndex != 1 { + t.Fatalf("bad node index: %#v", out) + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Update the node registration + in.Address = "1.1.1.2" + if err := s.EnsureNode(2, in); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the node + idx, out, err = s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Node and indexes were updated + if out.CreateIndex != 1 || out.ModifyIndex != 2 || out.Address != "1.1.1.2" { + t.Fatalf("bad: %#v", out) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Node upsert preserves the create index + if err := s.EnsureNode(3, in); err != nil { + t.Fatalf("err: %s", err) + } + idx, out, err = s.GetNode("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if out.CreateIndex != 1 || out.ModifyIndex != 3 || out.Address != "1.1.1.2" { + t.Fatalf("node was modified: %#v", out) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_GetNodes(t *testing.T) { + s := testStateStore(t) + + // Listing with no results returns nil + idx, res, err := s.Nodes() + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create some nodes in the state store + testRegisterNode(t, s, 0, "node0") + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + + // Retrieve the nodes + idx, nodes, err := s.Nodes() + if err != nil { + t.Fatalf("err: %s", err) + } + + // Highest index was returned + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // All nodes were returned + if n := len(nodes); n != 3 { + t.Fatalf("bad node count: %d", n) + } + + // Make sure the nodes match + for i, node := range nodes { + if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { + t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) + } + name := fmt.Sprintf("node%d", i) + if node.Node != name { + t.Fatalf("bad: %#v", node) + } + } +} + +func BenchmarkGetNodes(b *testing.B) { + s, err := NewStateStore(nil) + if err != nil { + b.Fatalf("err: %s", err) + } + + if err := s.EnsureNode(100, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + b.Fatalf("err: %v", err) + } + if err := s.EnsureNode(101, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + b.Fatalf("err: %v", err) + } + + for i := 0; i < b.N; i++ { + s.Nodes() + } +} + +func TestStateStore_DeleteNode(t *testing.T) { + s := testStateStore(t) + + // Create a node and register a service and health check with it. + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) + + // Delete the node + if err := s.DeleteNode(3, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + + // The node was removed + if idx, n, err := s.GetNode("node1"); err != nil || n != nil || idx != 3 { + t.Fatalf("bad: %#v %d (err: %#v)", n, idx, err) + } + + // Associated service was removed. Need to query this directly out of + // the DB to make sure it is actually gone. + tx := s.db.Txn(false) + defer tx.Abort() + services, err := tx.Get("services", "id", "node1", "service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if service := services.Next(); service != nil { + t.Fatalf("bad: %#v", service) + } + + // Associated health check was removed. + checks, err := tx.Get("checks", "id", "node1", "check1") + if err != nil { + t.Fatalf("err: %s", err) + } + if check := checks.Next(); check != nil { + t.Fatalf("bad: %#v", check) + } + + // Indexes were updated. + for _, tbl := range []string{"nodes", "services", "checks"} { + if idx := s.maxIndex(tbl); idx != 3 { + t.Fatalf("bad index: %d (%s)", idx, tbl) + } + } + + // Deleting a nonexistent node should be idempotent and not return + // an error + if err := s.DeleteNode(4, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("nodes"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Node_Snapshot(t *testing.T) { + s := testStateStore(t) + + // Create some nodes in the state store. + testRegisterNode(t, s, 0, "node0") + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + + // Snapshot the nodes. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + testRegisterNode(t, s, 3, "node3") + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + nodes, err := snap.Nodes() + if err != nil { + t.Fatalf("err: %s", err) + } + for i := 0; i < 3; i++ { + node := nodes.Next().(*structs.Node) + if node == nil { + t.Fatalf("unexpected end of nodes") + } + + if node.CreateIndex != uint64(i) || node.ModifyIndex != uint64(i) { + t.Fatalf("bad node index: %d, %d", node.CreateIndex, node.ModifyIndex) + } + if node.Node != fmt.Sprintf("node%d", i) { + t.Fatalf("bad: %#v", node) + } + } + if nodes.Next() != nil { + t.Fatalf("unexpected extra nodes") + } +} + +func TestStateStore_Node_Watches(t *testing.T) { + s := testStateStore(t) + + // Call functions that update the nodes table and make sure a watch fires + // each time. + verifyWatch(t, s.getTableWatch("nodes"), func() { + req := &structs.RegisterRequest{ + Node: "node1", + } + if err := s.EnsureRegistration(1, req); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + node := &structs.Node{Node: "node2"} + if err := s.EnsureNode(2, node); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("nodes"), func() { + if err := s.DeleteNode(3, "node2"); err != nil { + t.Fatalf("err: %s", err) + } + }) + + // Check that a delete of a node + service + check triggers all three + // tables in one shot. + testRegisterNode(t, s, 4, "node1") + testRegisterService(t, s, 5, "node1", "service1") + testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteNode(7, "node1"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) +} + +func TestStateStore_EnsureService(t *testing.T) { + s := testStateStore(t) + + // Fetching services for a node with none returns nil + idx, res, err := s.NodeServices("node1") + if err != nil || res != nil || idx != 0 { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create the service registration + ns1 := &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod"}, + Address: "1.1.1.1", + Port: 1111, + } + + // Creating a service without a node returns an error + if err := s.EnsureService(1, "node1", ns1); err != ErrMissingNode { + t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) + } + + // Register the nodes + testRegisterNode(t, s, 0, "node1") + testRegisterNode(t, s, 1, "node2") + + // Service successfully registers into the state store + if err = s.EnsureService(10, "node1", ns1); err != nil { + t.Fatalf("err: %s", err) + } + + // Register a similar service against both nodes + ns2 := *ns1 + ns2.ID = "service2" + for _, n := range []string{"node1", "node2"} { + if err := s.EnsureService(20, n, &ns2); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Register a different service on the bad node + ns3 := *ns1 + ns3.ID = "service3" + if err := s.EnsureService(30, "node2", &ns3); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the services + idx, out, err := s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 30 { + t.Fatalf("bad index: %d", idx) + } + + // Only the services for the requested node are returned + if out == nil || len(out.Services) != 2 { + t.Fatalf("bad services: %#v", out) + } + + // Results match the inserted services and have the proper indexes set + expect1 := *ns1 + expect1.CreateIndex, expect1.ModifyIndex = 10, 10 + if svc := out.Services["service1"]; !reflect.DeepEqual(&expect1, svc) { + t.Fatalf("bad: %#v", svc) + } + + expect2 := ns2 + expect2.CreateIndex, expect2.ModifyIndex = 20, 20 + if svc := out.Services["service2"]; !reflect.DeepEqual(&expect2, svc) { + t.Fatalf("bad: %#v %#v", ns2, svc) + } + + // Index tables were updated + if idx := s.maxIndex("services"); idx != 30 { + t.Fatalf("bad index: %d", idx) + } + + // Update a service registration + ns1.Address = "1.1.1.2" + if err := s.EnsureService(40, "node1", ns1); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the service again and ensure it matches + idx, out, err = s.NodeServices("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 40 { + t.Fatalf("bad index: %d", idx) + } + if out == nil || len(out.Services) != 2 { + t.Fatalf("bad: %#v", out) + } + expect1.Address = "1.1.1.2" + expect1.ModifyIndex = 40 + if svc := out.Services["service1"]; !reflect.DeepEqual(&expect1, svc) { + t.Fatalf("bad: %#v", svc) + } + + // Index tables were updated + if idx := s.maxIndex("services"); idx != 40 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Services(t *testing.T) { + s := testStateStore(t) + + // Register several nodes and services. + testRegisterNode(t, s, 1, "node1") + ns1 := &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod", "master"}, + Address: "1.1.1.1", + Port: 1111, + } + if err := s.EnsureService(2, "node1", ns1); err != nil { + t.Fatalf("err: %s", err) + } + testRegisterService(t, s, 3, "node1", "dogs") + testRegisterNode(t, s, 4, "node2") + ns2 := &structs.NodeService{ + ID: "service3", + Service: "redis", + Tags: []string{"prod", "slave"}, + Address: "1.1.1.1", + Port: 1111, + } + if err := s.EnsureService(5, "node2", ns2); err != nil { + t.Fatalf("err: %s", err) + } + + // Pull all the services. + idx, services, err := s.Services() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Verify the result. We sort the lists since the order is + // non-deterministic (it's built using a map internally). + expected := structs.Services{ + "redis": []string{"prod", "master", "slave"}, + "dogs": []string{}, + } + sort.Strings(expected["redis"]) + for _, tags := range services { + sort.Strings(tags) + } + if !reflect.DeepEqual(expected, services) { + t.Fatalf("bad: %#v", services) + } +} + +// strContains checks if a list contains a string +func strContains(l []string, s string) bool { + for _, v := range l { + if v == s { + return true + } + } + return false +} + +func TestStateStore_ServiceNodes(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(12, "foo", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(14, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(15, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.ServiceNodes("db") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 16 { + t.Fatalf("bad: %v", 16) + } + if len(nodes) != 3 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "bar" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.2" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServiceID != "db" { + t.Fatalf("bad: %v", nodes) + } + if !strContains(nodes[0].ServiceTags, "slave") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } + + if nodes[1].Node != "bar" { + t.Fatalf("bad: %v", nodes) + } + if nodes[1].Address != "127.0.0.2" { + t.Fatalf("bad: %v", nodes) + } + if nodes[1].ServiceID != "db2" { + t.Fatalf("bad: %v", nodes) + } + if !strContains(nodes[1].ServiceTags, "slave") { + t.Fatalf("bad: %v", nodes) + } + if nodes[1].ServicePort != 8001 { + t.Fatalf("bad: %v", nodes) + } + + if nodes[2].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[2].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if nodes[2].ServiceID != "db" { + t.Fatalf("bad: %v", nodes) + } + if !strContains(nodes[2].ServiceTags, "master") { + t.Fatalf("bad: %v", nodes) + } + if nodes[2].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } +} + +func TestStateStore_ServiceTagNodes(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(15, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureNode(16, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(17, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(18, "foo", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(19, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.ServiceTagNodes("db", "master") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if !strContains(nodes[0].ServiceTags, "master") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } +} + +func TestStateStore_ServiceTagNodes_MultipleTags(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(15, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureNode(16, &structs.Node{Node: "bar", Address: "127.0.0.2"}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(17, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"master", "v2"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(18, "foo", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave", "v2", "dev"}, Address: "", Port: 8001}); err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.EnsureService(19, "bar", &structs.NodeService{ID: "db", Service: "db", Tags: []string{"slave", "v2"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.ServiceTagNodes("db", "master") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if !strContains(nodes[0].ServiceTags, "master") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8000 { + t.Fatalf("bad: %v", nodes) + } + + idx, nodes, err = s.ServiceTagNodes("db", "v2") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 3 { + t.Fatalf("bad: %v", nodes) + } + + idx, nodes, err = s.ServiceTagNodes("db", "dev") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 19 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Node != "foo" { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].Address != "127.0.0.1" { + t.Fatalf("bad: %v", nodes) + } + if !strContains(nodes[0].ServiceTags, "dev") { + t.Fatalf("bad: %v", nodes) + } + if nodes[0].ServicePort != 8001 { + t.Fatalf("bad: %v", nodes) + } +} + +func TestStateStore_DeleteService(t *testing.T) { + s := testStateStore(t) + + // Register a node with one service and a check + testRegisterNode(t, s, 1, "node1") + testRegisterService(t, s, 2, "node1", "service1") + testRegisterCheck(t, s, 3, "node1", "service1", "check1", structs.HealthPassing) + + // Delete the service + if err := s.DeleteService(4, "node1", "service1"); err != nil { + t.Fatalf("err: %s", err) + } + + // Service doesn't exist. + _, ns, err := s.NodeServices("node1") + if err != nil || ns == nil || len(ns.Services) != 0 { + t.Fatalf("bad: %#v (err: %#v)", ns, err) + } + + // Check doesn't exist. Check using the raw DB so we can test + // that it actually is removed in the state store. + tx := s.db.Txn(false) + defer tx.Abort() + check, err := tx.First("checks", "id", "node1", "check1") + if err != nil || check != nil { + t.Fatalf("bad: %#v (err: %s)", check, err) + } + + // Index tables were updated + if idx := s.maxIndex("services"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if idx := s.maxIndex("checks"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Deleting a nonexistent service should be idempotent and not return an + // error + if err := s.DeleteService(5, "node1", "service1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("services"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Service_Snapshot(t *testing.T) { + s := testStateStore(t) + + // Register a node with two services. + testRegisterNode(t, s, 0, "node1") + ns := []*structs.NodeService{ + &structs.NodeService{ + ID: "service1", + Service: "redis", + Tags: []string{"prod"}, + Address: "1.1.1.1", + Port: 1111, + }, + &structs.NodeService{ + ID: "service2", + Service: "nomad", + Tags: []string{"dev"}, + Address: "1.1.1.2", + Port: 1112, + }, + } + for i, svc := range ns { + if err := s.EnsureService(uint64(i+1), "node1", svc); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Create a second node/service to make sure node filtering works. This + // will affect the index but not the dump. + testRegisterNode(t, s, 3, "node2") + testRegisterService(t, s, 4, "node2", "service2") + + // Snapshot the service. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + testRegisterService(t, s, 5, "node2", "service3") + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 4 { + t.Fatalf("bad index: %d", idx) + } + services, err := snap.Services("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + for i := 0; i < len(ns); i++ { + svc := services.Next().(*structs.ServiceNode) + if svc == nil { + t.Fatalf("unexpected end of services") + } + + ns[i].CreateIndex, ns[i].ModifyIndex = uint64(i+1), uint64(i+1) + if !reflect.DeepEqual(ns[i], svc.ToNodeService()) { + t.Fatalf("bad: %#v != %#v", svc, ns[i]) + } + } + if services.Next() != nil { + t.Fatalf("unexpected extra services") + } +} + +func TestStateStore_Service_Watches(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 0, "node1") + ns := &structs.NodeService{ + ID: "service2", + Service: "nomad", + Address: "1.1.1.2", + Port: 8000, + } + + // Call functions that update the services table and make sure a watch + // fires each time. + verifyWatch(t, s.getTableWatch("services"), func() { + if err := s.EnsureService(2, "node1", ns); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("services"), func() { + if err := s.DeleteService(3, "node1", "service2"); err != nil { + t.Fatalf("err: %s", err) + } + }) + + // Check that a delete of a service + check triggers both tables in one + // shot. + testRegisterService(t, s, 4, "node1", "service1") + testRegisterCheck(t, s, 5, "node1", "service1", "check3", structs.HealthPassing) + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteService(6, "node1", "service1"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) +} + +func TestStateStore_EnsureCheck(t *testing.T) { + s := testStateStore(t) + + // Create a check associated with the node + check := &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "redis check", + Status: structs.HealthPassing, + Notes: "test check", + Output: "aaa", + ServiceID: "service1", + ServiceName: "redis", + } + + // Creating a check without a node returns error + if err := s.EnsureCheck(1, check); err != ErrMissingNode { + t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) + } + + // Register the node + testRegisterNode(t, s, 1, "node1") + + // Creating a check with a bad services returns error + if err := s.EnsureCheck(1, check); err != ErrMissingService { + t.Fatalf("expected: %#v, got: %#v", ErrMissingService, err) + } + + // Register the service + testRegisterService(t, s, 2, "node1", "service1") + + // Inserting the check with the prerequisites succeeds + if err := s.EnsureCheck(3, check); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the check and make sure it matches + idx, checks, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 1 { + t.Fatalf("wrong number of checks: %d", len(checks)) + } + if !reflect.DeepEqual(checks[0], check) { + t.Fatalf("bad: %#v", checks[0]) + } + + // Modify the health check + check.Output = "bbb" + if err := s.EnsureCheck(4, check); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that we successfully updated + idx, checks, err = s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 1 { + t.Fatalf("wrong number of checks: %d", len(checks)) + } + if checks[0].Output != "bbb" { + t.Fatalf("wrong check output: %#v", checks[0]) + } + if checks[0].CreateIndex != 3 || checks[0].ModifyIndex != 4 { + t.Fatalf("bad index: %#v", checks[0]) + } + + // Index tables were updated + if idx := s.maxIndex("checks"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_EnsureCheck_defaultStatus(t *testing.T) { + s := testStateStore(t) + + // Register a node + testRegisterNode(t, s, 1, "node1") + + // Create and register a check with no health status + check := &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Status: "", + } + if err := s.EnsureCheck(2, check); err != nil { + t.Fatalf("err: %s", err) + } + + // Get the check again + _, result, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Check that the status was set to the proper default + if len(result) != 1 || result[0].Status != structs.HealthCritical { + t.Fatalf("bad: %#v", result) + } +} + +func TestStateStore_NodeChecks(t *testing.T) { + s := testStateStore(t) + + // Create the first node and service with some checks + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) + + // Create a second node/service with a different set of checks + testRegisterNode(t, s, 4, "node2") + testRegisterService(t, s, 5, "node2", "service2") + testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) + + // Try querying for all checks associated with node1 + idx, checks, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 2 || checks[0].CheckID != "check1" || checks[1].CheckID != "check2" { + t.Fatalf("bad checks: %#v", checks) + } + + // Try querying for all checks associated with node2 + idx, checks, err = s.NodeChecks("node2") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 1 || checks[0].CheckID != "check3" { + t.Fatalf("bad checks: %#v", checks) + } +} + +func TestStateStore_ServiceChecks(t *testing.T) { + s := testStateStore(t) + + // Create the first node and service with some checks + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + testRegisterCheck(t, s, 2, "node1", "service1", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 3, "node1", "service1", "check2", structs.HealthPassing) + + // Create a second node/service with a different set of checks + testRegisterNode(t, s, 4, "node2") + testRegisterService(t, s, 5, "node2", "service2") + testRegisterCheck(t, s, 6, "node2", "service2", "check3", structs.HealthPassing) + + // Try querying for all checks associated with service1 + idx, checks, err := s.ServiceChecks("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if len(checks) != 2 || checks[0].CheckID != "check1" || checks[1].CheckID != "check2" { + t.Fatalf("bad checks: %#v", checks) + } +} + +func TestStateStore_ChecksInState(t *testing.T) { + s := testStateStore(t) + + // Querying with no results returns nil + idx, res, err := s.ChecksInState(structs.HealthPassing) + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Register a node with checks in varied states + testRegisterNode(t, s, 0, "node1") + testRegisterCheck(t, s, 1, "node1", "", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 2, "node1", "", "check2", structs.HealthCritical) + testRegisterCheck(t, s, 3, "node1", "", "check3", structs.HealthPassing) + + // Query the state store for passing checks. + _, checks, err := s.ChecksInState(structs.HealthPassing) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Make sure we only get the checks which match the state + if n := len(checks); n != 2 { + t.Fatalf("expected 2 checks, got: %d", n) + } + if checks[0].CheckID != "check1" || checks[1].CheckID != "check3" { + t.Fatalf("bad: %#v", checks) + } + + // HealthAny just returns everything. + _, checks, err = s.ChecksInState(structs.HealthAny) + if err != nil { + t.Fatalf("err: %s", err) + } + if n := len(checks); n != 3 { + t.Fatalf("expected 3 checks, got: %d", n) + } +} + +func TestStateStore_DeleteCheck(t *testing.T) { + s := testStateStore(t) + + // Register a node and a node-level health check + testRegisterNode(t, s, 1, "node1") + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) + + // Delete the check + if err := s.DeleteCheck(3, "node1", "check1"); err != nil { + t.Fatalf("err: %s", err) + } + + // Check is gone + _, checks, err := s.NodeChecks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(checks) != 0 { + t.Fatalf("bad: %#v", checks) + } + + // Index tables were updated + if idx := s.maxIndex("checks"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Deleting a nonexistent check should be idempotent and not return an + // error + if err := s.DeleteCheck(4, "node1", "check1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("checks"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_CheckServiceNodes(t *testing.T) { + s := testStateStore(t) + + // Querying with no matches gives an empty response + idx, res, err := s.CheckServiceNodes("service1") + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Register some nodes + testRegisterNode(t, s, 0, "node1") + testRegisterNode(t, s, 1, "node2") + + // Register node-level checks. These should not be returned + // in the final result. + testRegisterCheck(t, s, 2, "node1", "", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 3, "node2", "", "check2", structs.HealthPassing) + + // Register a service against the nodes + testRegisterService(t, s, 4, "node1", "service1") + testRegisterService(t, s, 5, "node2", "service2") + + // Register checks against the services + testRegisterCheck(t, s, 6, "node1", "service1", "check3", structs.HealthPassing) + testRegisterCheck(t, s, 7, "node2", "service2", "check4", structs.HealthPassing) + + // Query the state store for nodes and checks which + // have been registered with a specific service. + idx, results, err := s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure we get the expected result (service check + node check) + if n := len(results); n != 1 { + t.Fatalf("expected 1 result, got: %d", n) + } + csn := results[0] + if csn.Node == nil || csn.Service == nil || len(csn.Checks) != 2 { + t.Fatalf("bad output: %#v", csn) + } + + // Node updates alter the returned index + testRegisterNode(t, s, 8, "node1") + idx, results, err = s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 8 { + t.Fatalf("bad index: %d", idx) + } + + // Service updates alter the returned index + testRegisterService(t, s, 9, "node1", "service1") + idx, results, err = s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } + + // Check updates alter the returned index + testRegisterCheck(t, s, 10, "node1", "service1", "check1", structs.HealthCritical) + idx, results, err = s.CheckServiceNodes("service1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 10 { + t.Fatalf("bad index: %d", idx) + } +} + +func BenchmarkCheckServiceNodes(b *testing.B) { + s, err := NewStateStore(nil) + if err != nil { + b.Fatalf("err: %s", err) + } + + if err := s.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + b.Fatalf("err: %v", err) + } + if err := s.EnsureService(2, "foo", &structs.NodeService{ID: "db1", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + b.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "db", + Name: "can connect", + Status: structs.HealthPassing, + ServiceID: "db1", + } + if err := s.EnsureCheck(3, check); err != nil { + b.Fatalf("err: %v", err) + } + check = &structs.HealthCheck{ + Node: "foo", + CheckID: "check1", + Name: "check1", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(4, check); err != nil { + b.Fatalf("err: %v", err) + } + + for i := 0; i < b.N; i++ { + s.CheckServiceNodes("db") + } +} + +func TestStateStore_CheckServiceTagNodes(t *testing.T) { + s := testStateStore(t) + + if err := s.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s.EnsureService(2, "foo", &structs.NodeService{ID: "db1", Service: "db", Tags: []string{"master"}, Address: "", Port: 8000}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "db", + Name: "can connect", + Status: structs.HealthPassing, + ServiceID: "db1", + } + if err := s.EnsureCheck(3, check); err != nil { + t.Fatalf("err: %v", err) + } + check = &structs.HealthCheck{ + Node: "foo", + CheckID: "check1", + Name: "another check", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(4, check); err != nil { + t.Fatalf("err: %v", err) + } + + idx, nodes, err := s.CheckServiceTagNodes("db", "master") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad: %v", idx) + } + if len(nodes) != 1 { + t.Fatalf("Bad: %v", nodes) + } + if nodes[0].Node.Node != "foo" { + t.Fatalf("Bad: %v", nodes[0]) + } + if nodes[0].Service.ID != "db1" { + t.Fatalf("Bad: %v", nodes[0]) + } + if len(nodes[0].Checks) != 2 { + t.Fatalf("Bad: %v", nodes[0]) + } + if nodes[0].Checks[0].CheckID != "check1" { + t.Fatalf("Bad: %v", nodes[0]) + } + if nodes[0].Checks[1].CheckID != "db" { + t.Fatalf("Bad: %v", nodes[0]) + } +} + +func TestStateStore_Check_Snapshot(t *testing.T) { + s := testStateStore(t) + + // Create a node, a service, and a service check as well as a node check. + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + checks := structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Name: "node check", + Status: structs.HealthPassing, + }, + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + Name: "service check", + Status: structs.HealthCritical, + ServiceID: "service1", + }, + } + for i, hc := range checks { + if err := s.EnsureCheck(uint64(i+1), hc); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Create a second node/service to make sure node filtering works. This + // will affect the index but not the dump. + testRegisterNode(t, s, 3, "node2") + testRegisterService(t, s, 4, "node2", "service2") + testRegisterCheck(t, s, 5, "node2", "service2", "check3", structs.HealthPassing) + + // Snapshot the checks. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + testRegisterCheck(t, s, 6, "node2", "service2", "check4", structs.HealthPassing) + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 5 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.Checks("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + for i := 0; i < len(checks); i++ { + check := iter.Next().(*structs.HealthCheck) + if check == nil { + t.Fatalf("unexpected end of checks") + } + + checks[i].CreateIndex, checks[i].ModifyIndex = uint64(i+1), uint64(i+1) + if !reflect.DeepEqual(check, checks[i]) { + t.Fatalf("bad: %#v != %#v", check, checks[i]) + } + } + if iter.Next() != nil { + t.Fatalf("unexpected extra checks") + } +} + +func TestStateStore_Check_Watches(t *testing.T) { + s := testStateStore(t) + + testRegisterNode(t, s, 0, "node1") + hc := &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + Status: structs.HealthPassing, + } + + // Call functions that update the checks table and make sure a watch fires + // each time. + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.EnsureCheck(1, hc); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("checks"), func() { + hc.Status = structs.HealthCritical + if err := s.EnsureCheck(2, hc); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteCheck(3, "node1", "check1"); err != nil { + t.Fatalf("err: %s", err) + } + }) +} + +func TestStateStore_NodeInfo_NodeDump(t *testing.T) { + s := testStateStore(t) + + // Generating a node dump that matches nothing returns empty + idx, dump, err := s.NodeInfo("node1") + if idx != 0 || dump != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) + } + idx, dump, err = s.NodeDump() + if idx != 0 || dump != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, dump, err) + } + + // Register some nodes + testRegisterNode(t, s, 0, "node1") + testRegisterNode(t, s, 1, "node2") + + // Register services against them + testRegisterService(t, s, 2, "node1", "service1") + testRegisterService(t, s, 3, "node1", "service2") + testRegisterService(t, s, 4, "node2", "service1") + testRegisterService(t, s, 5, "node2", "service2") + + // Register service-level checks + testRegisterCheck(t, s, 6, "node1", "service1", "check1", structs.HealthPassing) + testRegisterCheck(t, s, 7, "node2", "service1", "check1", structs.HealthPassing) + + // Register node-level checks + testRegisterCheck(t, s, 8, "node1", "", "check2", structs.HealthPassing) + testRegisterCheck(t, s, 9, "node2", "", "check2", structs.HealthPassing) + + // Check that our result matches what we expect. + expect := structs.NodeDump{ + &structs.NodeInfo{ + Node: "node1", + Checks: structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node1", + CheckID: "check1", + ServiceID: "service1", + ServiceName: "service1", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 6, + ModifyIndex: 6, + }, + }, + &structs.HealthCheck{ + Node: "node1", + CheckID: "check2", + ServiceID: "", + ServiceName: "", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 8, + ModifyIndex: 8, + }, + }, + }, + Services: []*structs.NodeService{ + &structs.NodeService{ + ID: "service1", + Service: "service1", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + }, + &structs.NodeService{ + ID: "service2", + Service: "service2", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 3, + ModifyIndex: 3, + }, + }, + }, + }, + &structs.NodeInfo{ + Node: "node2", + Checks: structs.HealthChecks{ + &structs.HealthCheck{ + Node: "node2", + CheckID: "check1", + ServiceID: "service1", + ServiceName: "service1", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 7, + ModifyIndex: 7, + }, + }, + &structs.HealthCheck{ + Node: "node2", + CheckID: "check2", + ServiceID: "", + ServiceName: "", + Status: structs.HealthPassing, + RaftIndex: structs.RaftIndex{ + CreateIndex: 9, + ModifyIndex: 9, + }, + }, + }, + Services: []*structs.NodeService{ + &structs.NodeService{ + ID: "service1", + Service: "service1", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 4, + ModifyIndex: 4, + }, + }, + &structs.NodeService{ + ID: "service2", + Service: "service2", + Address: "1.1.1.1", + Port: 1111, + RaftIndex: structs.RaftIndex{ + CreateIndex: 5, + ModifyIndex: 5, + }, + }, + }, + }, + } + + // Get a dump of just a single node + idx, dump, err = s.NodeInfo("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } + if len(dump) != 1 || !reflect.DeepEqual(dump[0], expect[0]) { + t.Fatalf("bad: %#v", dump) + } + + // Generate a dump of all the nodes + idx, dump, err = s.NodeDump() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", 9) + } + if !reflect.DeepEqual(dump, expect) { + t.Fatalf("bad: %#v", dump[0].Services[0]) + } +} + +func TestStateStore_KVSSet_KVSGet(t *testing.T) { + s := testStateStore(t) + + // Get on an nonexistent key returns nil. + idx, result, err := s.KVSGet("foo") + if result != nil || err != nil || idx != 0 { + t.Fatalf("expected (0, nil, nil), got : (%#v, %#v, %#v)", idx, result, err) + } + + // Write a new K/V entry to the store. + entry := &structs.DirEntry{ + Key: "foo", + Value: []byte("bar"), + } + if err := s.KVSSet(1, entry); err != nil { + t.Fatalf("err: %s", err) + } + + // Retrieve the K/V entry again. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result == nil { + t.Fatalf("expected k/v pair, got nothing") + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Check that the index was injected into the result. + if result.CreateIndex != 1 || result.ModifyIndex != 1 { + t.Fatalf("bad index: %d, %d", result.CreateIndex, result.ModifyIndex) + } + + // Check that the value matches. + if v := string(result.Value); v != "bar" { + t.Fatalf("expected 'bar', got: '%s'", v) + } + + // Updating the entry works and changes the index. + update := &structs.DirEntry{ + Key: "foo", + Value: []byte("baz"), + } + if err := s.KVSSet(2, update); err != nil { + t.Fatalf("err: %s", err) + } + + // Fetch the kv pair and check. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.CreateIndex != 1 || result.ModifyIndex != 2 { + t.Fatalf("bad index: %d, %d", result.CreateIndex, result.ModifyIndex) + } + if v := string(result.Value); v != "baz" { + t.Fatalf("expected 'baz', got '%s'", v) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Attempt to set the session during an update. + update = &structs.DirEntry{ + Key: "foo", + Value: []byte("zoo"), + Session: "nope", + } + if err := s.KVSSet(3, update); err != nil { + t.Fatalf("err: %s", err) + } + + // Fetch the kv pair and check. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.CreateIndex != 1 || result.ModifyIndex != 3 { + t.Fatalf("bad index: %d, %d", result.CreateIndex, result.ModifyIndex) + } + if v := string(result.Value); v != "zoo" { + t.Fatalf("expected 'zoo', got '%s'", v) + } + if result.Session != "" { + t.Fatalf("expected empty session, got '%s", result.Session) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Make a real session and then lock the key to set the session. + testRegisterNode(t, s, 4, "node1") + session := testUUID() + if err := s.SessionCreate(5, &structs.Session{ID: session, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + update = &structs.DirEntry{ + Key: "foo", + Value: []byte("locked"), + Session: session, + } + ok, err := s.KVSLock(6, update) + if !ok || err != nil { + t.Fatalf("didn't get the lock: %v %s", ok, err) + } + + // Fetch the kv pair and check. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.CreateIndex != 1 || result.ModifyIndex != 6 { + t.Fatalf("bad index: %d, %d", result.CreateIndex, result.ModifyIndex) + } + if v := string(result.Value); v != "locked" { + t.Fatalf("expected 'zoo', got '%s'", v) + } + if result.Session != session { + t.Fatalf("expected session, got '%s", result.Session) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Now make an update without the session and make sure it gets applied + // and doesn't take away the session (it is allowed to change the value). + update = &structs.DirEntry{ + Key: "foo", + Value: []byte("stoleit"), + } + if err := s.KVSSet(7, update); err != nil { + t.Fatalf("err: %s", err) + } + + // Fetch the kv pair and check. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.CreateIndex != 1 || result.ModifyIndex != 7 { + t.Fatalf("bad index: %d, %d", result.CreateIndex, result.ModifyIndex) + } + if v := string(result.Value); v != "stoleit" { + t.Fatalf("expected 'zoo', got '%s'", v) + } + if result.Session != session { + t.Fatalf("expected session, got '%s", result.Session) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Fetch a key that doesn't exist and make sure we get the right + // response. + idx, result, err = s.KVSGet("nope") + if result != nil || err != nil || idx != 7 { + t.Fatalf("expected (7, nil, nil), got : (%#v, %#v, %#v)", idx, result, err) + } +} + +func TestStateStore_KVSList(t *testing.T) { + s := testStateStore(t) + + // Listing an empty KVS returns nothing + idx, entries, err := s.KVSList("") + if idx != 0 || entries != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, entries, err) + } + + // Create some KVS entries + testSetKey(t, s, 1, "foo", "foo") + testSetKey(t, s, 2, "foo/bar", "bar") + testSetKey(t, s, 3, "foo/bar/zip", "zip") + testSetKey(t, s, 4, "foo/bar/zip/zorp", "zorp") + testSetKey(t, s, 5, "foo/bar/baz", "baz") + + // List out all of the keys + idx, entries, err = s.KVSList("") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Check that all of the keys were returned + if n := len(entries); n != 5 { + t.Fatalf("expected 5 kvs entries, got: %d", n) + } + + // Try listing with a provided prefix + idx, entries, err = s.KVSList("foo/bar/zip") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Check that only the keys in the prefix were returned + if n := len(entries); n != 2 { + t.Fatalf("expected 2 kvs entries, got: %d", n) + } + if entries[0].Key != "foo/bar/zip" || entries[1].Key != "foo/bar/zip/zorp" { + t.Fatalf("bad: %#v", entries) + } + + // Delete a key and make sure the index comes from the tombstone. + if err := s.KVSDelete(6, "foo/bar/baz"); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSList("foo/bar/baz") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Set a different key to bump the index. + testSetKey(t, s, 7, "some/other/key", "") + + // Make sure we get the right index from the tombstone. + idx, _, err = s.KVSList("foo/bar/baz") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Now reap the tombstones and make sure we get the latest index + // since there are no matching keys. + if err := s.ReapTombstones(6); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSList("foo/bar/baz") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // List all the keys to make sure the index is also correct. + idx, _, err = s.KVSList("") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVSListKeys(t *testing.T) { + s := testStateStore(t) + + // Listing keys with no results returns nil. + idx, keys, err := s.KVSListKeys("", "") + if idx != 0 || keys != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, keys, err) + } + + // Create some keys. + testSetKey(t, s, 1, "foo", "foo") + testSetKey(t, s, 2, "foo/bar", "bar") + testSetKey(t, s, 3, "foo/bar/baz", "baz") + testSetKey(t, s, 4, "foo/bar/zip", "zip") + testSetKey(t, s, 5, "foo/bar/zip/zam", "zam") + testSetKey(t, s, 6, "foo/bar/zip/zorp", "zorp") + testSetKey(t, s, 7, "some/other/prefix", "nack") + + // List all the keys. + idx, keys, err = s.KVSListKeys("", "") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(keys) != 7 { + t.Fatalf("bad keys: %#v", keys) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Query using a prefix and pass a separator. + idx, keys, err = s.KVSListKeys("foo/bar/", "/") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(keys) != 3 { + t.Fatalf("bad keys: %#v", keys) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Subset of the keys was returned. + expect := []string{"foo/bar/baz", "foo/bar/zip", "foo/bar/zip/"} + if !reflect.DeepEqual(keys, expect) { + t.Fatalf("bad keys: %#v", keys) + } + + // Listing keys with no separator returns everything. + idx, keys, err = s.KVSListKeys("foo", "") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + expect = []string{"foo", "foo/bar", "foo/bar/baz", "foo/bar/zip", + "foo/bar/zip/zam", "foo/bar/zip/zorp"} + if !reflect.DeepEqual(keys, expect) { + t.Fatalf("bad keys: %#v", keys) + } + + // Delete a key and make sure the index comes from the tombstone. + if err := s.KVSDelete(8, "foo/bar/baz"); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSListKeys("foo/bar/baz", "") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 8 { + t.Fatalf("bad index: %d", idx) + } + + // Set a different key to bump the index. + testSetKey(t, s, 9, "some/other/key", "") + + // Make sure the index still comes from the tombstone. + idx, _, err = s.KVSListKeys("foo/bar/baz", "") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 8 { + t.Fatalf("bad index: %d", idx) + } + + // Now reap the tombstones and make sure we get the latest index + // since there are no matching keys. + if err := s.ReapTombstones(8); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSListKeys("foo/bar/baz", "") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } + + // List all the keys to make sure the index is also correct. + idx, _, err = s.KVSListKeys("", "") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVSDelete(t *testing.T) { + s := testStateStore(t) + + // Create some KV pairs + testSetKey(t, s, 1, "foo", "foo") + testSetKey(t, s, 2, "foo/bar", "bar") + + // Call a delete on a specific key + if err := s.KVSDelete(3, "foo"); err != nil { + t.Fatalf("err: %s", err) + } + + // The entry was removed from the state store + tx := s.db.Txn(false) + defer tx.Abort() + e, err := tx.First("kvs", "id", "foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if e != nil { + t.Fatalf("expected kvs entry to be deleted, got: %#v", e) + } + + // Try fetching the other keys to ensure they still exist + e, err = tx.First("kvs", "id", "foo/bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if e == nil || string(e.(*structs.DirEntry).Value) != "bar" { + t.Fatalf("bad kvs entry: %#v", e) + } + + // Check that the index table was updated + if idx := s.maxIndex("kvs"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Check that the tombstone was created and that prevents the index + // from sliding backwards. + idx, _, err := s.KVSList("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Now reap the tombstone and watch the index revert to the remaining + // foo/bar key's index. + if err := s.ReapTombstones(3); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSList("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Deleting a nonexistent key should be idempotent and not return an + // error + if err := s.KVSDelete(4, "foo"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("kvs"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVSDeleteCAS(t *testing.T) { + s := testStateStore(t) + + // Create some KV entries + testSetKey(t, s, 1, "foo", "foo") + testSetKey(t, s, 2, "bar", "bar") + testSetKey(t, s, 3, "baz", "baz") + + // Do a CAS delete with an index lower than the entry + ok, err := s.KVSDeleteCAS(4, 1, "bar") + if ok || err != nil { + t.Fatalf("expected (false, nil), got: (%v, %#v)", ok, err) + } + + // Check that the index is untouched and the entry + // has not been deleted. + idx, e, err := s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if e == nil { + t.Fatalf("expected a kvs entry, got nil") + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Do another CAS delete, this time with the correct index + // which should cause the delete to take place. + ok, err = s.KVSDeleteCAS(4, 2, "bar") + if !ok || err != nil { + t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err) + } + + // Entry was deleted and index was updated + idx, e, err = s.KVSGet("bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if e != nil { + t.Fatalf("entry should be deleted") + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Add another key to bump the index. + testSetKey(t, s, 5, "some/other/key", "baz") + + // Check that the tombstone was created and that prevents the index + // from sliding backwards. + idx, _, err = s.KVSList("bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Now reap the tombstone and watch the index move up to the table + // index since there are no matching keys. + if err := s.ReapTombstones(4); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSList("bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // A delete on a nonexistent key should be idempotent and not return an + // error + ok, err = s.KVSDeleteCAS(6, 2, "bar") + if !ok || err != nil { + t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err) + } + if idx := s.maxIndex("kvs"); idx != 5 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVSSetCAS(t *testing.T) { + s := testStateStore(t) + + // Doing a CAS with ModifyIndex != 0 and no existing entry + // is a no-op. + entry := &structs.DirEntry{ + Key: "foo", + Value: []byte("foo"), + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + } + ok, err := s.KVSSetCAS(2, entry) + if ok || err != nil { + t.Fatalf("expected (false, nil), got: (%#v, %#v)", ok, err) + } + + // Check that nothing was actually stored + tx := s.db.Txn(false) + if e, err := tx.First("kvs", "id", "foo"); e != nil || err != nil { + t.Fatalf("expected (nil, nil), got: (%#v, %#v)", e, err) + } + tx.Abort() + + // Index was not updated + if idx := s.maxIndex("kvs"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Doing a CAS with a ModifyIndex of zero when no entry exists + // performs the set and saves into the state store. + entry = &structs.DirEntry{ + Key: "foo", + Value: []byte("foo"), + RaftIndex: structs.RaftIndex{ + CreateIndex: 0, + ModifyIndex: 0, + }, + } + ok, err = s.KVSSetCAS(2, entry) + if !ok || err != nil { + t.Fatalf("expected (true, nil), got: (%#v, %#v)", ok, err) + } + + // Entry was inserted + idx, entry, err := s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if string(entry.Value) != "foo" || entry.CreateIndex != 2 || entry.ModifyIndex != 2 { + t.Fatalf("bad entry: %#v", entry) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Doing a CAS with a ModifyIndex of zero when an entry exists does + // not do anything. + entry = &structs.DirEntry{ + Key: "foo", + Value: []byte("foo"), + RaftIndex: structs.RaftIndex{ + CreateIndex: 0, + ModifyIndex: 0, + }, + } + ok, err = s.KVSSetCAS(3, entry) + if ok || err != nil { + t.Fatalf("expected (false, nil), got: (%#v, %#v)", ok, err) + } + + // Doing a CAS with a ModifyIndex which does not match the current + // index does not do anything. + entry = &structs.DirEntry{ + Key: "foo", + Value: []byte("bar"), + RaftIndex: structs.RaftIndex{ + CreateIndex: 3, + ModifyIndex: 3, + }, + } + ok, err = s.KVSSetCAS(3, entry) + if ok || err != nil { + t.Fatalf("expected (false, nil), got: (%#v, %#v)", ok, err) + } + + // Entry was not updated in the store + idx, entry, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if string(entry.Value) != "foo" || entry.CreateIndex != 2 || entry.ModifyIndex != 2 { + t.Fatalf("bad entry: %#v", entry) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Doing a CAS with the proper current index should make the + // modification. + entry = &structs.DirEntry{ + Key: "foo", + Value: []byte("bar"), + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + } + ok, err = s.KVSSetCAS(3, entry) + if !ok || err != nil { + t.Fatalf("expected (true, nil), got: (%#v, %#v)", ok, err) + } + + // Entry was updated + idx, entry, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if string(entry.Value) != "bar" || entry.CreateIndex != 2 || entry.ModifyIndex != 3 { + t.Fatalf("bad entry: %#v", entry) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Attempt to update the session during the CAS. + entry = &structs.DirEntry{ + Key: "foo", + Value: []byte("zoo"), + Session: "nope", + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 3, + }, + } + ok, err = s.KVSSetCAS(4, entry) + if !ok || err != nil { + t.Fatalf("expected (true, nil), got: (%#v, %#v)", ok, err) + } + + // Entry was updated, but the session should have been ignored. + idx, entry, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if string(entry.Value) != "zoo" || entry.CreateIndex != 2 || entry.ModifyIndex != 4 || + entry.Session != "" { + t.Fatalf("bad entry: %#v", entry) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Now lock it and try the update, which should keep the session. + testRegisterNode(t, s, 5, "node1") + session := testUUID() + if err := s.SessionCreate(6, &structs.Session{ID: session, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + entry = &structs.DirEntry{ + Key: "foo", + Value: []byte("locked"), + Session: session, + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 4, + }, + } + ok, err = s.KVSLock(6, entry) + if !ok || err != nil { + t.Fatalf("didn't get the lock: %v %s", ok, err) + } + entry = &structs.DirEntry{ + Key: "foo", + Value: []byte("locked"), + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 6, + }, + } + ok, err = s.KVSSetCAS(7, entry) + if !ok || err != nil { + t.Fatalf("expected (true, nil), got: (%#v, %#v)", ok, err) + } + + // Entry was updated, and the lock status should have stayed the same. + idx, entry, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if string(entry.Value) != "locked" || entry.CreateIndex != 2 || entry.ModifyIndex != 7 || + entry.Session != session { + t.Fatalf("bad entry: %#v", entry) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVSDeleteTree(t *testing.T) { + s := testStateStore(t) + + // Create kvs entries in the state store + testSetKey(t, s, 1, "foo/bar", "bar") + testSetKey(t, s, 2, "foo/bar/baz", "baz") + testSetKey(t, s, 3, "foo/bar/zip", "zip") + testSetKey(t, s, 4, "foo/zorp", "zorp") + + // Calling tree deletion which affects nothing does not + // modify the table index. + if err := s.KVSDeleteTree(9, "bar"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("kvs"); idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Call tree deletion with a nested prefix. + if err := s.KVSDeleteTree(5, "foo/bar"); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that all the matching keys were deleted + tx := s.db.Txn(false) + defer tx.Abort() + + entries, err := tx.Get("kvs", "id") + if err != nil { + t.Fatalf("err: %s", err) + } + + num := 0 + for entry := entries.Next(); entry != nil; entry = entries.Next() { + if entry.(*structs.DirEntry).Key != "foo/zorp" { + t.Fatalf("unexpected kvs entry: %#v", entry) + } + num++ + } + + if num != 1 { + t.Fatalf("expected 1 key, got: %d", num) + } + + // Index should be updated if modifications are made + if idx := s.maxIndex("kvs"); idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Check that the tombstones ware created and that prevents the index + // from sliding backwards. + idx, _, err := s.KVSList("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Now reap the tombstones and watch the index revert to the remaining + // foo/zorp key's index. + if err := s.ReapTombstones(5); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSList("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVSLockDelay(t *testing.T) { + s := testStateStore(t) + + // KVSLockDelay is exercised in the lock/unlock and session invalidation + // cases below, so we just do a basic check on a nonexistent key here. + expires := s.KVSLockDelay("/not/there") + if expires.After(time.Now()) { + t.Fatalf("bad: %v", expires) + } +} + +func TestStateStore_KVSLock(t *testing.T) { + s := testStateStore(t) + + // Lock with no session should fail. + ok, err := s.KVSLock(0, &structs.DirEntry{Key: "foo", Value: []byte("foo")}) + if ok || err == nil || !strings.Contains(err.Error(), "missing session") { + t.Fatalf("didn't detect missing session: %v %s", ok, err) + } + + // Now try with a bogus session. + ok, err = s.KVSLock(1, &structs.DirEntry{Key: "foo", Value: []byte("foo"), Session: testUUID()}) + if ok || err == nil || !strings.Contains(err.Error(), "invalid session") { + t.Fatalf("didn't detect invalid session: %v %s", ok, err) + } + + // Make a real session. + testRegisterNode(t, s, 2, "node1") + session1 := testUUID() + if err := s.SessionCreate(3, &structs.Session{ID: session1, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + + // Lock and make the key at the same time. + ok, err = s.KVSLock(4, &structs.DirEntry{Key: "foo", Value: []byte("foo"), Session: session1}) + if !ok || err != nil { + t.Fatalf("didn't get the lock: %v %s", ok, err) + } + + // Make sure the indexes got set properly. + idx, result, err := s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 1 || result.CreateIndex != 4 || result.ModifyIndex != 4 || + string(result.Value) != "foo" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Re-locking with the same session should update the value and report + // success. + ok, err = s.KVSLock(5, &structs.DirEntry{Key: "foo", Value: []byte("bar"), Session: session1}) + if !ok || err != nil { + t.Fatalf("didn't handle locking an already-locked key: %v %s", ok, err) + } + + // Make sure the indexes got set properly, note that the lock index + // won't go up since we didn't lock it again. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 1 || result.CreateIndex != 4 || result.ModifyIndex != 5 || + string(result.Value) != "bar" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 5 { + t.Fatalf("bad index: %d", idx) + } + + // Unlock and the re-lock. + ok, err = s.KVSUnlock(6, &structs.DirEntry{Key: "foo", Value: []byte("baz"), Session: session1}) + if !ok || err != nil { + t.Fatalf("didn't handle unlocking a locked key: %v %s", ok, err) + } + ok, err = s.KVSLock(7, &structs.DirEntry{Key: "foo", Value: []byte("zoo"), Session: session1}) + if !ok || err != nil { + t.Fatalf("didn't get the lock: %v %s", ok, err) + } + + // Make sure the indexes got set properly. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 2 || result.CreateIndex != 4 || result.ModifyIndex != 7 || + string(result.Value) != "zoo" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Lock an existing key. + testSetKey(t, s, 8, "bar", "bar") + ok, err = s.KVSLock(9, &structs.DirEntry{Key: "bar", Value: []byte("xxx"), Session: session1}) + if !ok || err != nil { + t.Fatalf("didn't get the lock: %v %s", ok, err) + } + + // Make sure the indexes got set properly. + idx, result, err = s.KVSGet("bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 1 || result.CreateIndex != 8 || result.ModifyIndex != 9 || + string(result.Value) != "xxx" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } + + // Attempting a re-lock with a different session should also fail. + session2 := testUUID() + if err := s.SessionCreate(10, &structs.Session{ID: session2, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + + // Re-locking should not return an error, but will report that it didn't + // get the lock. + ok, err = s.KVSLock(11, &structs.DirEntry{Key: "bar", Value: []byte("nope"), Session: session2}) + if ok || err != nil { + t.Fatalf("didn't handle locking an already-locked key: %v %s", ok, err) + } + + // Make sure the indexes didn't update. + idx, result, err = s.KVSGet("bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 1 || result.CreateIndex != 8 || result.ModifyIndex != 9 || + string(result.Value) != "xxx" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVSUnlock(t *testing.T) { + s := testStateStore(t) + + // Unlock with no session should fail. + ok, err := s.KVSUnlock(0, &structs.DirEntry{Key: "foo", Value: []byte("bar")}) + if ok || err == nil || !strings.Contains(err.Error(), "missing session") { + t.Fatalf("didn't detect missing session: %v %s", ok, err) + } + + // Make a real session. + testRegisterNode(t, s, 1, "node1") + session1 := testUUID() + if err := s.SessionCreate(2, &structs.Session{ID: session1, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + + // Unlock with a real session but no key should not return an error, but + // will report it didn't unlock anything. + ok, err = s.KVSUnlock(3, &structs.DirEntry{Key: "foo", Value: []byte("bar"), Session: session1}) + if ok || err != nil { + t.Fatalf("didn't handle unlocking a missing key: %v %s", ok, err) + } + + // Make a key and unlock it, without it being locked. + testSetKey(t, s, 4, "foo", "bar") + ok, err = s.KVSUnlock(5, &structs.DirEntry{Key: "foo", Value: []byte("baz"), Session: session1}) + if ok || err != nil { + t.Fatalf("didn't handle unlocking a non-locked key: %v %s", ok, err) + } + + // Make sure the indexes didn't update. + idx, result, err := s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 0 || result.CreateIndex != 4 || result.ModifyIndex != 4 || + string(result.Value) != "bar" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Lock it with the first session. + ok, err = s.KVSLock(6, &structs.DirEntry{Key: "foo", Value: []byte("bar"), Session: session1}) + if !ok || err != nil { + t.Fatalf("didn't get the lock: %v %s", ok, err) + } + + // Attempt an unlock with another session. + session2 := testUUID() + if err := s.SessionCreate(7, &structs.Session{ID: session2, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + ok, err = s.KVSUnlock(8, &structs.DirEntry{Key: "foo", Value: []byte("zoo"), Session: session2}) + if ok || err != nil { + t.Fatalf("didn't handle unlocking with the wrong session: %v %s", ok, err) + } + + // Make sure the indexes didn't update. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 1 || result.CreateIndex != 4 || result.ModifyIndex != 6 || + string(result.Value) != "bar" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Now do the unlock with the correct session. + ok, err = s.KVSUnlock(9, &structs.DirEntry{Key: "foo", Value: []byte("zoo"), Session: session1}) + if !ok || err != nil { + t.Fatalf("didn't handle unlocking with the correct session: %v %s", ok, err) + } + + // Make sure the indexes got set properly. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 1 || result.CreateIndex != 4 || result.ModifyIndex != 9 || + string(result.Value) != "zoo" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } + + // Unlocking again should fail and not change anything. + ok, err = s.KVSUnlock(10, &structs.DirEntry{Key: "foo", Value: []byte("nope"), Session: session1}) + if ok || err != nil { + t.Fatalf("didn't handle unlocking with the previous session: %v %s", ok, err) + } + + // Make sure the indexes didn't update. + idx, result, err = s.KVSGet("foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if result.LockIndex != 1 || result.CreateIndex != 4 || result.ModifyIndex != 9 || + string(result.Value) != "zoo" { + t.Fatalf("bad entry: %#v", result) + } + if idx != 9 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_KVS_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Build up some entries to seed. + entries := structs.DirEntries{ + &structs.DirEntry{ + Key: "aaa", + Flags: 23, + Value: []byte("hello"), + }, + &structs.DirEntry{ + Key: "bar/a", + Value: []byte("one"), + }, + &structs.DirEntry{ + Key: "bar/b", + Value: []byte("two"), + }, + &structs.DirEntry{ + Key: "bar/c", + Value: []byte("three"), + }, + } + for i, entry := range entries { + if err := s.KVSSet(uint64(i+1), entry); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Make a node and session so we can test a locked key. + testRegisterNode(t, s, 5, "node1") + session := testUUID() + if err := s.SessionCreate(6, &structs.Session{ID: session, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + entries[3].Session = session + if ok, err := s.KVSLock(7, entries[3]); !ok || err != nil { + t.Fatalf("didn't get the lock: %v %s", ok, err) + } + + // This is required for the compare later. + entries[3].LockIndex = 1 + + // Snapshot the keys. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + if err := s.KVSSet(8, &structs.DirEntry{Key: "aaa", Value: []byte("nope")}); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.KVs() + if err != nil { + t.Fatalf("err: %s", err) + } + var dump structs.DirEntries + for entry := iter.Next(); entry != nil; entry = iter.Next() { + dump = append(dump, entry.(*structs.DirEntry)) + } + if !reflect.DeepEqual(dump, entries) { + t.Fatalf("bad: %#v", dump) + } + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, entry := range dump { + if err := restore.KVS(entry); err != nil { + t.Fatalf("err: %s", err) + } + } + restore.Commit() + + // Read the restored keys back out and verify they match. + idx, res, err := s.KVSList("") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(res, entries) { + t.Fatalf("bad: %#v", res) + } + + // Check that the index was updated. + if idx := s.maxIndex("kvs"); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + }() +} + +func TestStateStore_KVS_Watches(t *testing.T) { + s := testStateStore(t) + + // This is used when locking down below. + testRegisterNode(t, s, 1, "node1") + session := testUUID() + if err := s.SessionCreate(2, &structs.Session{ID: session, Node: "node1"}); err != nil { + t.Fatalf("err: %s", err) + } + + // An empty prefix watch should hit on all KVS ops, and some other + // prefix should not be affected ever. We also add a positive prefix + // match. + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { + if err := s.KVSSet(1, &structs.DirEntry{Key: "aaa"}); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { + if err := s.KVSSet(2, &structs.DirEntry{Key: "aaa"}); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + + // Restore just fires off a top-level watch, so we should get hits on + // any prefix, including ones for keys that aren't in there. + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("b"), func() { + verifyWatch(t, s.GetKVSWatch("/nope"), func() { + restore := s.Restore() + if err := restore.KVS(&structs.DirEntry{Key: "bbb"}); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) + }) + }) + + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { + if err := s.KVSDelete(3, "aaa"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { + if ok, err := s.KVSSetCAS(4, &structs.DirEntry{Key: "aaa"}); !ok || err != nil { + t.Fatalf("ok: %v err: %s", ok, err) + } + }) + }) + }) + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { + if ok, err := s.KVSLock(5, &structs.DirEntry{Key: "aaa", Session: session}); !ok || err != nil { + t.Fatalf("ok: %v err: %s", ok, err) + } + }) + }) + }) + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { + if ok, err := s.KVSUnlock(6, &structs.DirEntry{Key: "aaa", Session: session}); !ok || err != nil { + t.Fatalf("ok: %v err: %s", ok, err) + } + }) + }) + }) + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyNoWatch(t, s.GetKVSWatch("/nope"), func() { + if err := s.KVSDeleteTree(7, "aaa"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + + // A delete tree operation at the top level will notify all the watches. + verifyWatch(t, s.GetKVSWatch(""), func() { + verifyWatch(t, s.GetKVSWatch("a"), func() { + verifyWatch(t, s.GetKVSWatch("/nope"), func() { + if err := s.KVSDeleteTree(8, ""); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + + // Create a more interesting tree. + testSetKey(t, s, 9, "foo/bar", "bar") + testSetKey(t, s, 10, "foo/bar/baz", "baz") + testSetKey(t, s, 11, "foo/bar/zip", "zip") + testSetKey(t, s, 12, "foo/zorp", "zorp") + + // Deleting just the foo/bar key should not trigger watches on the + // children. + verifyWatch(t, s.GetKVSWatch("foo/bar"), func() { + verifyNoWatch(t, s.GetKVSWatch("foo/bar/baz"), func() { + verifyNoWatch(t, s.GetKVSWatch("foo/bar/zip"), func() { + if err := s.KVSDelete(13, "foo/bar"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + + // But a delete tree from that point should notify the whole subtree, + // even for keys that don't exist. + verifyWatch(t, s.GetKVSWatch("foo/bar"), func() { + verifyWatch(t, s.GetKVSWatch("foo/bar/baz"), func() { + verifyWatch(t, s.GetKVSWatch("foo/bar/zip"), func() { + verifyWatch(t, s.GetKVSWatch("foo/bar/uh/nope"), func() { + if err := s.KVSDeleteTree(14, "foo/bar"); err != nil { + t.Fatalf("err: %s", err) + } + }) + }) + }) + }) +} + +func TestStateStore_Tombstone_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Insert a key and then delete it to create a tombstone. + testSetKey(t, s, 1, "foo/bar", "bar") + testSetKey(t, s, 2, "foo/bar/baz", "bar") + testSetKey(t, s, 3, "foo/bar/zoo", "bar") + if err := s.KVSDelete(4, "foo/bar"); err != nil { + t.Fatalf("err: %s", err) + } + + // Snapshot the Tombstones. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + if err := s.ReapTombstones(4); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err := s.KVSList("foo/bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Verify the snapshot. + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + var dump []*Tombstone + for stone := stones.Next(); stone != nil; stone = stones.Next() { + dump = append(dump, stone.(*Tombstone)) + } + if len(dump) != 1 { + t.Fatalf("bad %#v", dump) + } + stone := dump[0] + if stone.Key != "foo/bar" || stone.Index != 4 { + t.Fatalf("bad: %#v", stone) + } + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, stone := range dump { + if err := restore.Tombstone(stone); err != nil { + t.Fatalf("err: %s", err) + } + } + restore.Commit() + + // See if the stone works properly in a list query. + idx, _, err := s.KVSList("foo/bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure it reaps correctly. We should still get a 4 for + // the index here because it will be using the last index from + // the tombstone table. + if err := s.ReapTombstones(4); err != nil { + t.Fatalf("err: %s", err) + } + idx, _, err = s.KVSList("foo/bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 4 { + t.Fatalf("bad index: %d", idx) + } + + // But make sure the tombstone is actually gone. + snap := s.Snapshot() + defer snap.Close() + stones, err := snap.Tombstones() + if err != nil { + t.Fatalf("err: %s", err) + } + if stones.Next() != nil { + t.Fatalf("unexpected extra tombstones") + } + }() +} + +func TestStateStore_SessionCreate_SessionGet(t *testing.T) { + s := testStateStore(t) + + // SessionGet returns nil if the session doesn't exist + idx, session, err := s.SessionGet(testUUID()) + if session != nil || err != nil { + t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err) + } + if idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Registering without a session ID is disallowed + err = s.SessionCreate(1, &structs.Session{}) + if err != ErrMissingSessionID { + t.Fatalf("expected %#v, got: %#v", ErrMissingSessionID, err) + } + + // Invalid session behavior throws error + sess := &structs.Session{ + ID: testUUID(), + Behavior: "nope", + } + err = s.SessionCreate(1, sess) + if err == nil || !strings.Contains(err.Error(), "session behavior") { + t.Fatalf("expected session behavior error, got: %#v", err) + } + + // Registering with an unknown node is disallowed + sess = &structs.Session{ID: testUUID()} + if err := s.SessionCreate(1, sess); err != ErrMissingNode { + t.Fatalf("expected %#v, got: %#v", ErrMissingNode, err) + } + + // None of the errored operations modified the index + if idx := s.maxIndex("sessions"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Valid session is able to register + testRegisterNode(t, s, 1, "node1") + sess = &structs.Session{ + ID: testUUID(), + Node: "node1", + } + if err := s.SessionCreate(2, sess); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("sessions"); idx != 2 { + t.Fatalf("bad index: %s", err) + } + + // Retrieve the session again + idx, session, err = s.SessionGet(sess.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Ensure the session looks correct and was assigned the + // proper default value for session behavior. + expect := &structs.Session{ + ID: sess.ID, + Behavior: structs.SessionKeysRelease, + Node: "node1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + } + if !reflect.DeepEqual(expect, session) { + t.Fatalf("bad session: %#v", session) + } + + // Registering with a non-existent check is disallowed + sess = &structs.Session{ + ID: testUUID(), + Node: "node1", + Checks: []string{"check1"}, + } + err = s.SessionCreate(3, sess) + if err == nil || !strings.Contains(err.Error(), "Missing check") { + t.Fatalf("expected missing check error, got: %#v", err) + } + + // Registering with a critical check is disallowed + testRegisterCheck(t, s, 3, "node1", "", "check1", structs.HealthCritical) + err = s.SessionCreate(4, sess) + if err == nil || !strings.Contains(err.Error(), structs.HealthCritical) { + t.Fatalf("expected critical state error, got: %#v", err) + } + + // Registering with a healthy check succeeds + testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) + if err := s.SessionCreate(5, sess); err != nil { + t.Fatalf("err: %s", err) + } + + // Register a session against two checks. + testRegisterCheck(t, s, 5, "node1", "", "check2", structs.HealthPassing) + sess2 := &structs.Session{ + ID: testUUID(), + Node: "node1", + Checks: []string{"check1", "check2"}, + } + if err := s.SessionCreate(6, sess2); err != nil { + t.Fatalf("err: %s", err) + } + + tx := s.db.Txn(false) + defer tx.Abort() + + // Check mappings were inserted + { + check, err := tx.First("session_checks", "session", sess.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + if check == nil { + t.Fatalf("missing session check") + } + expectCheck := &sessionCheck{ + Node: "node1", + CheckID: "check1", + Session: sess.ID, + } + if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { + t.Fatalf("expected %#v, got: %#v", expectCheck, actual) + } + } + checks, err := tx.Get("session_checks", "session", sess2.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + for i, check := 0, checks.Next(); check != nil; i, check = i+1, checks.Next() { + expectCheck := &sessionCheck{ + Node: "node1", + CheckID: fmt.Sprintf("check%d", i+1), + Session: sess2.ID, + } + if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { + t.Fatalf("expected %#v, got: %#v", expectCheck, actual) + } + } + + // Pulling a nonexistent session gives the table index. + idx, session, err = s.SessionGet(testUUID()) + if err != nil { + t.Fatalf("err: %s", err) + } + if session != nil { + t.Fatalf("expected not to get a session: %v", session) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } +} + +func TegstStateStore_SessionList(t *testing.T) { + s := testStateStore(t) + + // Listing when no sessions exist returns nil + idx, res, err := s.SessionList() + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Register some nodes + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + testRegisterNode(t, s, 3, "node3") + + // Create some sessions in the state store + sessions := structs.Sessions{ + &structs.Session{ + ID: testUUID(), + Node: "node1", + Behavior: structs.SessionKeysDelete, + }, + &structs.Session{ + ID: testUUID(), + Node: "node2", + Behavior: structs.SessionKeysRelease, + }, + &structs.Session{ + ID: testUUID(), + Node: "node3", + Behavior: structs.SessionKeysDelete, + }, + } + for i, session := range sessions { + if err := s.SessionCreate(uint64(4+i), session); err != nil { + t.Fatalf("err: %s", err) + } + } + + // List out all of the sessions + idx, sessionList, err := s.SessionList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(sessionList, sessions) { + t.Fatalf("bad: %#v", sessions) + } +} + +func TestStateStore_NodeSessions(t *testing.T) { + s := testStateStore(t) + + // Listing sessions with no results returns nil + idx, res, err := s.NodeSessions("node1") + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Create the nodes + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + + // Register some sessions with the nodes + sessions1 := structs.Sessions{ + &structs.Session{ + ID: testUUID(), + Node: "node1", + }, + &structs.Session{ + ID: testUUID(), + Node: "node1", + }, + } + sessions2 := []*structs.Session{ + &structs.Session{ + ID: testUUID(), + Node: "node2", + }, + &structs.Session{ + ID: testUUID(), + Node: "node2", + }, + } + for i, sess := range append(sessions1, sessions2...) { + if err := s.SessionCreate(uint64(3+i), sess); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Query all of the sessions associated with a specific + // node in the state store. + idx, res, err = s.NodeSessions("node1") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(res) != len(sessions1) { + t.Fatalf("bad: %#v", res) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + idx, res, err = s.NodeSessions("node2") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(res) != len(sessions2) { + t.Fatalf("bad: %#v", res) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_SessionDestroy(t *testing.T) { + s := testStateStore(t) + + // Session destroy is idempotent and returns no error + // if the session doesn't exist. + if err := s.SessionDestroy(1, testUUID()); err != nil { + t.Fatalf("err: %s", err) + } + + // Ensure the index was not updated if nothing was destroyed. + if idx := s.maxIndex("sessions"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Register a node. + testRegisterNode(t, s, 1, "node1") + + // Register a new session + sess := &structs.Session{ + ID: testUUID(), + Node: "node1", + } + if err := s.SessionCreate(2, sess); err != nil { + t.Fatalf("err: %s", err) + } + + // Destroy the session. + if err := s.SessionDestroy(3, sess.ID); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that the index was updated + if idx := s.maxIndex("sessions"); idx != 3 { + t.Fatalf("bad index: %d", idx) + } + + // Make sure the session is really gone. + tx := s.db.Txn(false) + sessions, err := tx.Get("sessions", "id") + if err != nil || sessions.Next() != nil { + t.Fatalf("session should not exist") + } + tx.Abort() +} + +func TestStateStore_Session_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Register some nodes and checks. + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + testRegisterNode(t, s, 3, "node3") + testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) + + // Create some sessions in the state store. + session1 := testUUID() + sessions := structs.Sessions{ + &structs.Session{ + ID: session1, + Node: "node1", + Behavior: structs.SessionKeysDelete, + Checks: []string{"check1"}, + }, + &structs.Session{ + ID: testUUID(), + Node: "node2", + Behavior: structs.SessionKeysRelease, + LockDelay: 10 * time.Second, + }, + &structs.Session{ + ID: testUUID(), + Node: "node3", + Behavior: structs.SessionKeysDelete, + TTL: "1.5s", + }, + } + for i, session := range sessions { + if err := s.SessionCreate(uint64(5+i), session); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Snapshot the sessions. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + if err := s.SessionDestroy(8, session1); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.Sessions() + if err != nil { + t.Fatalf("err: %s", err) + } + var dump structs.Sessions + for session := iter.Next(); session != nil; session = iter.Next() { + sess := session.(*structs.Session) + dump = append(dump, sess) + + found := false + for i, _ := range sessions { + if sess.ID == sessions[i].ID { + if !reflect.DeepEqual(sess, sessions[i]) { + t.Fatalf("bad: %#v", sess) + } + found = true + } + } + if !found { + t.Fatalf("bad: %#v", sess) + } + } + + // Restore the sessions into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, session := range dump { + if err := restore.Session(session); err != nil { + t.Fatalf("err: %s", err) + } + } + restore.Commit() + + // Read the restored sessions back out and verify that they + // match. + idx, res, err := s.SessionList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + for _, sess := range res { + found := false + for i, _ := range sessions { + if sess.ID == sessions[i].ID { + if !reflect.DeepEqual(sess, sessions[i]) { + t.Fatalf("bad: %#v", sess) + } + found = true + } + } + if !found { + t.Fatalf("bad: %#v", sess) + } + } + + // Check that the index was updated. + if idx := s.maxIndex("sessions"); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Manually verify that the session check mapping got restored. + tx := s.db.Txn(false) + defer tx.Abort() + + check, err := tx.First("session_checks", "session", session1) + if err != nil { + t.Fatalf("err: %s", err) + } + if check == nil { + t.Fatalf("missing session check") + } + expectCheck := &sessionCheck{ + Node: "node1", + CheckID: "check1", + Session: session1, + } + if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { + t.Fatalf("expected %#v, got: %#v", expectCheck, actual) + } + }() +} + +func TestStateStore_Session_Watches(t *testing.T) { + s := testStateStore(t) + + // Register a test node. + testRegisterNode(t, s, 1, "node1") + + // This just covers the basics. The session invalidation tests above + // cover the more nuanced multiple table watches. + session := testUUID() + verifyWatch(t, s.getTableWatch("sessions"), func() { + sess := &structs.Session{ + ID: session, + Node: "node1", + Behavior: structs.SessionKeysDelete, + } + if err := s.SessionCreate(2, sess); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("sessions"), func() { + if err := s.SessionDestroy(3, session); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("sessions"), func() { + restore := s.Restore() + sess := &structs.Session{ + ID: session, + Node: "node1", + Behavior: structs.SessionKeysDelete, + } + if err := restore.Session(sess); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) +} + +func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the node and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + if err := s.DeleteNode(15, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(11, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := s.EnsureService(12, "foo", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "api", + Name: "Can connect", + Status: structs.HealthPassing, + ServiceID: "api", + } + if err := s.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + Checks: []string{"api"}, + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the service and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("services"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteService(15, "foo", "api"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "bar", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + Checks: []string{"bar"}, + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Invalidate the check and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + check.Status = structs.HealthCritical + if err := s.EnsureCheck(15, check); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } +} + +func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "bar", + Status: structs.HealthPassing, + } + if err := s.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + Checks: []string{"bar"}, + } + if err := s.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the check and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("checks"), func() { + if err := s.DeleteCheck(15, "foo", "bar"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 15 { + t.Fatalf("bad index: %d", idx) + } + + // Manually make sure the session checks mapping is clear. + tx := s.db.Txn(false) + mapping, err := tx.First("session_checks", "session", session.ID) + if err != nil { + t.Fatalf("err: %s", err) + } + if mapping != nil { + t.Fatalf("unexpected session check") + } + tx.Abort() +} + +func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + LockDelay: 50 * time.Millisecond, + } + if err := s.SessionCreate(4, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Lock a key with the session. + d := &structs.DirEntry{ + Key: "/foo", + Flags: 42, + Value: []byte("test"), + Session: session.ID, + } + ok, err := s.KVSLock(5, d) + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("unexpected fail") + } + + // Delete the node and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.GetKVSWatch("/f"), func() { + if err := s.DeleteNode(6, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should be unlocked. + idx, d2, err := s.KVSGet("/foo") + if err != nil { + t.Fatalf("err: %s", err) + } + if d2.ModifyIndex != 6 { + t.Fatalf("bad index: %v", d2.ModifyIndex) + } + if d2.LockIndex != 1 { + t.Fatalf("bad: %v", *d2) + } + if d2.Session != "" { + t.Fatalf("bad: %v", *d2) + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should have a lock delay. + expires := s.KVSLockDelay("/foo") + if expires.Before(time.Now().Add(30 * time.Millisecond)) { + t.Fatalf("Bad: %v", expires) + } +} + +func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) { + s := testStateStore(t) + + // Set up our test environment. + if err := s.EnsureNode(3, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + session := &structs.Session{ + ID: testUUID(), + Node: "foo", + LockDelay: 50 * time.Millisecond, + Behavior: structs.SessionKeysDelete, + } + if err := s.SessionCreate(4, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Lock a key with the session. + d := &structs.DirEntry{ + Key: "/bar", + Flags: 42, + Value: []byte("test"), + Session: session.ID, + } + ok, err := s.KVSLock(5, d) + if err != nil { + t.Fatalf("err: %v", err) + } + if !ok { + t.Fatalf("unexpected fail") + } + + // Delete the node and make sure the watches fire. + verifyWatch(t, s.getTableWatch("sessions"), func() { + verifyWatch(t, s.getTableWatch("nodes"), func() { + verifyWatch(t, s.GetKVSWatch("/b"), func() { + if err := s.DeleteNode(6, "foo"); err != nil { + t.Fatalf("err: %v", err) + } + }) + }) + }) + + // Lookup by ID, should be nil. + idx, s2, err := s.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should be deleted. + idx, d2, err := s.KVSGet("/bar") + if err != nil { + t.Fatalf("err: %s", err) + } + if d2 != nil { + t.Fatalf("unexpected deleted key") + } + if idx != 6 { + t.Fatalf("bad index: %d", idx) + } + + // Key should have a lock delay. + expires := s.KVSLockDelay("/bar") + if expires.Before(time.Now().Add(30 * time.Millisecond)) { + t.Fatalf("Bad: %v", expires) + } +} + +func TestStateStore_ACLSet_ACLGet(t *testing.T) { + s := testStateStore(t) + + // Querying ACLs with no results returns nil + idx, res, err := s.ACLGet("nope") + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Inserting an ACL with empty ID is disallowed + if err := s.ACLSet(1, &structs.ACL{}); err == nil { + t.Fatalf("expected %#v, got: %#v", ErrMissingACLID, err) + } + + // Index is not updated if nothing is saved + if idx := s.maxIndex("acls"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Inserting valid ACL works + acl := &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules1", + } + if err := s.ACLSet(1, acl); err != nil { + t.Fatalf("err: %s", err) + } + + // Check that the index was updated + if idx := s.maxIndex("acls"); idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Retrieve the ACL again + idx, result, err := s.ACLGet("acl1") + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 1 { + t.Fatalf("bad index: %d", idx) + } + + // Check that the ACL matches the result + expect := &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + } + if !reflect.DeepEqual(result, expect) { + t.Fatalf("bad: %#v", result) + } + + // Update the ACL + acl = &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules2", + } + if err := s.ACLSet(2, acl); err != nil { + t.Fatalf("err: %s", err) + } + + // Index was updated + if idx := s.maxIndex("acls"); idx != 2 { + t.Fatalf("bad: %d", idx) + } + + // ACL was updated and matches expected value + expect = &structs.ACL{ + ID: "acl1", + Name: "First ACL", + Type: structs.ACLTypeClient, + Rules: "rules2", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 2, + }, + } + if !reflect.DeepEqual(acl, expect) { + t.Fatalf("bad: %#v", acl) + } +} + +func TestStateStore_ACLList(t *testing.T) { + s := testStateStore(t) + + // Listing when no ACLs exist returns nil + idx, res, err := s.ACLList() + if idx != 0 || res != nil || err != nil { + t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err) + } + + // Insert some ACLs + acls := structs.ACLs{ + &structs.ACL{ + ID: "acl1", + Type: structs.ACLTypeClient, + Rules: "rules1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + }, + &structs.ACL{ + ID: "acl2", + Type: structs.ACLTypeClient, + Rules: "rules2", + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + }, + } + for _, acl := range acls { + if err := s.ACLSet(acl.ModifyIndex, acl); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Query the ACLs + idx, res, err = s.ACLList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + // Check that the result matches + if !reflect.DeepEqual(res, acls) { + t.Fatalf("bad: %#v", res) + } +} + +func TestStateStore_ACLDelete(t *testing.T) { + s := testStateStore(t) + + // Calling delete on an ACL which doesn't exist returns nil + if err := s.ACLDelete(1, "nope"); err != nil { + t.Fatalf("err: %s", err) + } + + // Index isn't updated if nothing is deleted + if idx := s.maxIndex("acls"); idx != 0 { + t.Fatalf("bad index: %d", idx) + } + + // Insert an ACL + if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + + // Delete the ACL and check that the index was updated + if err := s.ACLDelete(2, "acl1"); err != nil { + t.Fatalf("err: %s", err) + } + if idx := s.maxIndex("acls"); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + + tx := s.db.Txn(false) + defer tx.Abort() + + // Check that the ACL was really deleted + result, err := tx.First("acls", "id", "acl1") + if err != nil { + t.Fatalf("err: %s", err) + } + if result != nil { + t.Fatalf("expected nil, got: %#v", result) + } +} + +func TestStateStore_ACL_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Insert some ACLs. + acls := structs.ACLs{ + &structs.ACL{ + ID: "acl1", + Type: structs.ACLTypeClient, + Rules: "rules1", + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + }, + &structs.ACL{ + ID: "acl2", + Type: structs.ACLTypeClient, + Rules: "rules2", + RaftIndex: structs.RaftIndex{ + CreateIndex: 2, + ModifyIndex: 2, + }, + }, + } + for _, acl := range acls { + if err := s.ACLSet(acl.ModifyIndex, acl); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Snapshot the ACLs. + snap := s.Snapshot() + defer snap.Close() + + // Alter the real state store. + if err := s.ACLDelete(3, "acl1"); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + iter, err := snap.ACLs() + if err != nil { + t.Fatalf("err: %s", err) + } + var dump structs.ACLs + for acl := iter.Next(); acl != nil; acl = iter.Next() { + dump = append(dump, acl.(*structs.ACL)) + } + if !reflect.DeepEqual(dump, acls) { + t.Fatalf("bad: %#v", dump) + } + + // Restore the values into a new state store. + func() { + s := testStateStore(t) + restore := s.Restore() + for _, acl := range dump { + if err := restore.ACL(acl); err != nil { + t.Fatalf("err: %s", err) + } + } + restore.Commit() + + // Read the restored ACLs back out and verify that they match. + idx, res, err := s.ACLList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 2 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(res, acls) { + t.Fatalf("bad: %#v", res) + } + + // Check that the index was updated. + if idx := s.maxIndex("acls"); idx != 2 { + t.Fatalf("bad index: %d", idx) + } + }() +} + +func TestStateStore_ACL_Watches(t *testing.T) { + s := testStateStore(t) + + // Call functions that update the acls table and make sure a watch fires + // each time. + verifyWatch(t, s.getTableWatch("acls"), func() { + if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("acls"), func() { + if err := s.ACLDelete(2, "acl1"); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.getTableWatch("acls"), func() { + restore := s.Restore() + if err := restore.ACL(&structs.ACL{ID: "acl1"}); err != nil { + t.Fatalf("err: %s", err) + } + restore.Commit() + }) +} diff --git a/consul/tombstone_gc.go b/consul/state/tombstone_gc.go similarity index 99% rename from consul/tombstone_gc.go rename to consul/state/tombstone_gc.go index 8dd2e1a5aa07..0d530eb696f5 100644 --- a/consul/tombstone_gc.go +++ b/consul/state/tombstone_gc.go @@ -1,4 +1,4 @@ -package consul +package state import ( "fmt" diff --git a/consul/tombstone_gc_test.go b/consul/state/tombstone_gc_test.go similarity index 99% rename from consul/tombstone_gc_test.go rename to consul/state/tombstone_gc_test.go index ac51e6418d4d..44ca19874ac2 100644 --- a/consul/tombstone_gc_test.go +++ b/consul/state/tombstone_gc_test.go @@ -1,4 +1,4 @@ -package consul +package state import ( "testing" diff --git a/consul/state/watch.go b/consul/state/watch.go new file mode 100644 index 000000000000..f8aa273a8256 --- /dev/null +++ b/consul/state/watch.go @@ -0,0 +1,177 @@ +package state + +import ( + "fmt" + "sync" + + "github.com/armon/go-radix" +) + +// Watch is the external interface that's common to all the different flavors. +type Watch interface { + // Wait registers the given channel and calls it back when the watch + // fires. + Wait(notifyCh chan struct{}) + + // Clear deregisters the given channel. + Clear(notifyCh chan struct{}) +} + +// FullTableWatch implements a single notify group for a table. +type FullTableWatch struct { + group NotifyGroup +} + +// NewFullTableWatch returns a new full table watch. +func NewFullTableWatch() *FullTableWatch { + return &FullTableWatch{} +} + +// See Watch. +func (w *FullTableWatch) Wait(notifyCh chan struct{}) { + w.group.Wait(notifyCh) +} + +// See Watch. +func (w *FullTableWatch) Clear(notifyCh chan struct{}) { + w.group.Clear(notifyCh) +} + +// Notify wakes up all the watchers registered for this table. +func (w *FullTableWatch) Notify() { + w.group.Notify() +} + +// DumbWatchManager is a wrapper that allows nested code to arm full table +// watches multiple times but fire them only once. This doesn't have any +// way to clear the state, and it's not thread-safe, so it should be used once +// and thrown away inside the context of a single thread. +type DumbWatchManager struct { + // tableWatches holds the full table watches. + tableWatches map[string]*FullTableWatch + + // armed tracks whether the table should be notified. + armed map[string]bool +} + +// NewDumbWatchManager returns a new dumb watch manager. +func NewDumbWatchManager(tableWatches map[string]*FullTableWatch) *DumbWatchManager { + return &DumbWatchManager{ + tableWatches: tableWatches, + armed: make(map[string]bool), + } +} + +// Arm arms the given table's watch. +func (d *DumbWatchManager) Arm(table string) { + if _, ok := d.tableWatches[table]; !ok { + panic(fmt.Sprintf("unknown table: %s", table)) + } + + if _, ok := d.armed[table]; !ok { + d.armed[table] = true + } +} + +// Notify fires watches for all the armed tables. +func (d *DumbWatchManager) Notify() { + for table, _ := range d.armed { + d.tableWatches[table].Notify() + } +} + +// PrefixWatch maintains a notify group for each prefix, allowing for much more +// fine-grained watches. +type PrefixWatch struct { + // watches has the set of notify groups, organized by prefix. + watches *radix.Tree + + // lock protects the watches tree. + lock sync.Mutex +} + +// NewPrefixWatch returns a new prefix watch. +func NewPrefixWatch() *PrefixWatch { + return &PrefixWatch{ + watches: radix.New(), + } +} + +// GetSubwatch returns the notify group for the given prefix. +func (w *PrefixWatch) GetSubwatch(prefix string) *NotifyGroup { + w.lock.Lock() + defer w.lock.Unlock() + + if raw, ok := w.watches.Get(prefix); ok { + return raw.(*NotifyGroup) + } + + group := &NotifyGroup{} + w.watches.Insert(prefix, group) + return group +} + +// Notify wakes up all the watchers associated with the given prefix. If subtree +// is true then we will also notify all the tree under the prefix, such as when +// a key is being deleted. +func (w *PrefixWatch) Notify(prefix string, subtree bool) { + w.lock.Lock() + defer w.lock.Unlock() + + var cleanup []string + fn := func(k string, v interface{}) bool { + group := v.(*NotifyGroup) + group.Notify() + if k != "" { + cleanup = append(cleanup, k) + } + return false + } + + // Invoke any watcher on the path downward to the key. + w.watches.WalkPath(prefix, fn) + + // If the entire prefix may be affected (e.g. delete tree), + // invoke the entire prefix. + if subtree { + w.watches.WalkPrefix(prefix, fn) + } + + // Delete the old notify groups. + for i := len(cleanup) - 1; i >= 0; i-- { + w.watches.Delete(cleanup[i]) + } + + // TODO (slackpad) If a watch never fires then we will never clear it + // out of the tree. The old state store had the same behavior, so this + // has been around for a while. We should probably add a prefix scan + // with a function that clears out any notify groups that are empty. +} + +// MultiWatch wraps several watches and allows any of them to trigger the +// caller. +type MultiWatch struct { + // watches holds the list of subordinate watches to forward events to. + watches []Watch +} + +// NewMultiWatch returns a new new multi watch over the given set of watches. +func NewMultiWatch(watches ...Watch) *MultiWatch { + return &MultiWatch{ + watches: watches, + } +} + +// See Watch. +func (w *MultiWatch) Wait(notifyCh chan struct{}) { + for _, watch := range w.watches { + watch.Wait(notifyCh) + } +} + +// See Watch. +func (w *MultiWatch) Clear(notifyCh chan struct{}) { + for _, watch := range w.watches { + watch.Clear(notifyCh) + } +} diff --git a/consul/state/watch_test.go b/consul/state/watch_test.go new file mode 100644 index 000000000000..64f08df06e54 --- /dev/null +++ b/consul/state/watch_test.go @@ -0,0 +1,279 @@ +package state + +import ( + "testing" +) + +// verifyWatch will set up a watch channel, call the given function, and then +// make sure the watch fires. +func verifyWatch(t *testing.T, watch Watch, fn func()) { + ch := make(chan struct{}, 1) + watch.Wait(ch) + + fn() + + select { + case <-ch: + default: + t.Fatalf("watch should have been notified") + } +} + +// verifyNoWatch will set up a watch channel, call the given function, and then +// make sure the watch never fires. +func verifyNoWatch(t *testing.T, watch Watch, fn func()) { + ch := make(chan struct{}, 1) + watch.Wait(ch) + + fn() + + select { + case <-ch: + t.Fatalf("watch should not been notified") + default: + } +} + +func TestWatch_FullTableWatch(t *testing.T) { + w := NewFullTableWatch() + + // Test the basic trigger with a single watcher. + verifyWatch(t, w, func() { + w.Notify() + }) + + // Run multiple watchers and make sure they both fire. + verifyWatch(t, w, func() { + verifyWatch(t, w, func() { + w.Notify() + }) + }) + + // Make sure clear works. + ch := make(chan struct{}, 1) + w.Wait(ch) + w.Clear(ch) + w.Notify() + select { + case <-ch: + t.Fatalf("watch should not have been notified") + default: + } + + // Make sure notify is a one shot. + w.Wait(ch) + w.Notify() + select { + case <-ch: + default: + t.Fatalf("watch should have been notified") + } + w.Notify() + select { + case <-ch: + t.Fatalf("watch should not have been notified") + default: + } +} + +func TestWatch_DumbWatchManager(t *testing.T) { + watches := map[string]*FullTableWatch{ + "alice": NewFullTableWatch(), + "bob": NewFullTableWatch(), + "carol": NewFullTableWatch(), + } + + // Notify with nothing armed and make sure nothing triggers. + func() { + w := NewDumbWatchManager(watches) + verifyNoWatch(t, watches["alice"], func() { + verifyNoWatch(t, watches["bob"], func() { + verifyNoWatch(t, watches["carol"], func() { + w.Notify() + }) + }) + }) + }() + + // Trigger one watch. + func() { + w := NewDumbWatchManager(watches) + verifyWatch(t, watches["alice"], func() { + verifyNoWatch(t, watches["bob"], func() { + verifyNoWatch(t, watches["carol"], func() { + w.Arm("alice") + w.Notify() + }) + }) + }) + }() + + // Trigger two watches. + func() { + w := NewDumbWatchManager(watches) + verifyWatch(t, watches["alice"], func() { + verifyNoWatch(t, watches["bob"], func() { + verifyWatch(t, watches["carol"], func() { + w.Arm("alice") + w.Arm("carol") + w.Notify() + }) + }) + }) + }() + + // Trigger all three watches. + func() { + w := NewDumbWatchManager(watches) + verifyWatch(t, watches["alice"], func() { + verifyWatch(t, watches["bob"], func() { + verifyWatch(t, watches["carol"], func() { + w.Arm("alice") + w.Arm("bob") + w.Arm("carol") + w.Notify() + }) + }) + }) + }() + + // Trigger multiple times. + func() { + w := NewDumbWatchManager(watches) + verifyWatch(t, watches["alice"], func() { + verifyNoWatch(t, watches["bob"], func() { + verifyNoWatch(t, watches["carol"], func() { + w.Arm("alice") + w.Arm("alice") + w.Notify() + }) + }) + }) + }() + + // Make sure it panics when asked to arm an unknown table. + func() { + defer func() { + if r := recover(); r == nil { + t.Fatalf("didn't get expected panic") + } + }() + w := NewDumbWatchManager(watches) + w.Arm("nope") + }() +} + +func TestWatch_PrefixWatch(t *testing.T) { + w := NewPrefixWatch() + + // Hit a specific key. + verifyWatch(t, w.GetSubwatch(""), func() { + verifyWatch(t, w.GetSubwatch("foo/bar/baz"), func() { + verifyNoWatch(t, w.GetSubwatch("foo/bar/zoo"), func() { + verifyNoWatch(t, w.GetSubwatch("nope"), func() { + w.Notify("foo/bar/baz", false) + }) + }) + }) + }) + + // Make sure cleanup is happening. All that should be left is the + // full-table watch and the un-fired watches. + fn := func(k string, v interface{}) bool { + if k != "" && k != "foo/bar/zoo" && k != "nope" { + t.Fatalf("unexpected watch: %s", k) + } + return false + } + w.watches.WalkPrefix("", fn) + + // Delete a subtree. + verifyWatch(t, w.GetSubwatch(""), func() { + verifyWatch(t, w.GetSubwatch("foo/bar/baz"), func() { + verifyWatch(t, w.GetSubwatch("foo/bar/zoo"), func() { + verifyNoWatch(t, w.GetSubwatch("nope"), func() { + w.Notify("foo/", true) + }) + }) + }) + }) + + // Hit an unknown key. + verifyWatch(t, w.GetSubwatch(""), func() { + verifyNoWatch(t, w.GetSubwatch("foo/bar/baz"), func() { + verifyNoWatch(t, w.GetSubwatch("foo/bar/zoo"), func() { + verifyNoWatch(t, w.GetSubwatch("nope"), func() { + w.Notify("not/in/there", false) + }) + }) + }) + }) +} + +type MockWatch struct { + Waits map[chan struct{}]int + Clears map[chan struct{}]int +} + +func NewMockWatch() *MockWatch { + return &MockWatch{ + Waits: make(map[chan struct{}]int), + Clears: make(map[chan struct{}]int), + } +} + +func (m *MockWatch) Wait(notifyCh chan struct{}) { + if _, ok := m.Waits[notifyCh]; ok { + m.Waits[notifyCh]++ + } else { + m.Waits[notifyCh] = 1 + } +} + +func (m *MockWatch) Clear(notifyCh chan struct{}) { + if _, ok := m.Clears[notifyCh]; ok { + m.Clears[notifyCh]++ + } else { + m.Clears[notifyCh] = 1 + } +} + +func TestWatch_MultiWatch(t *testing.T) { + w1, w2 := NewMockWatch(), NewMockWatch() + w := NewMultiWatch(w1, w2) + + // Do some activity. + c1, c2 := make(chan struct{}), make(chan struct{}) + w.Wait(c1) + w.Clear(c1) + w.Wait(c1) + w.Wait(c2) + w.Clear(c1) + w.Clear(c2) + + // Make sure all the events were forwarded. + if cnt, ok := w1.Waits[c1]; !ok || cnt != 2 { + t.Fatalf("bad: %d", w1.Waits[c1]) + } + if cnt, ok := w1.Clears[c1]; !ok || cnt != 2 { + t.Fatalf("bad: %d", w1.Clears[c1]) + } + if cnt, ok := w1.Waits[c2]; !ok || cnt != 1 { + t.Fatalf("bad: %d", w1.Waits[c2]) + } + if cnt, ok := w1.Clears[c2]; !ok || cnt != 1 { + t.Fatalf("bad: %d", w1.Clears[c2]) + } + if cnt, ok := w2.Waits[c1]; !ok || cnt != 2 { + t.Fatalf("bad: %d", w2.Waits[c1]) + } + if cnt, ok := w2.Clears[c1]; !ok || cnt != 2 { + t.Fatalf("bad: %d", w2.Clears[c1]) + } + if cnt, ok := w2.Waits[c2]; !ok || cnt != 1 { + t.Fatalf("bad: %d", w2.Waits[c2]) + } + if cnt, ok := w2.Clears[c2]; !ok || cnt != 1 { + t.Fatalf("bad: %d", w2.Clears[c2]) + } +} diff --git a/consul/state_store.go b/consul/state_store.go deleted file mode 100644 index 038ae212bcd2..000000000000 --- a/consul/state_store.go +++ /dev/null @@ -1,2140 +0,0 @@ -package consul - -import ( - "fmt" - "io" - "io/ioutil" - "log" - "os" - "runtime" - "strings" - "sync" - "time" - - "github.com/armon/go-radix" - "github.com/armon/gomdb" - "github.com/hashicorp/consul/consul/structs" -) - -const ( - dbNodes = "nodes" - dbServices = "services" - dbChecks = "checks" - dbKVS = "kvs" - dbTombstone = "tombstones" - dbSessions = "sessions" - dbSessionChecks = "sessionChecks" - dbACLs = "acls" - dbMaxMapSize32bit uint64 = 128 * 1024 * 1024 // 128MB maximum size - dbMaxMapSize64bit uint64 = 32 * 1024 * 1024 * 1024 // 32GB maximum size - dbMaxReaders uint = 4096 // 4K, default is 126 -) - -// kvMode is used internally to control which type of set -// operation we are performing -type kvMode int - -const ( - kvSet kvMode = iota - kvCAS - kvLock - kvUnlock -) - -// The StateStore is responsible for maintaining all the Consul -// state. It is manipulated by the FSM which maintains consistency -// through the use of Raft. The goals of the StateStore are to provide -// high concurrency for read operations without blocking writes, and -// to provide write availability in the face of reads. The current -// implementation uses the Lightning Memory-Mapped Database (MDB). -// This gives us Multi-Version Concurrency Control for "free" -type StateStore struct { - logger *log.Logger - path string - env *mdb.Env - nodeTable *MDBTable - serviceTable *MDBTable - checkTable *MDBTable - kvsTable *MDBTable - tombstoneTable *MDBTable - sessionTable *MDBTable - sessionCheckTable *MDBTable - aclTable *MDBTable - tables MDBTables - watch map[*MDBTable]*NotifyGroup - queryTables map[string]MDBTables - - // kvWatch is a more optimized way of watching for KV changes. - // Instead of just using a NotifyGroup for the entire table, - // a watcher is instantiated on a given prefix. When a change happens, - // only the relevant watchers are woken up. This reduces the cost of - // watching for KV changes. - kvWatch *radix.Tree - kvWatchLock sync.Mutex - - // lockDelay is used to mark certain locks as unacquirable. - // When a lock is forcefully released (failing health - // check, destroyed session, etc), it is subject to the LockDelay - // imposed by the session. This prevents another session from - // acquiring the lock for some period of time as a protection against - // split-brains. This is inspired by the lock-delay in Chubby. - // Because this relies on wall-time, we cannot assume all peers - // perceive time as flowing uniformly. This means KVSLock MUST ignore - // lockDelay, since the lockDelay may have expired on the leader, - // but not on the follower. Rejecting the lock could result in - // inconsistencies in the FSMs due to the rate time progresses. Instead, - // only the opinion of the leader is respected, and the Raft log - // is never questioned. - lockDelay map[string]time.Time - lockDelayLock sync.RWMutex - - // GC is when we create tombstones to track their time-to-live. - // The GC is consumed upstream to manage clearing of tombstones. - gc *TombstoneGC -} - -// StateSnapshot is used to provide a point-in-time snapshot -// It works by starting a readonly transaction against all tables. -type StateSnapshot struct { - store *StateStore - tx *MDBTxn - lastIndex uint64 -} - -// sessionCheck is used to create a many-to-one table such -// that each check registered by a session can be mapped back -// to the session row. -type sessionCheck struct { - Node string - CheckID string - Session string -} - -// Close is used to abort the transaction and allow for cleanup -func (s *StateSnapshot) Close() error { - s.tx.Abort() - return nil -} - -// NewStateStore is used to create a new state store -func NewStateStore(gc *TombstoneGC, logOutput io.Writer) (*StateStore, error) { - // Create a new temp dir - path, err := ioutil.TempDir("", "consul") - if err != nil { - return nil, err - } - return NewStateStorePath(gc, path, logOutput) -} - -// NewStateStorePath is used to create a new state store at a given path -// The path is cleared on closing. -func NewStateStorePath(gc *TombstoneGC, path string, logOutput io.Writer) (*StateStore, error) { - // Open the env - env, err := mdb.NewEnv() - if err != nil { - return nil, err - } - - s := &StateStore{ - logger: log.New(logOutput, "", log.LstdFlags), - path: path, - env: env, - watch: make(map[*MDBTable]*NotifyGroup), - kvWatch: radix.New(), - lockDelay: make(map[string]time.Time), - gc: gc, - } - - // Ensure we can initialize - if err := s.initialize(); err != nil { - env.Close() - os.RemoveAll(path) - return nil, err - } - return s, nil -} - -// Close is used to safely shutdown the state store -func (s *StateStore) Close() error { - s.env.Close() - os.RemoveAll(s.path) - return nil -} - -// initialize is used to setup the store for use -func (s *StateStore) initialize() error { - // Setup the Env first - if err := s.env.SetMaxDBs(mdb.DBI(32)); err != nil { - return err - } - - // Set the maximum db size based on 32/64bit. Since we are - // doing an mmap underneath, we need to limit our use of virtual - // address space on 32bit, but don't have to care on 64bit. - dbSize := dbMaxMapSize32bit - if runtime.GOARCH == "amd64" { - dbSize = dbMaxMapSize64bit - } - - // Increase the maximum map size - if err := s.env.SetMapSize(dbSize); err != nil { - return err - } - - // Increase the maximum number of concurrent readers - // TODO: Block transactions if we could exceed dbMaxReaders - if err := s.env.SetMaxReaders(dbMaxReaders); err != nil { - return err - } - - // Optimize our flags for speed over safety, since the Raft log + snapshots - // are durable. We treat this as an ephemeral in-memory DB, since we nuke - // the data anyways. - var flags uint = mdb.NOMETASYNC | mdb.NOSYNC | mdb.NOTLS - if err := s.env.Open(s.path, flags, 0755); err != nil { - return err - } - - // Tables use a generic struct encoder - encoder := func(obj interface{}) []byte { - buf, err := structs.Encode(255, obj) - if err != nil { - panic(err) - } - return buf[1:] - } - - // Setup our tables - s.nodeTable = &MDBTable{ - Name: dbNodes, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Node"}, - CaseInsensitive: true, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.Node) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - s.serviceTable = &MDBTable{ - Name: dbServices, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Node", "ServiceID"}, - }, - "service": &MDBIndex{ - AllowBlank: true, - Fields: []string{"ServiceName"}, - CaseInsensitive: true, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.ServiceNode) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - s.checkTable = &MDBTable{ - Name: dbChecks, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Node", "CheckID"}, - }, - "status": &MDBIndex{ - Fields: []string{"Status"}, - }, - "service": &MDBIndex{ - AllowBlank: true, - Fields: []string{"ServiceName"}, - }, - "node": &MDBIndex{ - AllowBlank: true, - Fields: []string{"Node", "ServiceID"}, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.HealthCheck) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - s.kvsTable = &MDBTable{ - Name: dbKVS, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "id_prefix": &MDBIndex{ - Virtual: true, - RealIndex: "id", - Fields: []string{"Key"}, - IdxFunc: DefaultIndexPrefixFunc, - }, - "session": &MDBIndex{ - AllowBlank: true, - Fields: []string{"Session"}, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.DirEntry) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - s.tombstoneTable = &MDBTable{ - Name: dbTombstone, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Key"}, - }, - "id_prefix": &MDBIndex{ - Virtual: true, - RealIndex: "id", - Fields: []string{"Key"}, - IdxFunc: DefaultIndexPrefixFunc, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.DirEntry) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - s.sessionTable = &MDBTable{ - Name: dbSessions, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"ID"}, - }, - "node": &MDBIndex{ - AllowBlank: true, - Fields: []string{"Node"}, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.Session) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - s.sessionCheckTable = &MDBTable{ - Name: dbSessionChecks, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"Node", "CheckID", "Session"}, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(sessionCheck) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - s.aclTable = &MDBTable{ - Name: dbACLs, - Indexes: map[string]*MDBIndex{ - "id": &MDBIndex{ - Unique: true, - Fields: []string{"ID"}, - }, - }, - Decoder: func(buf []byte) interface{} { - out := new(structs.ACL) - if err := structs.Decode(buf, out); err != nil { - panic(err) - } - return out - }, - } - - // Store the set of tables - s.tables = []*MDBTable{s.nodeTable, s.serviceTable, s.checkTable, - s.kvsTable, s.tombstoneTable, s.sessionTable, s.sessionCheckTable, - s.aclTable} - for _, table := range s.tables { - table.Env = s.env - table.Encoder = encoder - if err := table.Init(); err != nil { - return err - } - - // Setup a notification group per table - s.watch[table] = &NotifyGroup{} - } - - // Setup the query tables - s.queryTables = map[string]MDBTables{ - "Nodes": MDBTables{s.nodeTable}, - "Services": MDBTables{s.serviceTable}, - "ServiceNodes": MDBTables{s.nodeTable, s.serviceTable}, - "NodeServices": MDBTables{s.nodeTable, s.serviceTable}, - "ChecksInState": MDBTables{s.checkTable}, - "NodeChecks": MDBTables{s.checkTable}, - "ServiceChecks": MDBTables{s.checkTable}, - "CheckServiceNodes": MDBTables{s.nodeTable, s.serviceTable, s.checkTable}, - "NodeInfo": MDBTables{s.nodeTable, s.serviceTable, s.checkTable}, - "NodeDump": MDBTables{s.nodeTable, s.serviceTable, s.checkTable}, - "SessionGet": MDBTables{s.sessionTable}, - "SessionList": MDBTables{s.sessionTable}, - "NodeSessions": MDBTables{s.sessionTable}, - "ACLGet": MDBTables{s.aclTable}, - "ACLList": MDBTables{s.aclTable}, - } - return nil -} - -// Watch is used to subscribe a channel to a set of MDBTables -func (s *StateStore) Watch(tables MDBTables, notify chan struct{}) { - for _, t := range tables { - s.watch[t].Wait(notify) - } -} - -// StopWatch is used to unsubscribe a channel to a set of MDBTables -func (s *StateStore) StopWatch(tables MDBTables, notify chan struct{}) { - for _, t := range tables { - s.watch[t].Clear(notify) - } -} - -// WatchKV is used to subscribe a channel to changes in KV data -func (s *StateStore) WatchKV(prefix string, notify chan struct{}) { - s.kvWatchLock.Lock() - defer s.kvWatchLock.Unlock() - - // Check for an existing notify group - if raw, ok := s.kvWatch.Get(prefix); ok { - grp := raw.(*NotifyGroup) - grp.Wait(notify) - return - } - - // Create new notify group - grp := &NotifyGroup{} - grp.Wait(notify) - s.kvWatch.Insert(prefix, grp) -} - -// StopWatchKV is used to unsubscribe a channel from changes in KV data -func (s *StateStore) StopWatchKV(prefix string, notify chan struct{}) { - s.kvWatchLock.Lock() - defer s.kvWatchLock.Unlock() - - // Check for an existing notify group - if raw, ok := s.kvWatch.Get(prefix); ok { - grp := raw.(*NotifyGroup) - grp.Clear(notify) - } -} - -// notifyKV is used to notify any KV listeners of a change -// on a prefix -func (s *StateStore) notifyKV(path string, prefix bool) { - s.kvWatchLock.Lock() - defer s.kvWatchLock.Unlock() - - var toDelete []string - fn := func(s string, v interface{}) bool { - group := v.(*NotifyGroup) - group.Notify() - if s != "" { - toDelete = append(toDelete, s) - } - return false - } - - // Invoke any watcher on the path downward to the key. - s.kvWatch.WalkPath(path, fn) - - // If the entire prefix may be affected (e.g. delete tree), - // invoke the entire prefix - if prefix { - s.kvWatch.WalkPrefix(path, fn) - } - - // Delete the old watch groups - for i := len(toDelete) - 1; i >= 0; i-- { - s.kvWatch.Delete(toDelete[i]) - } -} - -// QueryTables returns the Tables that are queried for a given query -func (s *StateStore) QueryTables(q string) MDBTables { - return s.queryTables[q] -} - -// EnsureRegistration is used to make sure a node, service, and check registration -// is performed within a single transaction to avoid race conditions on state updates. -func (s *StateStore) EnsureRegistration(index uint64, req *structs.RegisterRequest) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - // Ensure the node - node := structs.Node{req.Node, req.Address} - if err := s.ensureNodeTxn(index, node, tx); err != nil { - return err - } - - // Ensure the service if provided - if req.Service != nil { - if err := s.ensureServiceTxn(index, req.Node, req.Service, tx); err != nil { - return err - } - } - - // Ensure the check(s), if provided - if req.Check != nil { - if err := s.ensureCheckTxn(index, req.Check, tx); err != nil { - return err - } - } - for _, check := range req.Checks { - if err := s.ensureCheckTxn(index, check, tx); err != nil { - return err - } - } - - // Commit as one unit - return tx.Commit() -} - -// EnsureNode is used to ensure a given node exists, with the provided address -func (s *StateStore) EnsureNode(index uint64, node structs.Node) error { - tx, err := s.nodeTable.StartTxn(false, nil) - if err != nil { - return err - } - defer tx.Abort() - if err := s.ensureNodeTxn(index, node, tx); err != nil { - return err - } - return tx.Commit() -} - -// ensureNodeTxn is used to ensure a given node exists, with the provided address -// within a given txn -func (s *StateStore) ensureNodeTxn(index uint64, node structs.Node, tx *MDBTxn) error { - if err := s.nodeTable.InsertTxn(tx, node); err != nil { - return err - } - if err := s.nodeTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.nodeTable].Notify() }) - return nil -} - -// GetNode returns all the address of the known and if it was found -func (s *StateStore) GetNode(name string) (uint64, bool, string) { - idx, res, err := s.nodeTable.Get("id", name) - if err != nil { - s.logger.Printf("[ERR] consul.state: Error during node lookup: %v", err) - return 0, false, "" - } - if len(res) == 0 { - return idx, false, "" - } - return idx, true, res[0].(*structs.Node).Address -} - -// GetNodes returns all the known nodes, the slice alternates between -// the node name and address -func (s *StateStore) Nodes() (uint64, structs.Nodes) { - idx, res, err := s.nodeTable.Get("id") - if err != nil { - s.logger.Printf("[ERR] consul.state: Error getting nodes: %v", err) - } - results := make([]structs.Node, len(res)) - for i, r := range res { - results[i] = *r.(*structs.Node) - } - return idx, results -} - -// EnsureService is used to ensure a given node exposes a service -func (s *StateStore) EnsureService(index uint64, node string, ns *structs.NodeService) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - if err := s.ensureServiceTxn(index, node, ns, tx); err != nil { - return nil - } - return tx.Commit() -} - -// ensureServiceTxn is used to ensure a given node exposes a service in a transaction -func (s *StateStore) ensureServiceTxn(index uint64, node string, ns *structs.NodeService, tx *MDBTxn) error { - // Ensure the node exists - res, err := s.nodeTable.GetTxn(tx, "id", node) - if err != nil { - return err - } - if len(res) == 0 { - return fmt.Errorf("Missing node registration") - } - - // Create the entry - entry := structs.ServiceNode{ - Node: node, - ServiceID: ns.ID, - ServiceName: ns.Service, - ServiceTags: ns.Tags, - ServiceAddress: ns.Address, - ServicePort: ns.Port, - } - - // Ensure the service entry is set - if err := s.serviceTable.InsertTxn(tx, &entry); err != nil { - return err - } - if err := s.serviceTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.serviceTable].Notify() }) - return nil -} - -// NodeServices is used to return all the services of a given node -func (s *StateStore) NodeServices(name string) (uint64, *structs.NodeServices) { - tables := s.queryTables["NodeServices"] - tx, err := tables.StartTxn(true) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - return s.parseNodeServices(tables, tx, name) -} - -// parseNodeServices is used to get the services belonging to a -// node, using a given txn -func (s *StateStore) parseNodeServices(tables MDBTables, tx *MDBTxn, name string) (uint64, *structs.NodeServices) { - ns := &structs.NodeServices{ - Services: make(map[string]*structs.NodeService), - } - - // Get the maximum index - index, err := tables.LastIndexTxn(tx) - if err != nil { - panic(fmt.Errorf("Failed to get last index: %v", err)) - } - - // Get the node first - res, err := s.nodeTable.GetTxn(tx, "id", name) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get node: %v", err) - } - if len(res) == 0 { - return index, nil - } - - // Set the address - node := res[0].(*structs.Node) - ns.Node = *node - - // Get the services - res, err = s.serviceTable.GetTxn(tx, "id", name) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get node '%s' services: %v", name, err) - } - - // Add each service - for _, r := range res { - service := r.(*structs.ServiceNode) - srv := &structs.NodeService{ - ID: service.ServiceID, - Service: service.ServiceName, - Tags: service.ServiceTags, - Address: service.ServiceAddress, - Port: service.ServicePort, - } - ns.Services[srv.ID] = srv - } - return index, ns -} - -// DeleteNodeService is used to delete a node service -func (s *StateStore) DeleteNodeService(index uint64, node, id string) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - if n, err := s.serviceTable.DeleteTxn(tx, "id", node, id); err != nil { - return err - } else if n > 0 { - if err := s.serviceTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.serviceTable].Notify() }) - } - - // Invalidate any sessions using these checks - checks, err := s.checkTable.GetTxn(tx, "node", node, id) - if err != nil { - return err - } - for _, c := range checks { - check := c.(*structs.HealthCheck) - if err := s.invalidateCheck(index, tx, node, check.CheckID); err != nil { - return err - } - } - - if n, err := s.checkTable.DeleteTxn(tx, "node", node, id); err != nil { - return err - } else if n > 0 { - if err := s.checkTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.checkTable].Notify() }) - } - return tx.Commit() -} - -// DeleteNode is used to delete a node and all it's services -func (s *StateStore) DeleteNode(index uint64, node string) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - // Invalidate any sessions held by the node - if err := s.invalidateNode(index, tx, node); err != nil { - return err - } - - if n, err := s.serviceTable.DeleteTxn(tx, "id", node); err != nil { - return err - } else if n > 0 { - if err := s.serviceTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.serviceTable].Notify() }) - } - if n, err := s.checkTable.DeleteTxn(tx, "id", node); err != nil { - return err - } else if n > 0 { - if err := s.checkTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.checkTable].Notify() }) - } - if n, err := s.nodeTable.DeleteTxn(tx, "id", node); err != nil { - return err - } else if n > 0 { - if err := s.nodeTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.nodeTable].Notify() }) - } - return tx.Commit() -} - -// Services is used to return all the services with a list of associated tags -func (s *StateStore) Services() (uint64, map[string][]string) { - services := make(map[string][]string) - idx, res, err := s.serviceTable.Get("id") - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get services: %v", err) - return idx, services - } - for _, r := range res { - srv := r.(*structs.ServiceNode) - tags, ok := services[srv.ServiceName] - if !ok { - services[srv.ServiceName] = make([]string, 0) - } - - for _, tag := range srv.ServiceTags { - if !strContains(tags, tag) { - tags = append(tags, tag) - services[srv.ServiceName] = tags - } - } - } - return idx, services -} - -// ServiceNodes returns the nodes associated with a given service -func (s *StateStore) ServiceNodes(service string) (uint64, structs.ServiceNodes) { - tables := s.queryTables["ServiceNodes"] - tx, err := tables.StartTxn(true) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - idx, err := tables.LastIndexTxn(tx) - if err != nil { - panic(fmt.Errorf("Failed to get last index: %v", err)) - } - - res, err := s.serviceTable.GetTxn(tx, "service", service) - return idx, s.parseServiceNodes(tx, s.nodeTable, res, err) -} - -// ServiceTagNodes returns the nodes associated with a given service matching a tag -func (s *StateStore) ServiceTagNodes(service, tag string) (uint64, structs.ServiceNodes) { - tables := s.queryTables["ServiceNodes"] - tx, err := tables.StartTxn(true) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - idx, err := tables.LastIndexTxn(tx) - if err != nil { - panic(fmt.Errorf("Failed to get last index: %v", err)) - } - - res, err := s.serviceTable.GetTxn(tx, "service", service) - res = serviceTagFilter(res, tag) - return idx, s.parseServiceNodes(tx, s.nodeTable, res, err) -} - -// serviceTagFilter is used to filter a list of *structs.ServiceNode which do -// not have the specified tag -func serviceTagFilter(l []interface{}, tag string) []interface{} { - n := len(l) - for i := 0; i < n; i++ { - srv := l[i].(*structs.ServiceNode) - if !strContains(ToLowerList(srv.ServiceTags), strings.ToLower(tag)) { - l[i], l[n-1] = l[n-1], nil - i-- - n-- - } - } - return l[:n] -} - -// parseServiceNodes parses results ServiceNodes and ServiceTagNodes -func (s *StateStore) parseServiceNodes(tx *MDBTxn, table *MDBTable, res []interface{}, err error) structs.ServiceNodes { - nodes := make(structs.ServiceNodes, len(res)) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get service nodes: %v", err) - return nodes - } - - for i, r := range res { - srv := r.(*structs.ServiceNode) - - // Get the address of the node - nodeRes, err := table.GetTxn(tx, "id", srv.Node) - if err != nil || len(nodeRes) != 1 { - s.logger.Printf("[ERR] consul.state: Failed to join service node %#v with node: %v", *srv, err) - continue - } - srv.Address = nodeRes[0].(*structs.Node).Address - - nodes[i] = *srv - } - - return nodes -} - -// EnsureCheck is used to create a check or updates it's state -func (s *StateStore) EnsureCheck(index uint64, check *structs.HealthCheck) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - if err := s.ensureCheckTxn(index, check, tx); err != nil { - return err - } - return tx.Commit() -} - -// ensureCheckTxn is used to create a check or updates it's state in a transaction -func (s *StateStore) ensureCheckTxn(index uint64, check *structs.HealthCheck, tx *MDBTxn) error { - // Ensure we have a status - if check.Status == "" { - check.Status = structs.HealthCritical - } - - // Ensure the node exists - res, err := s.nodeTable.GetTxn(tx, "id", check.Node) - if err != nil { - return err - } - if len(res) == 0 { - return fmt.Errorf("Missing node registration") - } - - // Ensure the service exists if specified - if check.ServiceID != "" { - res, err = s.serviceTable.GetTxn(tx, "id", check.Node, check.ServiceID) - if err != nil { - return err - } - if len(res) == 0 { - return fmt.Errorf("Missing service registration") - } - // Ensure we set the correct service - srv := res[0].(*structs.ServiceNode) - check.ServiceName = srv.ServiceName - } - - // Invalidate any sessions if status is critical - if check.Status == structs.HealthCritical { - err := s.invalidateCheck(index, tx, check.Node, check.CheckID) - if err != nil { - return err - } - } - - // Ensure the check is set - if err := s.checkTable.InsertTxn(tx, check); err != nil { - return err - } - if err := s.checkTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.checkTable].Notify() }) - return nil -} - -// DeleteNodeCheck is used to delete a node health check -func (s *StateStore) DeleteNodeCheck(index uint64, node, id string) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - return err - } - defer tx.Abort() - - // Invalidate any sessions held by this check - if err := s.invalidateCheck(index, tx, node, id); err != nil { - return err - } - - if n, err := s.checkTable.DeleteTxn(tx, "id", node, id); err != nil { - return err - } else if n > 0 { - if err := s.checkTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.checkTable].Notify() }) - } - return tx.Commit() -} - -// NodeChecks is used to get all the checks for a node -func (s *StateStore) NodeChecks(node string) (uint64, structs.HealthChecks) { - return s.parseHealthChecks(s.checkTable.Get("id", node)) -} - -// ServiceChecks is used to get all the checks for a service -func (s *StateStore) ServiceChecks(service string) (uint64, structs.HealthChecks) { - return s.parseHealthChecks(s.checkTable.Get("service", service)) -} - -// CheckInState is used to get all the checks for a service in a given state -func (s *StateStore) ChecksInState(state string) (uint64, structs.HealthChecks) { - var idx uint64 - var res []interface{} - var err error - if state == structs.HealthAny { - idx, res, err = s.checkTable.Get("id") - } else { - idx, res, err = s.checkTable.Get("status", state) - } - return s.parseHealthChecks(idx, res, err) -} - -// parseHealthChecks is used to handle the results of a Get against -// the checkTable -func (s *StateStore) parseHealthChecks(idx uint64, res []interface{}, err error) (uint64, structs.HealthChecks) { - results := make([]*structs.HealthCheck, len(res)) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get health checks: %v", err) - return idx, results - } - for i, r := range res { - results[i] = r.(*structs.HealthCheck) - } - return idx, results -} - -// CheckServiceNodes returns the nodes associated with a given service, along -// with any associated check -func (s *StateStore) CheckServiceNodes(service string) (uint64, structs.CheckServiceNodes) { - tables := s.queryTables["CheckServiceNodes"] - tx, err := tables.StartTxn(true) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - idx, err := tables.LastIndexTxn(tx) - if err != nil { - panic(fmt.Errorf("Failed to get last index: %v", err)) - } - - res, err := s.serviceTable.GetTxn(tx, "service", service) - return idx, s.parseCheckServiceNodes(tx, res, err) -} - -// CheckServiceNodes returns the nodes associated with a given service, along -// with any associated checks -func (s *StateStore) CheckServiceTagNodes(service, tag string) (uint64, structs.CheckServiceNodes) { - tables := s.queryTables["CheckServiceNodes"] - tx, err := tables.StartTxn(true) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - idx, err := tables.LastIndexTxn(tx) - if err != nil { - panic(fmt.Errorf("Failed to get last index: %v", err)) - } - - res, err := s.serviceTable.GetTxn(tx, "service", service) - res = serviceTagFilter(res, tag) - return idx, s.parseCheckServiceNodes(tx, res, err) -} - -// parseCheckServiceNodes parses results CheckServiceNodes and CheckServiceTagNodes -func (s *StateStore) parseCheckServiceNodes(tx *MDBTxn, res []interface{}, err error) structs.CheckServiceNodes { - nodes := make(structs.CheckServiceNodes, len(res)) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get service nodes: %v", err) - return nodes - } - - for i, r := range res { - srv := r.(*structs.ServiceNode) - - // Get the node - nodeRes, err := s.nodeTable.GetTxn(tx, "id", srv.Node) - if err != nil || len(nodeRes) != 1 { - s.logger.Printf("[ERR] consul.state: Failed to join service node %#v with node: %v", *srv, err) - continue - } - - // Get any associated checks of the service - res, err := s.checkTable.GetTxn(tx, "node", srv.Node, srv.ServiceID) - _, checks := s.parseHealthChecks(0, res, err) - - // Get any checks of the node, not associated with any service - res, err = s.checkTable.GetTxn(tx, "node", srv.Node, "") - _, nodeChecks := s.parseHealthChecks(0, res, err) - checks = append(checks, nodeChecks...) - - // Setup the node - nodes[i].Node = *nodeRes[0].(*structs.Node) - nodes[i].Service = structs.NodeService{ - ID: srv.ServiceID, - Service: srv.ServiceName, - Tags: srv.ServiceTags, - Address: srv.ServiceAddress, - Port: srv.ServicePort, - } - nodes[i].Checks = checks - } - - return nodes -} - -// NodeInfo is used to generate the full info about a node. -func (s *StateStore) NodeInfo(node string) (uint64, structs.NodeDump) { - tables := s.queryTables["NodeInfo"] - tx, err := tables.StartTxn(true) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - idx, err := tables.LastIndexTxn(tx) - if err != nil { - panic(fmt.Errorf("Failed to get last index: %v", err)) - } - - res, err := s.nodeTable.GetTxn(tx, "id", node) - return idx, s.parseNodeInfo(tx, res, err) -} - -// NodeDump is used to generate the NodeInfo for all nodes. This is very expensive, -// and should generally be avoided for programmatic access. -func (s *StateStore) NodeDump() (uint64, structs.NodeDump) { - tables := s.queryTables["NodeDump"] - tx, err := tables.StartTxn(true) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - idx, err := tables.LastIndexTxn(tx) - if err != nil { - panic(fmt.Errorf("Failed to get last index: %v", err)) - } - - res, err := s.nodeTable.GetTxn(tx, "id") - return idx, s.parseNodeInfo(tx, res, err) -} - -// parseNodeInfo is used to scan over the results of a node -// iteration and generate a NodeDump -func (s *StateStore) parseNodeInfo(tx *MDBTxn, res []interface{}, err error) structs.NodeDump { - dump := make(structs.NodeDump, 0, len(res)) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get nodes: %v", err) - return dump - } - - for _, r := range res { - // Copy the address and node - node := r.(*structs.Node) - info := &structs.NodeInfo{ - Node: node.Node, - Address: node.Address, - } - - // Get any services of the node - res, err = s.serviceTable.GetTxn(tx, "id", node.Node) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get node services: %v", err) - } - info.Services = make([]*structs.NodeService, 0, len(res)) - for _, r := range res { - service := r.(*structs.ServiceNode) - srv := &structs.NodeService{ - ID: service.ServiceID, - Service: service.ServiceName, - Tags: service.ServiceTags, - Address: service.ServiceAddress, - Port: service.ServicePort, - } - info.Services = append(info.Services, srv) - } - - // Get any checks of the node - res, err = s.checkTable.GetTxn(tx, "node", node.Node) - if err != nil { - s.logger.Printf("[ERR] consul.state: Failed to get node checks: %v", err) - } - info.Checks = make([]*structs.HealthCheck, 0, len(res)) - for _, r := range res { - chk := r.(*structs.HealthCheck) - info.Checks = append(info.Checks, chk) - } - - // Add the node info - dump = append(dump, info) - } - return dump -} - -// KVSSet is used to create or update a KV entry -func (s *StateStore) KVSSet(index uint64, d *structs.DirEntry) error { - _, err := s.kvsSet(index, d, kvSet) - return err -} - -// KVSRestore is used to restore a DirEntry. It should only be used when -// doing a restore, otherwise KVSSet should be used. -func (s *StateStore) KVSRestore(d *structs.DirEntry) error { - // Start a new txn - tx, err := s.kvsTable.StartTxn(false, nil) - if err != nil { - return err - } - defer tx.Abort() - - if err := s.kvsTable.InsertTxn(tx, d); err != nil { - return err - } - if err := s.kvsTable.SetMaxLastIndexTxn(tx, d.ModifyIndex); err != nil { - return err - } - return tx.Commit() -} - -// KVSGet is used to get a KV entry -func (s *StateStore) KVSGet(key string) (uint64, *structs.DirEntry, error) { - idx, res, err := s.kvsTable.Get("id", key) - var d *structs.DirEntry - if len(res) > 0 { - d = res[0].(*structs.DirEntry) - } - return idx, d, err -} - -// KVSList is used to list all KV entries with a prefix -func (s *StateStore) KVSList(prefix string) (uint64, uint64, structs.DirEntries, error) { - tables := MDBTables{s.kvsTable, s.tombstoneTable} - tx, err := tables.StartTxn(true) - if err != nil { - return 0, 0, nil, err - } - defer tx.Abort() - - idx, err := tables.LastIndexTxn(tx) - if err != nil { - return 0, 0, nil, err - } - - res, err := s.kvsTable.GetTxn(tx, "id_prefix", prefix) - if err != nil { - return 0, 0, nil, err - } - ents := make(structs.DirEntries, len(res)) - for idx, r := range res { - ents[idx] = r.(*structs.DirEntry) - } - - // Check for the highest index in the tombstone table - var maxIndex uint64 - res, err = s.tombstoneTable.GetTxn(tx, "id_prefix", prefix) - for _, r := range res { - ent := r.(*structs.DirEntry) - if ent.ModifyIndex > maxIndex { - maxIndex = ent.ModifyIndex - } - } - - return maxIndex, idx, ents, err -} - -// KVSListKeys is used to list keys with a prefix, and up to a given separator -func (s *StateStore) KVSListKeys(prefix, seperator string) (uint64, []string, error) { - tables := MDBTables{s.kvsTable, s.tombstoneTable} - tx, err := tables.StartTxn(true) - if err != nil { - return 0, nil, err - } - defer tx.Abort() - - idx, err := s.kvsTable.LastIndexTxn(tx) - if err != nil { - return 0, nil, err - } - - // Ensure a non-zero index - if idx == 0 { - // Must provide non-zero index to prevent blocking - // Index 1 is impossible anyways (due to Raft internals) - idx = 1 - } - - // Aggregate the stream - stream := make(chan interface{}, 128) - streamTomb := make(chan interface{}, 128) - done := make(chan struct{}) - var keys []string - var maxIndex uint64 - go func() { - prefixLen := len(prefix) - sepLen := len(seperator) - last := "" - for raw := range stream { - ent := raw.(*structs.DirEntry) - after := ent.Key[prefixLen:] - - // Update the highest index we've seen - if ent.ModifyIndex > maxIndex { - maxIndex = ent.ModifyIndex - } - - // If there is no separator, always accumulate - if sepLen == 0 { - keys = append(keys, ent.Key) - continue - } - - // Check for the separator - if idx := strings.Index(after, seperator); idx >= 0 { - toSep := ent.Key[:prefixLen+idx+sepLen] - if last != toSep { - keys = append(keys, toSep) - last = toSep - } - } else { - keys = append(keys, ent.Key) - } - } - - // Handle the tombstones for any index updates - for raw := range streamTomb { - ent := raw.(*structs.DirEntry) - if ent.ModifyIndex > maxIndex { - maxIndex = ent.ModifyIndex - } - } - close(done) - }() - - // Start the stream, and wait for completion - if err = s.kvsTable.StreamTxn(stream, tx, "id_prefix", prefix); err != nil { - return 0, nil, err - } - if err := s.tombstoneTable.StreamTxn(streamTomb, tx, "id_prefix", prefix); err != nil { - return 0, nil, err - } - <-done - - // Use the maxIndex if we have any keys - if maxIndex != 0 { - idx = maxIndex - } - return idx, keys, nil -} - -// KVSDelete is used to delete a KVS entry -func (s *StateStore) KVSDelete(index uint64, key string) error { - return s.kvsDeleteWithIndex(index, "id", key) -} - -// KVSDeleteCheckAndSet is used to perform an atomic delete check-and-set -func (s *StateStore) KVSDeleteCheckAndSet(index uint64, key string, casIndex uint64) (bool, error) { - tx, err := s.tables.StartTxn(false) - if err != nil { - return false, err - } - defer tx.Abort() - - // Get the existing node - res, err := s.kvsTable.GetTxn(tx, "id", key) - if err != nil { - return false, err - } - - // Get the existing node if any - var exist *structs.DirEntry - if len(res) > 0 { - exist = res[0].(*structs.DirEntry) - } - - // Use the casIndex as the constraint. A modify time of 0 means - // we are doing a delete-if-not-exists (odd...), while any other - // value means we expect that modify time. - if casIndex == 0 { - return exist == nil, nil - } else if casIndex > 0 && (exist == nil || exist.ModifyIndex != casIndex) { - return false, nil - } - - // Do the actual delete - if err := s.kvsDeleteWithIndexTxn(index, tx, "id", key); err != nil { - return false, err - } - return true, tx.Commit() -} - -// KVSDeleteTree is used to delete all keys with a given prefix -func (s *StateStore) KVSDeleteTree(index uint64, prefix string) error { - if prefix == "" { - return s.kvsDeleteWithIndex(index, "id") - } - return s.kvsDeleteWithIndex(index, "id_prefix", prefix) -} - -// kvsDeleteWithIndex does a delete with either the id or id_prefix -func (s *StateStore) kvsDeleteWithIndex(index uint64, tableIndex string, parts ...string) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - return err - } - defer tx.Abort() - if err := s.kvsDeleteWithIndexTxn(index, tx, tableIndex, parts...); err != nil { - return err - } - return tx.Commit() -} - -// kvsDeleteWithIndexTxn does a delete within an existing transaction -func (s *StateStore) kvsDeleteWithIndexTxn(index uint64, tx *MDBTxn, tableIndex string, parts ...string) error { - num := 0 - for { - // Get some number of entries to delete - pairs, err := s.kvsTable.GetTxnLimit(tx, 128, tableIndex, parts...) - if err != nil { - return err - } - - // Create the tombstones and delete - for _, raw := range pairs { - ent := raw.(*structs.DirEntry) - ent.ModifyIndex = index // Update the index - ent.Value = nil // Reduce storage required - ent.Session = "" - if err := s.tombstoneTable.InsertTxn(tx, ent); err != nil { - return err - } - if num, err := s.kvsTable.DeleteTxn(tx, "id", ent.Key); err != nil { - return err - } else if num != 1 { - return fmt.Errorf("Failed to delete key '%s'", ent.Key) - } - } - - // Increment the total number - num += len(pairs) - if len(pairs) == 0 { - break - } - } - - if num > 0 { - if err := s.kvsTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { - // Trigger the most fine grained notifications if possible - switch { - case len(parts) == 0: - s.notifyKV("", true) - case tableIndex == "id": - s.notifyKV(parts[0], false) - case tableIndex == "id_prefix": - s.notifyKV(parts[0], true) - default: - s.notifyKV("", true) - } - if s.gc != nil { - // If GC is configured, then we hint that this index - // required expiration. - s.gc.Hint(index) - } - }) - } - return nil -} - -// KVSCheckAndSet is used to perform an atomic check-and-set -func (s *StateStore) KVSCheckAndSet(index uint64, d *structs.DirEntry) (bool, error) { - return s.kvsSet(index, d, kvCAS) -} - -// KVSLock works like KVSSet but only writes if the lock can be acquired -func (s *StateStore) KVSLock(index uint64, d *structs.DirEntry) (bool, error) { - return s.kvsSet(index, d, kvLock) -} - -// KVSUnlock works like KVSSet but only writes if the lock can be unlocked -func (s *StateStore) KVSUnlock(index uint64, d *structs.DirEntry) (bool, error) { - return s.kvsSet(index, d, kvUnlock) -} - -// KVSLockDelay returns the expiration time of a key lock delay. A key may -// have a lock delay if it was unlocked due to a session invalidation instead -// of a graceful unlock. This must be checked on the leader node, and not in -// KVSLock due to the variability of clocks. -func (s *StateStore) KVSLockDelay(key string) time.Time { - s.lockDelayLock.RLock() - expires := s.lockDelay[key] - s.lockDelayLock.RUnlock() - return expires -} - -// kvsSet is the internal setter -func (s *StateStore) kvsSet( - index uint64, - d *structs.DirEntry, - mode kvMode) (bool, error) { - // Start a new txn - tx, err := s.tables.StartTxn(false) - if err != nil { - return false, err - } - defer tx.Abort() - - // Get the existing node - res, err := s.kvsTable.GetTxn(tx, "id", d.Key) - if err != nil { - return false, err - } - - // Get the existing node if any - var exist *structs.DirEntry - if len(res) > 0 { - exist = res[0].(*structs.DirEntry) - } - - // Use the ModifyIndex as the constraint. A modify of time of 0 - // means we are doing a set-if-not-exists, while any other value - // means we expect that modify time. - if mode == kvCAS { - if d.ModifyIndex == 0 && exist != nil { - return false, nil - } else if d.ModifyIndex > 0 && (exist == nil || exist.ModifyIndex != d.ModifyIndex) { - return false, nil - } - } - - // If attempting to lock, check this is possible - if mode == kvLock { - // Verify we have a session - if d.Session == "" { - return false, fmt.Errorf("Missing session") - } - - // Bail if it is already locked - if exist != nil && exist.Session != "" { - return false, nil - } - - // Verify the session exists - res, err := s.sessionTable.GetTxn(tx, "id", d.Session) - if err != nil { - return false, err - } - if len(res) == 0 { - return false, fmt.Errorf("Invalid session") - } - - // Update the lock index - if exist != nil { - exist.LockIndex++ - exist.Session = d.Session - } else { - d.LockIndex = 1 - } - } - - // If attempting to unlock, verify the key exists and is held - if mode == kvUnlock { - if exist == nil || exist.Session != d.Session { - return false, nil - } - // Clear the session to unlock - exist.Session = "" - } - - // Set the create and modify times - if exist == nil { - d.CreateIndex = index - } else { - d.CreateIndex = exist.CreateIndex - d.LockIndex = exist.LockIndex - d.Session = exist.Session - - } - d.ModifyIndex = index - - if err := s.kvsTable.InsertTxn(tx, d); err != nil { - return false, err - } - if err := s.kvsTable.SetLastIndexTxn(tx, index); err != nil { - return false, err - } - tx.Defer(func() { s.notifyKV(d.Key, false) }) - return true, tx.Commit() -} - -// ReapTombstones is used to delete all the tombstones with a ModifyTime -// less than or equal to the given index. This is used to prevent unbounded -// storage growth of the tombstones. -func (s *StateStore) ReapTombstones(index uint64) error { - tx, err := s.tombstoneTable.StartTxn(false, nil) - if err != nil { - return fmt.Errorf("failed to start txn: %v", err) - } - defer tx.Abort() - - // Scan the tombstone table for all the entries that are - // eligible for GC. This could be improved by indexing on - // ModifyTime and doing a less-than-equals scan, however - // we don't currently support numeric indexes internally. - // Luckily, this is a low frequency operation. - var toDelete []string - streamCh := make(chan interface{}, 128) - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) - for raw := range streamCh { - ent := raw.(*structs.DirEntry) - if ent.ModifyIndex <= index { - toDelete = append(toDelete, ent.Key) - } - } - }() - if err := s.tombstoneTable.StreamTxn(streamCh, tx, "id"); err != nil { - s.logger.Printf("[ERR] consul.state: failed to scan tombstones: %v", err) - return fmt.Errorf("failed to scan tombstones: %v", err) - } - <-doneCh - - // Delete each tombstone - if len(toDelete) > 0 { - s.logger.Printf("[DEBUG] consul.state: reaping %d tombstones up to %d", len(toDelete), index) - } - for _, key := range toDelete { - num, err := s.tombstoneTable.DeleteTxn(tx, "id", key) - if err != nil { - s.logger.Printf("[ERR] consul.state: failed to delete tombstone: %v", err) - return fmt.Errorf("failed to delete tombstone: %v", err) - } - if num != 1 { - return fmt.Errorf("failed to delete tombstone '%s'", key) - } - } - return tx.Commit() -} - -// TombstoneRestore is used to restore a tombstone. -// It should only be used when doing a restore. -func (s *StateStore) TombstoneRestore(d *structs.DirEntry) error { - // Start a new txn - tx, err := s.tombstoneTable.StartTxn(false, nil) - if err != nil { - return err - } - defer tx.Abort() - - if err := s.tombstoneTable.InsertTxn(tx, d); err != nil { - return err - } - return tx.Commit() -} - -// SessionCreate is used to create a new session. The -// ID will be populated on a successful return -func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error { - // Verify a Session ID is generated - if session.ID == "" { - return fmt.Errorf("Missing Session ID") - } - - switch session.Behavior { - case "": - // Default behavior is Release for backwards compatibility - session.Behavior = structs.SessionKeysRelease - case structs.SessionKeysRelease: - case structs.SessionKeysDelete: - default: - return fmt.Errorf("Invalid Session Behavior setting '%s'", session.Behavior) - } - - // Assign the create index - session.CreateIndex = index - - // Start the transaction - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - // Verify that the node exists - res, err := s.nodeTable.GetTxn(tx, "id", session.Node) - if err != nil { - return err - } - if len(res) == 0 { - return fmt.Errorf("Missing node registration") - } - - // Verify that the checks exist and are not critical - for _, checkId := range session.Checks { - res, err := s.checkTable.GetTxn(tx, "id", session.Node, checkId) - if err != nil { - return err - } - if len(res) == 0 { - return fmt.Errorf("Missing check '%s' registration", checkId) - } - chk := res[0].(*structs.HealthCheck) - if chk.Status == structs.HealthCritical { - return fmt.Errorf("Check '%s' is in %s state", checkId, chk.Status) - } - } - - // Insert the session - if err := s.sessionTable.InsertTxn(tx, session); err != nil { - return err - } - - // Insert the check mappings - sCheck := sessionCheck{Node: session.Node, Session: session.ID} - for _, checkID := range session.Checks { - sCheck.CheckID = checkID - if err := s.sessionCheckTable.InsertTxn(tx, &sCheck); err != nil { - return err - } - } - - // Trigger the update notifications - if err := s.sessionTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.sessionTable].Notify() }) - return tx.Commit() -} - -// SessionRestore is used to restore a session. It should only be used when -// doing a restore, otherwise SessionCreate should be used. -func (s *StateStore) SessionRestore(session *structs.Session) error { - // Start the transaction - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - // Insert the session - if err := s.sessionTable.InsertTxn(tx, session); err != nil { - return err - } - - // Insert the check mappings - sCheck := sessionCheck{Node: session.Node, Session: session.ID} - for _, checkID := range session.Checks { - sCheck.CheckID = checkID - if err := s.sessionCheckTable.InsertTxn(tx, &sCheck); err != nil { - return err - } - } - - // Trigger the update notifications - index := session.CreateIndex - if err := s.sessionTable.SetMaxLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.sessionTable].Notify() }) - return tx.Commit() -} - -// SessionGet is used to get a session entry -func (s *StateStore) SessionGet(id string) (uint64, *structs.Session, error) { - idx, res, err := s.sessionTable.Get("id", id) - var d *structs.Session - if len(res) > 0 { - d = res[0].(*structs.Session) - } - return idx, d, err -} - -// SessionList is used to list all the open sessions -func (s *StateStore) SessionList() (uint64, []*structs.Session, error) { - idx, res, err := s.sessionTable.Get("id") - out := make([]*structs.Session, len(res)) - for i, raw := range res { - out[i] = raw.(*structs.Session) - } - return idx, out, err -} - -// NodeSessions is used to list all the open sessions for a node -func (s *StateStore) NodeSessions(node string) (uint64, []*structs.Session, error) { - idx, res, err := s.sessionTable.Get("node", node) - out := make([]*structs.Session, len(res)) - for i, raw := range res { - out[i] = raw.(*structs.Session) - } - return idx, out, err -} - -// SessionDestroy is used to destroy a session. -func (s *StateStore) SessionDestroy(index uint64, id string) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - s.logger.Printf("[DEBUG] consul.state: Invalidating session %s due to session destroy", - id) - if err := s.invalidateSession(index, tx, id); err != nil { - return err - } - return tx.Commit() -} - -// invalidateNode is used to invalidate all sessions belonging to a node -// All tables should be locked in the tx. -func (s *StateStore) invalidateNode(index uint64, tx *MDBTxn, node string) error { - sessions, err := s.sessionTable.GetTxn(tx, "node", node) - if err != nil { - return err - } - for _, sess := range sessions { - session := sess.(*structs.Session).ID - s.logger.Printf("[DEBUG] consul.state: Invalidating session %s due to node '%s' invalidation", - session, node) - if err := s.invalidateSession(index, tx, session); err != nil { - return err - } - } - return nil -} - -// invalidateCheck is used to invalidate all sessions belonging to a check -// All tables should be locked in the tx. -func (s *StateStore) invalidateCheck(index uint64, tx *MDBTxn, node, check string) error { - sessionChecks, err := s.sessionCheckTable.GetTxn(tx, "id", node, check) - if err != nil { - return err - } - for _, sc := range sessionChecks { - session := sc.(*sessionCheck).Session - s.logger.Printf("[DEBUG] consul.state: Invalidating session %s due to check '%s' invalidation", - session, check) - if err := s.invalidateSession(index, tx, session); err != nil { - return err - } - } - return nil -} - -// invalidateSession is used to invalidate a session within a given txn -// All tables should be locked in the tx. -func (s *StateStore) invalidateSession(index uint64, tx *MDBTxn, id string) error { - // Get the session - res, err := s.sessionTable.GetTxn(tx, "id", id) - if err != nil { - return err - } - - // Quit if this session does not exist - if len(res) == 0 { - return nil - } - session := res[0].(*structs.Session) - - // Enforce the MaxLockDelay - delay := session.LockDelay - if delay > structs.MaxLockDelay { - delay = structs.MaxLockDelay - } - - // Invalidate any held locks - if session.Behavior == structs.SessionKeysDelete { - if err := s.deleteLocks(index, tx, delay, id); err != nil { - return err - } - } else if err := s.invalidateLocks(index, tx, delay, id); err != nil { - return err - } - - // Nuke the session - if _, err := s.sessionTable.DeleteTxn(tx, "id", id); err != nil { - return err - } - - // Delete the check mappings - for _, checkID := range session.Checks { - if _, err := s.sessionCheckTable.DeleteTxn(tx, "id", - session.Node, checkID, id); err != nil { - return err - } - } - - // Trigger the update notifications - if err := s.sessionTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.sessionTable].Notify() }) - return nil -} - -// invalidateLocks is used to invalidate all the locks held by a session -// within a given txn. All tables should be locked in the tx. -func (s *StateStore) invalidateLocks(index uint64, tx *MDBTxn, - lockDelay time.Duration, id string) error { - pairs, err := s.kvsTable.GetTxn(tx, "session", id) - if err != nil { - return err - } - - var expires time.Time - if lockDelay > 0 { - s.lockDelayLock.Lock() - defer s.lockDelayLock.Unlock() - expires = time.Now().Add(lockDelay) - } - - for _, pair := range pairs { - kv := pair.(*structs.DirEntry) - kv.Session = "" // Clear the lock - kv.ModifyIndex = index // Update the modified time - if err := s.kvsTable.InsertTxn(tx, kv); err != nil { - return err - } - // If there is a lock delay, prevent acquisition - // for at least lockDelay period - if lockDelay > 0 { - s.lockDelay[kv.Key] = expires - time.AfterFunc(lockDelay, func() { - s.lockDelayLock.Lock() - delete(s.lockDelay, kv.Key) - s.lockDelayLock.Unlock() - }) - } - tx.Defer(func() { s.notifyKV(kv.Key, false) }) - } - if len(pairs) > 0 { - if err := s.kvsTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - } - return nil -} - -// deleteLocks is used to delete all the locks held by a session -// within a given txn. All tables should be locked in the tx. -func (s *StateStore) deleteLocks(index uint64, tx *MDBTxn, - lockDelay time.Duration, id string) error { - pairs, err := s.kvsTable.GetTxn(tx, "session", id) - if err != nil { - return err - } - - var expires time.Time - if lockDelay > 0 { - s.lockDelayLock.Lock() - defer s.lockDelayLock.Unlock() - expires = time.Now().Add(lockDelay) - } - - for _, pair := range pairs { - kv := pair.(*structs.DirEntry) - if err := s.kvsDeleteWithIndexTxn(index, tx, "id", kv.Key); err != nil { - return err - } - - // If there is a lock delay, prevent acquisition - // for at least lockDelay period - if lockDelay > 0 { - s.lockDelay[kv.Key] = expires - time.AfterFunc(lockDelay, func() { - s.lockDelayLock.Lock() - delete(s.lockDelay, kv.Key) - s.lockDelayLock.Unlock() - }) - } - } - return nil -} - -// ACLSet is used to create or update an ACL entry -func (s *StateStore) ACLSet(index uint64, acl *structs.ACL) error { - // Check for an ID - if acl.ID == "" { - return fmt.Errorf("Missing ACL ID") - } - - // Start a new txn - tx, err := s.tables.StartTxn(false) - if err != nil { - return err - } - defer tx.Abort() - - // Look for the existing node - res, err := s.aclTable.GetTxn(tx, "id", acl.ID) - if err != nil { - return err - } - - switch len(res) { - case 0: - acl.CreateIndex = index - acl.ModifyIndex = index - case 1: - exist := res[0].(*structs.ACL) - acl.CreateIndex = exist.CreateIndex - acl.ModifyIndex = index - default: - panic(fmt.Errorf("Duplicate ACL definition. Internal error")) - } - - // Insert the ACL - if err := s.aclTable.InsertTxn(tx, acl); err != nil { - return err - } - - // Trigger the update notifications - if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.aclTable].Notify() }) - return tx.Commit() -} - -// ACLRestore is used to restore an ACL. It should only be used when -// doing a restore, otherwise ACLSet should be used. -func (s *StateStore) ACLRestore(acl *structs.ACL) error { - // Start a new txn - tx, err := s.aclTable.StartTxn(false, nil) - if err != nil { - return err - } - defer tx.Abort() - - if err := s.aclTable.InsertTxn(tx, acl); err != nil { - return err - } - if err := s.aclTable.SetMaxLastIndexTxn(tx, acl.ModifyIndex); err != nil { - return err - } - return tx.Commit() -} - -// ACLGet is used to get an ACL by ID -func (s *StateStore) ACLGet(id string) (uint64, *structs.ACL, error) { - idx, res, err := s.aclTable.Get("id", id) - var d *structs.ACL - if len(res) > 0 { - d = res[0].(*structs.ACL) - } - return idx, d, err -} - -// ACLList is used to list all the acls -func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) { - idx, res, err := s.aclTable.Get("id") - out := make([]*structs.ACL, len(res)) - for i, raw := range res { - out[i] = raw.(*structs.ACL) - } - return idx, out, err -} - -// ACLDelete is used to remove an ACL -func (s *StateStore) ACLDelete(index uint64, id string) error { - tx, err := s.tables.StartTxn(false) - if err != nil { - panic(fmt.Errorf("Failed to start txn: %v", err)) - } - defer tx.Abort() - - if n, err := s.aclTable.DeleteTxn(tx, "id", id); err != nil { - return err - } else if n > 0 { - if err := s.aclTable.SetLastIndexTxn(tx, index); err != nil { - return err - } - tx.Defer(func() { s.watch[s.aclTable].Notify() }) - } - return tx.Commit() -} - -// Snapshot is used to create a point in time snapshot -func (s *StateStore) Snapshot() (*StateSnapshot, error) { - // Begin a new txn on all tables - tx, err := s.tables.StartTxn(true) - if err != nil { - return nil, err - } - - // Determine the max index - index, err := s.tables.LastIndexTxn(tx) - if err != nil { - tx.Abort() - return nil, err - } - - // Return the snapshot - snap := &StateSnapshot{ - store: s, - tx: tx, - lastIndex: index, - } - return snap, nil -} - -// LastIndex returns the last index that affects the snapshotted data -func (s *StateSnapshot) LastIndex() uint64 { - return s.lastIndex -} - -// Nodes returns all the known nodes, the slice alternates between -// the node name and address -func (s *StateSnapshot) Nodes() structs.Nodes { - res, err := s.store.nodeTable.GetTxn(s.tx, "id") - if err != nil { - s.store.logger.Printf("[ERR] consul.state: Failed to get nodes: %v", err) - return nil - } - results := make([]structs.Node, len(res)) - for i, r := range res { - results[i] = *r.(*structs.Node) - } - return results -} - -// NodeServices is used to return all the services of a given node -func (s *StateSnapshot) NodeServices(name string) *structs.NodeServices { - _, res := s.store.parseNodeServices(s.store.tables, s.tx, name) - return res -} - -// NodeChecks is used to return all the checks of a given node -func (s *StateSnapshot) NodeChecks(node string) structs.HealthChecks { - res, err := s.store.checkTable.GetTxn(s.tx, "id", node) - _, checks := s.store.parseHealthChecks(s.lastIndex, res, err) - return checks -} - -// KVSDump is used to list all KV entries. It takes a channel and streams -// back *struct.DirEntry objects. This will block and should be invoked -// in a goroutine. -func (s *StateSnapshot) KVSDump(stream chan<- interface{}) error { - return s.store.kvsTable.StreamTxn(stream, s.tx, "id") -} - -// TombstoneDump is used to dump all tombstone entries. It takes a channel and streams -// back *struct.DirEntry objects. This will block and should be invoked -// in a goroutine. -func (s *StateSnapshot) TombstoneDump(stream chan<- interface{}) error { - return s.store.tombstoneTable.StreamTxn(stream, s.tx, "id") -} - -// SessionList is used to list all the open sessions -func (s *StateSnapshot) SessionList() ([]*structs.Session, error) { - res, err := s.store.sessionTable.GetTxn(s.tx, "id") - out := make([]*structs.Session, len(res)) - for i, raw := range res { - out[i] = raw.(*structs.Session) - } - return out, err -} - -// ACLList is used to list all of the ACLs -func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) { - res, err := s.store.aclTable.GetTxn(s.tx, "id") - out := make([]*structs.ACL, len(res)) - for i, raw := range res { - out[i] = raw.(*structs.ACL) - } - return out, err -} diff --git a/consul/state_store_test.go b/consul/state_store_test.go deleted file mode 100644 index 13a0044ba57b..000000000000 --- a/consul/state_store_test.go +++ /dev/null @@ -1,3024 +0,0 @@ -package consul - -import ( - "os" - "reflect" - "sort" - "testing" - "time" - - "github.com/hashicorp/consul/consul/structs" -) - -func testStateStore() (*StateStore, error) { - return NewStateStore(nil, os.Stderr) -} - -func TestEnsureRegistration(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - reg := &structs.RegisterRequest{ - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{"api", "api", nil, "", 5000, false}, - Check: &structs.HealthCheck{ - Node: "foo", - CheckID: "api", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "api", - }, - Checks: structs.HealthChecks{ - &structs.HealthCheck{ - Node: "foo", - CheckID: "api-cache", - Name: "Can cache stuff", - Status: structs.HealthPassing, - ServiceID: "api", - }, - }, - } - - if err := store.EnsureRegistration(13, reg); err != nil { - t.Fatalf("err: %v", err) - } - - idx, found, addr := store.GetNode("foo") - if idx != 13 || !found || addr != "127.0.0.1" { - t.Fatalf("Bad: %v %v %v", idx, found, addr) - } - - idx, services := store.NodeServices("foo") - if idx != 13 { - t.Fatalf("bad: %v", idx) - } - - entry, ok := services.Services["api"] - if !ok { - t.Fatalf("missing api: %#v", services) - } - if len(entry.Tags) != 0 || entry.Port != 5000 { - t.Fatalf("Bad entry: %#v", entry) - } - - idx, checks := store.NodeChecks("foo") - if idx != 13 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 2 { - t.Fatalf("check: %#v", checks) - } -} - -func TestEnsureNode(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, found, addr := store.GetNode("foo") - if idx != 3 || !found || addr != "127.0.0.1" { - t.Fatalf("Bad: %v %v %v", idx, found, addr) - } - - if err := store.EnsureNode(4, structs.Node{"foo", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, found, addr = store.GetNode("foo") - if idx != 4 || !found || addr != "127.0.0.2" { - t.Fatalf("Bad: %v %v %v", idx, found, addr) - } -} - -func TestGetNodes(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(40, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureNode(41, structs.Node{"bar", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes := store.Nodes() - if idx != 41 { - t.Fatalf("idx: %v", idx) - } - if len(nodes) != 2 { - t.Fatalf("Bad: %v", nodes) - } - if nodes[1].Node != "foo" && nodes[0].Node != "bar" { - t.Fatalf("Bad: %v", nodes) - } -} - -func TestGetNodes_Watch_StopWatch(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - notify1 := make(chan struct{}, 1) - notify2 := make(chan struct{}, 1) - - store.Watch(store.QueryTables("Nodes"), notify1) - store.Watch(store.QueryTables("Nodes"), notify2) - store.StopWatch(store.QueryTables("Nodes"), notify2) - - if err := store.EnsureNode(40, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - select { - case <-notify1: - default: - t.Fatalf("should be notified") - } - - select { - case <-notify2: - t.Fatalf("should not be notified") - default: - } -} - -func BenchmarkGetNodes(b *testing.B) { - store, err := testStateStore() - if err != nil { - b.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(100, structs.Node{"foo", "127.0.0.1"}); err != nil { - b.Fatalf("err: %v", err) - } - - if err := store.EnsureNode(101, structs.Node{"bar", "127.0.0.2"}); err != nil { - b.Fatalf("err: %v", err) - } - - for i := 0; i < b.N; i++ { - store.Nodes() - } -} - -func TestEnsureService(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(10, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(11, "foo", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(12, "foo", &structs.NodeService{"api", "api", nil, "", 5001, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(13, "foo", &structs.NodeService{"db", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, services := store.NodeServices("foo") - if idx != 13 { - t.Fatalf("bad: %v", idx) - } - - entry, ok := services.Services["api"] - if !ok { - t.Fatalf("missing api: %#v", services) - } - if len(entry.Tags) != 0 || entry.Port != 5001 { - t.Fatalf("Bad entry: %#v", entry) - } - - entry, ok = services.Services["db"] - if !ok { - t.Fatalf("missing db: %#v", services) - } - if !strContains(entry.Tags, "master") || entry.Port != 8000 { - t.Fatalf("Bad entry: %#v", entry) - } -} - -func TestEnsureService_DuplicateNode(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(10, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(11, "foo", &structs.NodeService{"api1", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(12, "foo", &structs.NodeService{"api2", "api", nil, "", 5001, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(13, "foo", &structs.NodeService{"api3", "api", nil, "", 5002, false}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, services := store.NodeServices("foo") - if idx != 13 { - t.Fatalf("bad: %v", idx) - } - - entry, ok := services.Services["api1"] - if !ok { - t.Fatalf("missing api: %#v", services) - } - if len(entry.Tags) != 0 || entry.Port != 5000 { - t.Fatalf("Bad entry: %#v", entry) - } - - entry, ok = services.Services["api2"] - if !ok { - t.Fatalf("missing api: %#v", services) - } - if len(entry.Tags) != 0 || entry.Port != 5001 { - t.Fatalf("Bad entry: %#v", entry) - } - - entry, ok = services.Services["api3"] - if !ok { - t.Fatalf("missing api: %#v", services) - } - if len(entry.Tags) != 0 || entry.Port != 5002 { - t.Fatalf("Bad entry: %#v", entry) - } -} - -func TestDeleteNodeService(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(11, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(12, "foo", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "api", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "api", - } - if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.DeleteNodeService(14, "foo", "api"); err != nil { - t.Fatalf("err: %v", err) - } - - idx, services := store.NodeServices("foo") - if idx != 14 { - t.Fatalf("bad: %v", idx) - } - _, ok := services.Services["api"] - if ok { - t.Fatalf("has api: %#v", services) - } - - idx, checks := store.NodeChecks("foo") - if idx != 14 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 0 { - t.Fatalf("has check: %#v", checks) - } -} - -func TestDeleteNodeService_One(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(11, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(12, "foo", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(13, "foo", &structs.NodeService{"api2", "api", nil, "", 5001, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.DeleteNodeService(14, "foo", "api"); err != nil { - t.Fatalf("err: %v", err) - } - - idx, services := store.NodeServices("foo") - if idx != 14 { - t.Fatalf("bad: %v", idx) - } - _, ok := services.Services["api"] - if ok { - t.Fatalf("has api: %#v", services) - } - _, ok = services.Services["api2"] - if !ok { - t.Fatalf("does not have api2: %#v", services) - } -} - -func TestDeleteNode(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(20, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(21, "foo", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "api", - } - if err := store.EnsureCheck(22, check); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.DeleteNode(23, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - - idx, services := store.NodeServices("foo") - if idx != 23 { - t.Fatalf("bad: %v", idx) - } - if services != nil { - t.Fatalf("has services: %#v", services) - } - - idx, checks := store.NodeChecks("foo") - if idx != 23 { - t.Fatalf("bad: %v", idx) - } - if len(checks) > 0 { - t.Fatalf("has checks: %v", checks) - } - - idx, found, _ := store.GetNode("foo") - if idx != 23 { - t.Fatalf("bad: %v", idx) - } - if found { - t.Fatalf("found node") - } -} - -func TestGetServices(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(30, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureNode(31, structs.Node{"bar", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(32, "foo", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(33, "foo", &structs.NodeService{"db", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(34, "bar", &structs.NodeService{"db", "db", []string{"slave"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, services := store.Services() - if idx != 34 { - t.Fatalf("bad: %v", idx) - } - - tags, ok := services["api"] - if !ok { - t.Fatalf("missing api: %#v", services) - } - if len(tags) != 0 { - t.Fatalf("Bad entry: %#v", tags) - } - - tags, ok = services["db"] - sort.Strings(tags) - if !ok { - t.Fatalf("missing db: %#v", services) - } - if len(tags) != 2 || tags[0] != "master" || tags[1] != "slave" { - t.Fatalf("Bad entry: %#v", tags) - } -} - -func TestServiceNodes(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(10, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureNode(11, structs.Node{"bar", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(12, "foo", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(13, "bar", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(14, "foo", &structs.NodeService{"db", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(15, "bar", &structs.NodeService{"db", "db", []string{"slave"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(16, "bar", &structs.NodeService{"db2", "db", []string{"slave"}, "", 8001, false}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes := store.ServiceNodes("db") - if idx != 16 { - t.Fatalf("bad: %v", 16) - } - if len(nodes) != 3 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServiceID != "db" { - t.Fatalf("bad: %v", nodes) - } - if !strContains(nodes[0].ServiceTags, "master") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } - - if nodes[1].Node != "bar" { - t.Fatalf("bad: %v", nodes) - } - if nodes[1].Address != "127.0.0.2" { - t.Fatalf("bad: %v", nodes) - } - if nodes[1].ServiceID != "db" { - t.Fatalf("bad: %v", nodes) - } - if !strContains(nodes[1].ServiceTags, "slave") { - t.Fatalf("bad: %v", nodes) - } - if nodes[1].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } - - if nodes[2].Node != "bar" { - t.Fatalf("bad: %v", nodes) - } - if nodes[2].Address != "127.0.0.2" { - t.Fatalf("bad: %v", nodes) - } - if nodes[2].ServiceID != "db2" { - t.Fatalf("bad: %v", nodes) - } - if !strContains(nodes[2].ServiceTags, "slave") { - t.Fatalf("bad: %v", nodes) - } - if nodes[2].ServicePort != 8001 { - t.Fatalf("bad: %v", nodes) - } -} - -func TestServiceTagNodes(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(15, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureNode(16, structs.Node{"bar", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(17, "foo", &structs.NodeService{"db", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(18, "foo", &structs.NodeService{"db2", "db", []string{"slave"}, "", 8001, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(19, "bar", &structs.NodeService{"db", "db", []string{"slave"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes := store.ServiceTagNodes("db", "master") - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if !strContains(nodes[0].ServiceTags, "master") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } -} - -func TestServiceTagNodes_MultipleTags(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(15, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureNode(16, structs.Node{"bar", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(17, "foo", &structs.NodeService{"db", "db", []string{"master", "v2"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(18, "foo", &structs.NodeService{"db2", "db", []string{"slave", "v2", "dev"}, "", 8001, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(19, "bar", &structs.NodeService{"db", "db", []string{"slave", "v2"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes := store.ServiceTagNodes("db", "master") - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if !strContains(nodes[0].ServiceTags, "master") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8000 { - t.Fatalf("bad: %v", nodes) - } - - idx, nodes = store.ServiceTagNodes("db", "v2") - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 3 { - t.Fatalf("bad: %v", nodes) - } - - idx, nodes = store.ServiceTagNodes("db", "dev") - if idx != 19 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Node != "foo" { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].Address != "127.0.0.1" { - t.Fatalf("bad: %v", nodes) - } - if !strContains(nodes[0].ServiceTags, "dev") { - t.Fatalf("bad: %v", nodes) - } - if nodes[0].ServicePort != 8001 { - t.Fatalf("bad: %v", nodes) - } -} - -func TestStoreSnapshot(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(8, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureNode(9, structs.Node{"bar", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(10, "foo", &structs.NodeService{"db", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(11, "foo", &structs.NodeService{"db2", "db", []string{"slave"}, "", 8001, false}); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.EnsureService(12, "bar", &structs.NodeService{"db", "db", []string{"slave"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "db", - } - if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - - // Add some KVS entries - d := &structs.DirEntry{Key: "/web/a", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(14, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/b", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(15, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/c", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(16, d); err != nil { - t.Fatalf("err: %v", err) - } - // Create a tombstone - // TODO: Change to /web/c causes failure? - if err := store.KVSDelete(17, "/web/a"); err != nil { - t.Fatalf("err: %v", err) - } - - // Add some sessions - session := &structs.Session{ID: generateUUID(), Node: "foo"} - if err := store.SessionCreate(18, session); err != nil { - t.Fatalf("err: %v", err) - } - - session = &structs.Session{ID: generateUUID(), Node: "bar"} - if err := store.SessionCreate(19, session); err != nil { - t.Fatalf("err: %v", err) - } - d.Session = session.ID - if ok, err := store.KVSLock(20, d); err != nil || !ok { - t.Fatalf("err: %v", err) - } - session = &structs.Session{ID: generateUUID(), Node: "bar", TTL: "60s"} - if err := store.SessionCreate(21, session); err != nil { - t.Fatalf("err: %v", err) - } - - a1 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(21, a1); err != nil { - t.Fatalf("err: %v", err) - } - - a2 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(22, a2); err != nil { - t.Fatalf("err: %v", err) - } - - // Take a snapshot - snap, err := store.Snapshot() - if err != nil { - t.Fatalf("err: %v", err) - } - defer snap.Close() - - // Check the last nodes - if idx := snap.LastIndex(); idx != 22 { - t.Fatalf("bad: %v", idx) - } - - // Check snapshot has old values - nodes := snap.Nodes() - if len(nodes) != 2 { - t.Fatalf("bad: %v", nodes) - } - - // Ensure we get the service entries - services := snap.NodeServices("foo") - if !strContains(services.Services["db"].Tags, "master") { - t.Fatalf("bad: %v", services) - } - if !strContains(services.Services["db2"].Tags, "slave") { - t.Fatalf("bad: %v", services) - } - - services = snap.NodeServices("bar") - if !strContains(services.Services["db"].Tags, "slave") { - t.Fatalf("bad: %v", services) - } - - // Ensure we get the checks - checks := snap.NodeChecks("foo") - if len(checks) != 1 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check) { - t.Fatalf("bad: %v", checks[0]) - } - - // Check we have the entries - streamCh := make(chan interface{}, 64) - doneCh := make(chan struct{}) - var ents []*structs.DirEntry - go func() { - for { - obj := <-streamCh - if obj == nil { - close(doneCh) - return - } - ents = append(ents, obj.(*structs.DirEntry)) - } - }() - if err := snap.KVSDump(streamCh); err != nil { - t.Fatalf("err: %v", err) - } - <-doneCh - if len(ents) != 2 { - t.Fatalf("missing KVS entries! %#v", ents) - } - - // Check we have the tombstone entries - streamCh = make(chan interface{}, 64) - doneCh = make(chan struct{}) - ents = nil - go func() { - for { - obj := <-streamCh - if obj == nil { - close(doneCh) - return - } - ents = append(ents, obj.(*structs.DirEntry)) - } - }() - if err := snap.TombstoneDump(streamCh); err != nil { - t.Fatalf("err: %v", err) - } - <-doneCh - if len(ents) != 1 { - t.Fatalf("missing tombstone entries!") - } - - // Check there are 3 sessions - sessions, err := snap.SessionList() - if err != nil { - t.Fatalf("err: %v", err) - } - if len(sessions) != 3 { - t.Fatalf("missing sessions") - } - - ttls := 0 - for _, session := range sessions { - if session.TTL != "" { - ttls++ - } - } - if ttls != 1 { - t.Fatalf("Wrong number of sessions with TTL") - } - - // Check for an acl - acls, err := snap.ACLList() - if err != nil { - t.Fatalf("err: %v", err) - } - if len(acls) != 2 { - t.Fatalf("missing acls") - } - - // Make some changes! - if err := store.EnsureService(23, "foo", &structs.NodeService{"db", "db", []string{"slave"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(24, "bar", &structs.NodeService{"db", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureNode(25, structs.Node{"baz", "127.0.0.3"}); err != nil { - t.Fatalf("err: %v", err) - } - checkAfter := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthCritical, - ServiceID: "db", - } - if err := store.EnsureCheck(27, checkAfter); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.KVSDelete(28, "/web/b"); err != nil { - t.Fatalf("err: %v", err) - } - - // Nuke an ACL - if err := store.ACLDelete(29, a1.ID); err != nil { - t.Fatalf("err: %v", err) - } - - // Check snapshot has old values - nodes = snap.Nodes() - if len(nodes) != 2 { - t.Fatalf("bad: %v", nodes) - } - - // Ensure old service entries - services = snap.NodeServices("foo") - if !strContains(services.Services["db"].Tags, "master") { - t.Fatalf("bad: %v", services) - } - if !strContains(services.Services["db2"].Tags, "slave") { - t.Fatalf("bad: %v", services) - } - - services = snap.NodeServices("bar") - if !strContains(services.Services["db"].Tags, "slave") { - t.Fatalf("bad: %v", services) - } - - checks = snap.NodeChecks("foo") - if len(checks) != 1 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check) { - t.Fatalf("bad: %v", checks[0]) - } - - // Check we have the entries - streamCh = make(chan interface{}, 64) - doneCh = make(chan struct{}) - ents = nil - go func() { - for { - obj := <-streamCh - if obj == nil { - close(doneCh) - return - } - ents = append(ents, obj.(*structs.DirEntry)) - } - }() - if err := snap.KVSDump(streamCh); err != nil { - t.Fatalf("err: %v", err) - } - <-doneCh - if len(ents) != 2 { - t.Fatalf("missing KVS entries!") - } - - // Check we have the tombstone entries - streamCh = make(chan interface{}, 64) - doneCh = make(chan struct{}) - ents = nil - go func() { - for { - obj := <-streamCh - if obj == nil { - close(doneCh) - return - } - ents = append(ents, obj.(*structs.DirEntry)) - } - }() - if err := snap.TombstoneDump(streamCh); err != nil { - t.Fatalf("err: %v", err) - } - <-doneCh - if len(ents) != 1 { - t.Fatalf("missing tombstone entries!") - } - - // Check there are 3 sessions - sessions, err = snap.SessionList() - if err != nil { - t.Fatalf("err: %v", err) - } - if len(sessions) != 3 { - t.Fatalf("missing sessions") - } - - // Check for an acl - acls, err = snap.ACLList() - if err != nil { - t.Fatalf("err: %v", err) - } - if len(acls) != 2 { - t.Fatalf("missing acls") - } -} - -func TestEnsureCheck(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(2, "foo", &structs.NodeService{"db1", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "db1", - } - if err := store.EnsureCheck(3, check); err != nil { - t.Fatalf("err: %v", err) - } - - check2 := &structs.HealthCheck{ - Node: "foo", - CheckID: "memory", - Name: "memory utilization", - Status: structs.HealthWarning, - } - if err := store.EnsureCheck(4, check2); err != nil { - t.Fatalf("err: %v", err) - } - - idx, checks := store.NodeChecks("foo") - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 2 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check) { - t.Fatalf("bad: %v", checks[0]) - } - if !reflect.DeepEqual(checks[1], check2) { - t.Fatalf("bad: %v", checks[1]) - } - - idx, checks = store.ServiceChecks("db") - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 1 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check) { - t.Fatalf("bad: %v", checks[0]) - } - - idx, checks = store.ChecksInState(structs.HealthPassing) - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 1 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check) { - t.Fatalf("bad: %v", checks[0]) - } - - idx, checks = store.ChecksInState(structs.HealthWarning) - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 1 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check2) { - t.Fatalf("bad: %v", checks[0]) - } - - idx, checks = store.ChecksInState(structs.HealthAny) - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 2 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check) { - t.Fatalf("bad: %v", checks[0]) - } - if !reflect.DeepEqual(checks[1], check2) { - t.Fatalf("bad: %v", checks[1]) - } -} - -func TestDeleteNodeCheck(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(2, "foo", &structs.NodeService{"db1", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "db1", - } - if err := store.EnsureCheck(3, check); err != nil { - t.Fatalf("err: %v", err) - } - - check2 := &structs.HealthCheck{ - Node: "foo", - CheckID: "memory", - Name: "memory utilization", - Status: structs.HealthWarning, - } - if err := store.EnsureCheck(4, check2); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.DeleteNodeCheck(5, "foo", "db"); err != nil { - t.Fatalf("err: %v", err) - } - - idx, checks := store.NodeChecks("foo") - if idx != 5 { - t.Fatalf("bad: %v", idx) - } - if len(checks) != 1 { - t.Fatalf("bad: %v", checks) - } - if !reflect.DeepEqual(checks[0], check2) { - t.Fatalf("bad: %v", checks[0]) - } -} - -func TestCheckServiceNodes(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(2, "foo", &structs.NodeService{"db1", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "db1", - } - if err := store.EnsureCheck(3, check); err != nil { - t.Fatalf("err: %v", err) - } - check = &structs.HealthCheck{ - Node: "foo", - CheckID: SerfCheckID, - Name: SerfCheckName, - Status: structs.HealthPassing, - } - if err := store.EnsureCheck(4, check); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes := store.CheckServiceNodes("db") - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("Bad: %v", nodes) - } - - if nodes[0].Node.Node != "foo" { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Service.ID != "db1" { - t.Fatalf("Bad: %v", nodes[0]) - } - if len(nodes[0].Checks) != 2 { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Checks[0].CheckID != "db" { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Checks[1].CheckID != SerfCheckID { - t.Fatalf("Bad: %v", nodes[0]) - } - - idx, nodes = store.CheckServiceTagNodes("db", "master") - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 1 { - t.Fatalf("Bad: %v", nodes) - } - - if nodes[0].Node.Node != "foo" { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Service.ID != "db1" { - t.Fatalf("Bad: %v", nodes[0]) - } - if len(nodes[0].Checks) != 2 { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Checks[0].CheckID != "db" { - t.Fatalf("Bad: %v", nodes[0]) - } - if nodes[0].Checks[1].CheckID != SerfCheckID { - t.Fatalf("Bad: %v", nodes[0]) - } -} -func BenchmarkCheckServiceNodes(t *testing.B) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(2, "foo", &structs.NodeService{"db1", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "db1", - } - if err := store.EnsureCheck(3, check); err != nil { - t.Fatalf("err: %v", err) - } - check = &structs.HealthCheck{ - Node: "foo", - CheckID: SerfCheckID, - Name: SerfCheckName, - Status: structs.HealthPassing, - } - if err := store.EnsureCheck(4, check); err != nil { - t.Fatalf("err: %v", err) - } - - for i := 0; i < t.N; i++ { - store.CheckServiceNodes("db") - } -} - -func TestSS_Register_Deregister_Query(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - srv := &structs.NodeService{ - "statsite-box-stats", - "statsite-box-stats", - nil, - "", - 0, - false} - if err := store.EnsureService(2, "foo", srv); err != nil { - t.Fatalf("err: %v", err) - } - - srv = &structs.NodeService{ - "statsite-share-stats", - "statsite-share-stats", - nil, - "", - 0, - false} - if err := store.EnsureService(3, "foo", srv); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.DeleteNode(4, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - - idx, nodes := store.CheckServiceNodes("statsite-share-stats") - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(nodes) != 0 { - t.Fatalf("Bad: %v", nodes) - } -} - -func TestNodeInfo(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(2, "foo", &structs.NodeService{"db1", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "db", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "db1", - } - if err := store.EnsureCheck(3, check); err != nil { - t.Fatalf("err: %v", err) - } - check = &structs.HealthCheck{ - Node: "foo", - CheckID: SerfCheckID, - Name: SerfCheckName, - Status: structs.HealthPassing, - } - if err := store.EnsureCheck(4, check); err != nil { - t.Fatalf("err: %v", err) - } - - idx, dump := store.NodeInfo("foo") - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(dump) != 1 { - t.Fatalf("Bad: %v", dump) - } - - info := dump[0] - if info.Node != "foo" { - t.Fatalf("Bad: %v", info) - } - if info.Services[0].ID != "db1" { - t.Fatalf("Bad: %v", info) - } - if len(info.Checks) != 2 { - t.Fatalf("Bad: %v", info) - } - if info.Checks[0].CheckID != "db" { - t.Fatalf("Bad: %v", info) - } - if info.Checks[1].CheckID != SerfCheckID { - t.Fatalf("Bad: %v", info) - } -} - -func TestNodeDump(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(2, "foo", &structs.NodeService{"db1", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureNode(3, structs.Node{"baz", "127.0.0.2"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(4, "baz", &structs.NodeService{"db1", "db", []string{"master"}, "", 8000, false}); err != nil { - t.Fatalf("err: %v", err) - } - - idx, dump := store.NodeDump() - if idx != 4 { - t.Fatalf("bad: %v", idx) - } - if len(dump) != 2 { - t.Fatalf("Bad: %v", dump) - } - - info := dump[0] - if info.Node != "baz" { - t.Fatalf("Bad: %v", info) - } - if info.Services[0].ID != "db1" { - t.Fatalf("Bad: %v", info) - } - info = dump[1] - if info.Node != "foo" { - t.Fatalf("Bad: %v", info) - } - if info.Services[0].ID != "db1" { - t.Fatalf("Bad: %v", info) - } -} - -func TestKVSSet_Watch(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - notify1 := make(chan struct{}, 1) - notify2 := make(chan struct{}, 1) - notify3 := make(chan struct{}, 1) - - store.WatchKV("", notify1) - store.WatchKV("foo/", notify2) - store.WatchKV("foo/bar", notify3) - - // Create the entry - d := &structs.DirEntry{Key: "foo/baz", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Check that we've fired notify1 and notify2 - select { - case <-notify1: - default: - t.Fatalf("should notify root") - } - select { - case <-notify2: - default: - t.Fatalf("should notify foo/") - } - select { - case <-notify3: - t.Fatalf("should not notify foo/bar") - default: - } -} - -func TestKVSSet_Watch_Stop(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - notify1 := make(chan struct{}, 1) - - store.WatchKV("", notify1) - store.StopWatchKV("", notify1) - - // Create the entry - d := &structs.DirEntry{Key: "foo/baz", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Check that we've not fired notify1 - select { - case <-notify1: - t.Fatalf("should not notify ") - default: - } -} - -func TestKVSSet_Get(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // Should not exist - idx, d, err := store.KVSGet("/foo") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 0 { - t.Fatalf("bad: %v", idx) - } - if d != nil { - t.Fatalf("bad: %v", d) - } - - // Create the entry - d = &structs.DirEntry{Key: "/foo", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Should exist exist - idx, d, err = store.KVSGet("/foo") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1000 { - t.Fatalf("bad: %v", idx) - } - if d.CreateIndex != 1000 { - t.Fatalf("bad: %v", d) - } - if d.ModifyIndex != 1000 { - t.Fatalf("bad: %v", d) - } - if d.Key != "/foo" { - t.Fatalf("bad: %v", d) - } - if d.Flags != 42 { - t.Fatalf("bad: %v", d) - } - if string(d.Value) != "test" { - t.Fatalf("bad: %v", d) - } - - // Update the entry - d = &structs.DirEntry{Key: "/foo", Flags: 43, Value: []byte("zip")} - if err := store.KVSSet(1010, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Should update - idx, d, err = store.KVSGet("/foo") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1010 { - t.Fatalf("bad: %v", idx) - } - if d.CreateIndex != 1000 { - t.Fatalf("bad: %v", d) - } - if d.ModifyIndex != 1010 { - t.Fatalf("bad: %v", d) - } - if d.Key != "/foo" { - t.Fatalf("bad: %v", d) - } - if d.Flags != 43 { - t.Fatalf("bad: %v", d) - } - if string(d.Value) != "zip" { - t.Fatalf("bad: %v", d) - } -} - -func TestKVSDelete(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - ttl := 10 * time.Millisecond - gran := 5 * time.Millisecond - gc, err := NewTombstoneGC(ttl, gran) - if err != nil { - t.Fatalf("err: %v", err) - } - gc.SetEnabled(true) - store.gc = gc - - // Create the entry - d := &structs.DirEntry{Key: "/foo", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - - notify1 := make(chan struct{}, 1) - store.WatchKV("/", notify1) - - // Delete the entry - if err := store.KVSDelete(1020, "/foo"); err != nil { - t.Fatalf("err: %v", err) - } - - // Check that we've fired notify1 - select { - case <-notify1: - default: - t.Fatalf("should notify /") - } - - // Should not exist - idx, d, err := store.KVSGet("/foo") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1020 { - t.Fatalf("bad: %v", idx) - } - if d != nil { - t.Fatalf("bad: %v", d) - } - - // Check tombstone exists - _, res, err := store.tombstoneTable.Get("id", "/foo") - if err != nil { - t.Fatalf("err: %v", err) - } - if res == nil || res[0].(*structs.DirEntry).ModifyIndex != 1020 { - t.Fatalf("bad: %#v", d) - } - - // Check that we get a delete - select { - case idx := <-gc.ExpireCh(): - if idx != 1020 { - t.Fatalf("bad %d", idx) - } - case <-time.After(20 * time.Millisecond): - t.Fatalf("should expire") - } -} - -func TestKVSDeleteCheckAndSet(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // CAS should fail, no entry - ok, err := store.KVSDeleteCheckAndSet(1000, "/foo", 100) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("unexpected commit") - } - - // CAS should work, no entry - ok, err = store.KVSDeleteCheckAndSet(1000, "/foo", 0) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected failure") - } - - // Make an entry - d := &structs.DirEntry{Key: "/foo"} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Constrain on a wrong modify time - ok, err = store.KVSDeleteCheckAndSet(1001, "/foo", 42) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("unexpected commit") - } - - // Constrain on a correct modify time - ok, err = store.KVSDeleteCheckAndSet(1002, "/foo", 1000) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("expected commit") - } -} - -func TestKVSCheckAndSet(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // CAS should fail, no entry - d := &structs.DirEntry{ - ModifyIndex: 100, - Key: "/foo", - Flags: 42, - Value: []byte("test"), - } - ok, err := store.KVSCheckAndSet(1000, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("unexpected commit") - } - - // Constrain on not-exist, should work - d.ModifyIndex = 0 - ok, err = store.KVSCheckAndSet(1001, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("expected commit") - } - - // Constrain on not-exist, should fail - d.ModifyIndex = 0 - ok, err = store.KVSCheckAndSet(1002, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("unexpected commit") - } - - // Constrain on a wrong modify time - d.ModifyIndex = 1000 - ok, err = store.KVSCheckAndSet(1003, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("unexpected commit") - } - - // Constrain on a correct modify time - d.ModifyIndex = 1001 - ok, err = store.KVSCheckAndSet(1004, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("expected commit") - } -} - -func TestKVS_List(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // Should not exist - _, idx, ents, err := store.KVSList("/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 0 { - t.Fatalf("bad: %v", idx) - } - if len(ents) != 0 { - t.Fatalf("bad: %v", ents) - } - - // Create the entries - d := &structs.DirEntry{Key: "/web/a", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/b", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1001, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/sub/c", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1002, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Should list - _, idx, ents, err = store.KVSList("/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1002 { - t.Fatalf("bad: %v", idx) - } - if len(ents) != 3 { - t.Fatalf("bad: %v", ents) - } - - if ents[0].Key != "/web/a" { - t.Fatalf("bad: %v", ents[0]) - } - if ents[1].Key != "/web/b" { - t.Fatalf("bad: %v", ents[1]) - } - if ents[2].Key != "/web/sub/c" { - t.Fatalf("bad: %v", ents[2]) - } -} - -func TestKVSList_TombstoneIndex(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // Create the entries - d := &structs.DirEntry{Key: "/web/a", Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/b", Value: []byte("test")} - if err := store.KVSSet(1001, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/c", Value: []byte("test")} - if err := store.KVSSet(1002, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Nuke the last node - err = store.KVSDeleteTree(1003, "/web/c") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Add another node - d = &structs.DirEntry{Key: "/other", Value: []byte("test")} - if err := store.KVSSet(1004, d); err != nil { - t.Fatalf("err: %v", err) - } - - // List should properly reflect tombstoned value - tombIdx, idx, ents, err := store.KVSList("/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1004 { - t.Fatalf("bad: %v", idx) - } - if tombIdx != 1003 { - t.Fatalf("bad: %v", idx) - } - if len(ents) != 2 { - t.Fatalf("bad: %v", ents) - } -} - -func TestKVS_ListKeys(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // Should not exist - idx, keys, err := store.KVSListKeys("", "/") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 0 { - t.Fatalf("bad: %v", keys) - } - - // Create the entries - d := &structs.DirEntry{Key: "/web/a", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/b", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1001, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/sub/c", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1002, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Should list - idx, keys, err = store.KVSListKeys("", "/") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1002 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 1 { - t.Fatalf("bad: %v", keys) - } - if keys[0] != "/" { - t.Fatalf("bad: %v", keys) - } - - // Should list just web - idx, keys, err = store.KVSListKeys("/", "/") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1002 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 1 { - t.Fatalf("bad: %v", keys) - } - if keys[0] != "/web/" { - t.Fatalf("bad: %v", keys) - } - - // Should list a, b, sub/ - idx, keys, err = store.KVSListKeys("/web/", "/") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1002 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 3 { - t.Fatalf("bad: %v", keys) - } - if keys[0] != "/web/a" { - t.Fatalf("bad: %v", keys) - } - if keys[1] != "/web/b" { - t.Fatalf("bad: %v", keys) - } - if keys[2] != "/web/sub/" { - t.Fatalf("bad: %v", keys) - } - - // Should list c - idx, keys, err = store.KVSListKeys("/web/sub/", "/") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1002 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 1 { - t.Fatalf("bad: %v", keys) - } - if keys[0] != "/web/sub/c" { - t.Fatalf("bad: %v", keys) - } - - // Should list all - idx, keys, err = store.KVSListKeys("/web/", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1002 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 3 { - t.Fatalf("bad: %v", keys) - } - if keys[2] != "/web/sub/c" { - t.Fatalf("bad: %v", keys) - } -} - -func TestKVS_ListKeys_Index(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // Create the entries - d := &structs.DirEntry{Key: "/foo/a", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/bar/b", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1001, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/baz/c", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1002, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/other/d", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1003, d); err != nil { - t.Fatalf("err: %v", err) - } - - idx, keys, err := store.KVSListKeys("/foo", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1000 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 1 { - t.Fatalf("bad: %v", keys) - } - - idx, keys, err = store.KVSListKeys("/ba", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1002 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 2 { - t.Fatalf("bad: %v", keys) - } - - idx, keys, err = store.KVSListKeys("/nope", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1003 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 0 { - t.Fatalf("bad: %v", keys) - } -} - -func TestKVS_ListKeys_TombstoneIndex(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // Create the entries - d := &structs.DirEntry{Key: "/foo/a", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/bar/b", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1001, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/baz/c", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1002, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/other/d", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1003, d); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.KVSDelete(1004, "/baz/c"); err != nil { - t.Fatalf("err: %v", err) - } - - idx, keys, err := store.KVSListKeys("/foo", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1000 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 1 { - t.Fatalf("bad: %v", keys) - } - - idx, keys, err = store.KVSListKeys("/ba", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1004 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 1 { - t.Fatalf("bad: %v", keys) - } - - idx, keys, err = store.KVSListKeys("/nope", "") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1004 { - t.Fatalf("bad: %v", idx) - } - if len(keys) != 0 { - t.Fatalf("bad: %v", keys) - } -} - -func TestKVSDeleteTree(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - ttl := 10 * time.Millisecond - gran := 5 * time.Millisecond - gc, err := NewTombstoneGC(ttl, gran) - if err != nil { - t.Fatalf("err: %v", err) - } - gc.SetEnabled(true) - store.gc = gc - - notify1 := make(chan struct{}, 1) - notify2 := make(chan struct{}, 1) - notify3 := make(chan struct{}, 1) - - store.WatchKV("", notify1) - store.WatchKV("/web/sub", notify2) - store.WatchKV("/other", notify3) - - // Should not exist - err = store.KVSDeleteTree(1000, "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Create the entries - d := &structs.DirEntry{Key: "/web/a", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/b", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1001, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/sub/c", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1002, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Nuke the web tree - err = store.KVSDeleteTree(1010, "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Nothing should list - tombIdx, idx, ents, err := store.KVSList("/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1010 { - t.Fatalf("bad: %v", idx) - } - if tombIdx != 1010 { - t.Fatalf("bad: %v", idx) - } - if len(ents) != 0 { - t.Fatalf("bad: %v", ents) - } - - // Check tombstones exists - _, res, err := store.tombstoneTable.Get("id_prefix", "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 3 { - t.Fatalf("bad: %#v", d) - } - for _, r := range res { - if r.(*structs.DirEntry).ModifyIndex != 1010 { - t.Fatalf("bad: %#v", r) - } - } - - // Check that we've fired notify1 and notify2 - select { - case <-notify1: - default: - t.Fatalf("should notify root") - } - select { - case <-notify2: - default: - t.Fatalf("should notify /web/sub") - } - select { - case <-notify3: - t.Fatalf("should not notify /other") - default: - } - - // Check that we get a delete - select { - case idx := <-gc.ExpireCh(): - if idx != 1010 { - t.Fatalf("bad %d", idx) - } - case <-time.After(20 * time.Millisecond): - t.Fatalf("should expire") - } -} - -func TestReapTombstones(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - ttl := 10 * time.Millisecond - gran := 5 * time.Millisecond - gc, err := NewTombstoneGC(ttl, gran) - if err != nil { - t.Fatalf("err: %v", err) - } - gc.SetEnabled(true) - store.gc = gc - - // Should not exist - err = store.KVSDeleteTree(1000, "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Create the entries - d := &structs.DirEntry{Key: "/web/a", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1000, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/b", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1001, d); err != nil { - t.Fatalf("err: %v", err) - } - d = &structs.DirEntry{Key: "/web/sub/c", Flags: 42, Value: []byte("test")} - if err := store.KVSSet(1002, d); err != nil { - t.Fatalf("err: %v", err) - } - - // Nuke just a - err = store.KVSDelete(1010, "/web/a") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Nuke the web tree - err = store.KVSDeleteTree(1020, "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - - // Do a reap, should be a noop - if err := store.ReapTombstones(1000); err != nil { - t.Fatalf("err: %v", err) - } - - // Check tombstones exists - _, res, err := store.tombstoneTable.Get("id_prefix", "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 3 { - t.Fatalf("bad: %#v", d) - } - - // Do a reap, should remove just /web/a - if err := store.ReapTombstones(1010); err != nil { - t.Fatalf("err: %v", err) - } - - // Check tombstones exists - _, res, err = store.tombstoneTable.Get("id_prefix", "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 2 { - t.Fatalf("bad: %#v", d) - } - - // Do a reap, should remove them all - if err := store.ReapTombstones(1025); err != nil { - t.Fatalf("err: %v", err) - } - - // Check no tombstones exists - _, res, err = store.tombstoneTable.Get("id_prefix", "/web") - if err != nil { - t.Fatalf("err: %v", err) - } - if len(res) != 0 { - t.Fatalf("bad: %#v", d) - } -} - -func TestSessionCreate(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "bar", - Status: structs.HealthPassing, - } - if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - Checks: []string{"bar"}, - } - - if err := store.SessionCreate(1000, session); err != nil { - t.Fatalf("err: %v", err) - } - - if session.CreateIndex != 1000 { - t.Fatalf("bad: %v", session) - } -} - -func TestSessionCreate_Invalid(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // No node registered - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - Checks: []string{"bar"}, - } - if err := store.SessionCreate(1000, session); err.Error() != "Missing node registration" { - t.Fatalf("err: %v", err) - } - - // Check not registered - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.SessionCreate(1000, session); err.Error() != "Missing check 'bar' registration" { - t.Fatalf("err: %v", err) - } - - // Unhealthy check - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "bar", - Status: structs.HealthCritical, - } - if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.SessionCreate(1000, session); err.Error() != "Check 'bar' is in critical state" { - t.Fatalf("err: %v", err) - } -} - -func TestSession_Lookups(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - // Create a session - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - Checks: []string{}, - } - if err := store.SessionCreate(1000, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Lookup by ID - idx, s2, err := store.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1000 { - t.Fatalf("bad: %v", idx) - } - if !reflect.DeepEqual(s2, session) { - t.Fatalf("bad: %#v %#v", s2, session) - } - - // Create many sessions - ids := []string{session.ID} - for i := 0; i < 10; i++ { - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - } - if err := store.SessionCreate(uint64(1000+i), session); err != nil { - t.Fatalf("err: %v", err) - } - ids = append(ids, session.ID) - } - - // List all - idx, all, err := store.SessionList() - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1009 { - t.Fatalf("bad: %v", idx) - } - - // Retrieve the ids - var out []string - for _, s := range all { - out = append(out, s.ID) - } - - sort.Strings(ids) - sort.Strings(out) - if !reflect.DeepEqual(ids, out) { - t.Fatalf("bad: %v %v", ids, out) - } - - // List by node - idx, nodes, err := store.NodeSessions("foo") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 1009 { - t.Fatalf("bad: %v", idx) - } - - // Check again for the node list - out = nil - for _, s := range nodes { - out = append(out, s.ID) - } - sort.Strings(out) - if !reflect.DeepEqual(ids, out) { - t.Fatalf("bad: %v %v", ids, out) - } -} - -func TestSessionInvalidate_CriticalHealthCheck(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "bar", - Status: structs.HealthPassing, - } - if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - Checks: []string{"bar"}, - } - if err := store.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Invalidate the check - check.Status = structs.HealthCritical - if err := store.EnsureCheck(15, check); err != nil { - t.Fatalf("err: %v", err) - } - - // Lookup by ID, should be nil - _, s2, err := store.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } -} - -func TestSessionInvalidate_DeleteHealthCheck(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "bar", - Status: structs.HealthPassing, - } - if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - Checks: []string{"bar"}, - } - if err := store.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Delete the check - if err := store.DeleteNodeCheck(15, "foo", "bar"); err != nil { - t.Fatalf("err: %v", err) - } - - // Lookup by ID, should be nil - _, s2, err := store.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } -} - -func TestSessionInvalidate_DeleteNode(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - } - if err := store.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Delete the node - if err := store.DeleteNode(15, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - - // Lookup by ID, should be nil - _, s2, err := store.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } -} - -func TestSessionInvalidate_DeleteNodeService(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(11, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.EnsureService(12, "foo", &structs.NodeService{"api", "api", nil, "", 5000, false}); err != nil { - t.Fatalf("err: %v", err) - } - check := &structs.HealthCheck{ - Node: "foo", - CheckID: "api", - Name: "Can connect", - Status: structs.HealthPassing, - ServiceID: "api", - } - if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v", err) - } - - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - Checks: []string{"api"}, - } - if err := store.SessionCreate(14, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Should invalidate the session - if err := store.DeleteNodeService(15, "foo", "api"); err != nil { - t.Fatalf("err: %v", err) - } - - // Lookup by ID, should be nil - _, s2, err := store.SessionGet(session.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if s2 != nil { - t.Fatalf("session should be invalidated") - } -} - -func TestKVSLock(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ID: generateUUID(), Node: "foo"} - if err := store.SessionCreate(4, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Lock with a non-existing keys should work - d := &structs.DirEntry{ - Key: "/foo", - Flags: 42, - Value: []byte("test"), - Session: session.ID, - } - ok, err := store.KVSLock(5, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected fail") - } - if d.LockIndex != 1 { - t.Fatalf("bad: %v", d) - } - - // Re-locking should fail - ok, err = store.KVSLock(6, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("expected fail") - } - - // Set a normal key - k1 := &structs.DirEntry{ - Key: "/bar", - Flags: 0, - Value: []byte("asdf"), - } - if err := store.KVSSet(7, k1); err != nil { - t.Fatalf("err: %v", err) - } - - // Should acquire the lock - k1.Session = session.ID - ok, err = store.KVSLock(8, k1) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected fail") - } - - // Re-acquire should fail - ok, err = store.KVSLock(9, k1) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("expected fail") - } - -} - -func TestKVSUnlock(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ID: generateUUID(), Node: "foo"} - if err := store.SessionCreate(4, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Unlock with a non-existing keys should fail - d := &structs.DirEntry{ - Key: "/foo", - Flags: 42, - Value: []byte("test"), - Session: session.ID, - } - ok, err := store.KVSUnlock(5, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if ok { - t.Fatalf("expected fail") - } - - // Lock should work - d.Session = session.ID - if ok, _ := store.KVSLock(6, d); !ok { - t.Fatalf("expected lock") - } - - // Unlock should work - d.Session = session.ID - ok, err = store.KVSUnlock(7, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected fail") - } - - // Re-lock should work - d.Session = session.ID - if ok, err := store.KVSLock(8, d); err != nil { - t.Fatalf("err: %v", err) - } else if !ok { - t.Fatalf("expected lock") - } - if d.LockIndex != 2 { - t.Fatalf("bad: %v", d) - } -} - -func TestSessionInvalidate_KeyUnlock(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - LockDelay: 50 * time.Millisecond, - } - if err := store.SessionCreate(4, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Lock a key with the session - d := &structs.DirEntry{ - Key: "/foo", - Flags: 42, - Value: []byte("test"), - Session: session.ID, - } - ok, err := store.KVSLock(5, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected fail") - } - - notify1 := make(chan struct{}, 1) - store.WatchKV("/f", notify1) - - // Delete the node - if err := store.DeleteNode(6, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - - // Key should be unlocked - idx, d2, err := store.KVSGet("/foo") - if idx != 6 { - t.Fatalf("bad: %v", idx) - } - if d2.LockIndex != 1 { - t.Fatalf("bad: %v", *d2) - } - if d2.Session != "" { - t.Fatalf("bad: %v", *d2) - } - - // Should notify of update - select { - case <-notify1: - default: - t.Fatalf("should notify /f") - } - - // Key should have a lock delay - expires := store.KVSLockDelay("/foo") - if expires.Before(time.Now().Add(30 * time.Millisecond)) { - t.Fatalf("Bad: %v", expires) - } -} - -func TestSessionInvalidate_KeyDelete(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { - t.Fatalf("err: %v", err) - } - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - LockDelay: 50 * time.Millisecond, - Behavior: structs.SessionKeysDelete, - } - if err := store.SessionCreate(4, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Lock a key with the session - d := &structs.DirEntry{ - Key: "/bar", - Flags: 42, - Value: []byte("test"), - Session: session.ID, - } - ok, err := store.KVSLock(5, d) - if err != nil { - t.Fatalf("err: %v", err) - } - if !ok { - t.Fatalf("unexpected fail") - } - - notify1 := make(chan struct{}, 1) - store.WatchKV("/b", notify1) - - // Delete the node - if err := store.DeleteNode(6, "foo"); err != nil { - t.Fatalf("err: %v", err) - } - - // Key should be deleted - _, d2, err := store.KVSGet("/bar") - if d2 != nil { - t.Fatalf("unexpected undeleted key") - } - - // Should notify of update - select { - case <-notify1: - default: - t.Fatalf("should notify /b") - } - - // Key should have a lock delay - expires := store.KVSLockDelay("/bar") - if expires.Before(time.Now().Add(30 * time.Millisecond)) { - t.Fatalf("Bad: %v", expires) - } -} - -func TestACLSet_Get(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - idx, out, err := store.ACLGet("1234") - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 0 { - t.Fatalf("bad: %v", idx) - } - if out != nil { - t.Fatalf("bad: %v", out) - } - - a := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - Rules: "", - } - if err := store.ACLSet(50, a); err != nil { - t.Fatalf("err: %v", err) - } - if a.CreateIndex != 50 { - t.Fatalf("Bad: %v", a) - } - if a.ModifyIndex != 50 { - t.Fatalf("Bad: %v", a) - } - if a.ID == "" { - t.Fatalf("Bad: %v", a) - } - - idx, out, err = store.ACLGet(a.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 50 { - t.Fatalf("bad: %v", idx) - } - if !reflect.DeepEqual(out, a) { - t.Fatalf("bad: %v", out) - } - - // Update - a.Rules = "foo bar baz" - if err := store.ACLSet(52, a); err != nil { - t.Fatalf("err: %v", err) - } - if a.CreateIndex != 50 { - t.Fatalf("Bad: %v", a) - } - if a.ModifyIndex != 52 { - t.Fatalf("Bad: %v", a) - } - - idx, out, err = store.ACLGet(a.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 52 { - t.Fatalf("bad: %v", idx) - } - if !reflect.DeepEqual(out, a) { - t.Fatalf("bad: %v", out) - } -} - -func TestACLDelete(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - a := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - Rules: "", - } - if err := store.ACLSet(50, a); err != nil { - t.Fatalf("err: %v", err) - } - - if err := store.ACLDelete(52, a.ID); err != nil { - t.Fatalf("err: %v", err) - } - if err := store.ACLDelete(53, a.ID); err != nil { - t.Fatalf("err: %v", err) - } - - idx, out, err := store.ACLGet(a.ID) - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 52 { - t.Fatalf("bad: %v", idx) - } - if out != nil { - t.Fatalf("bad: %v", out) - } -} - -func TestACLList(t *testing.T) { - store, err := testStateStore() - if err != nil { - t.Fatalf("err: %v", err) - } - defer store.Close() - - a1 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(50, a1); err != nil { - t.Fatalf("err: %v", err) - } - - a2 := &structs.ACL{ - ID: generateUUID(), - Name: "User token", - Type: structs.ACLTypeClient, - } - if err := store.ACLSet(51, a2); err != nil { - t.Fatalf("err: %v", err) - } - - idx, out, err := store.ACLList() - if err != nil { - t.Fatalf("err: %v", err) - } - if idx != 51 { - t.Fatalf("bad: %v", idx) - } - if len(out) != 2 { - t.Fatalf("bad: %v", out) - } -} diff --git a/consul/structs/structs.go b/consul/structs/structs.go index 3f2faf5d568c..e884ff6b43ec 100644 --- a/consul/structs/structs.go +++ b/consul/structs/structs.go @@ -17,6 +17,13 @@ var ( type MessageType uint8 +// RaftIndex is used to track the index used while creating +// or modifying a given struct type. +type RaftIndex struct { + CreateIndex uint64 + ModifyIndex uint64 +} + const ( RegisterRequestType MessageType = iota DeregisterRequestType @@ -224,8 +231,10 @@ func (r *ChecksInStateRequest) RequestDatacenter() string { type Node struct { Node string Address string + + RaftIndex } -type Nodes []Node +type Nodes []*Node // Used to return information about a provided services. // Maps service name to available tags @@ -233,15 +242,56 @@ type Services map[string][]string // ServiceNode represents a node that is part of a service type ServiceNode struct { - Node string - Address string - ServiceID string - ServiceName string - ServiceTags []string - ServiceAddress string - ServicePort int -} -type ServiceNodes []ServiceNode + Node string + Address string + ServiceID string + ServiceName string + ServiceTags []string + ServiceAddress string + ServicePort int + ServiceEnableTagOverride bool + + RaftIndex +} + +// Clone returns a clone of the given service node. +func (s *ServiceNode) Clone() *ServiceNode { + tags := make([]string, len(s.ServiceTags)) + copy(tags, s.ServiceTags) + + return &ServiceNode{ + Node: s.Node, + Address: s.Address, + ServiceID: s.ServiceID, + ServiceName: s.ServiceName, + ServiceTags: tags, + ServiceAddress: s.ServiceAddress, + ServicePort: s.ServicePort, + ServiceEnableTagOverride: s.ServiceEnableTagOverride, + RaftIndex: RaftIndex{ + CreateIndex: s.CreateIndex, + ModifyIndex: s.ModifyIndex, + }, + } +} + +// ToNodeService converts the given service node to a node service. +func (s *ServiceNode) ToNodeService() *NodeService { + return &NodeService{ + ID: s.ServiceID, + Service: s.ServiceName, + Tags: s.ServiceTags, + Address: s.ServiceAddress, + Port: s.ServicePort, + EnableTagOverride: s.ServiceEnableTagOverride, + RaftIndex: RaftIndex{ + CreateIndex: s.CreateIndex, + ModifyIndex: s.ModifyIndex, + }, + } +} + +type ServiceNodes []*ServiceNode // NodeService is a service provided by a node type NodeService struct { @@ -251,9 +301,30 @@ type NodeService struct { Address string Port int EnableTagOverride bool + + RaftIndex } + +// ToServiceNode converts the given node service to a service node. +func (s *NodeService) ToServiceNode(node, address string) *ServiceNode { + return &ServiceNode{ + Node: node, + Address: address, + ServiceID: s.ID, + ServiceName: s.Service, + ServiceTags: s.Tags, + ServiceAddress: s.Address, + ServicePort: s.Port, + ServiceEnableTagOverride: s.EnableTagOverride, + RaftIndex: RaftIndex{ + CreateIndex: s.CreateIndex, + ModifyIndex: s.ModifyIndex, + }, + } +} + type NodeServices struct { - Node Node + Node *Node Services map[string]*NodeService } @@ -267,14 +338,16 @@ type HealthCheck struct { Output string // Holds output of script runs ServiceID string // optional associated service ServiceName string // optional service name + + RaftIndex } type HealthChecks []*HealthCheck -// CheckServiceNode is used to provide the node, it's service +// CheckServiceNode is used to provide the node, its service // definition, as well as a HealthCheck that is associated type CheckServiceNode struct { - Node Node - Service NodeService + Node *Node + Service *NodeService Checks HealthChecks } type CheckServiceNodes []CheckServiceNode @@ -332,14 +405,30 @@ type IndexedNodeDump struct { // DirEntry is used to represent a directory entry. This is // used for values in our Key-Value store. type DirEntry struct { - CreateIndex uint64 - ModifyIndex uint64 - LockIndex uint64 - Key string - Flags uint64 - Value []byte - Session string `json:",omitempty"` + LockIndex uint64 + Key string + Flags uint64 + Value []byte + Session string `json:",omitempty"` + + RaftIndex +} + +// Returns a clone of the given directory entry. +func (d *DirEntry) Clone() *DirEntry { + return &DirEntry{ + LockIndex: d.LockIndex, + Key: d.Key, + Flags: d.Flags, + Value: d.Value, + Session: d.Session, + RaftIndex: RaftIndex{ + CreateIndex: d.CreateIndex, + ModifyIndex: d.ModifyIndex, + }, + } } + type DirEntries []*DirEntry type KVSOp string @@ -414,14 +503,15 @@ const ( // Session is used to represent an open session in the KV store. // This issued to associate node checks with acquired locks. type Session struct { - CreateIndex uint64 - ID string - Name string - Node string - Checks []string - LockDelay time.Duration - Behavior SessionBehavior // What to do when session is invalidated - TTL string + ID string + Name string + Node string + Checks []string + LockDelay time.Duration + Behavior SessionBehavior // What to do when session is invalidated + TTL string + + RaftIndex } type Sessions []*Session @@ -462,12 +552,12 @@ type IndexedSessions struct { // ACL is used to represent a token and it's rules type ACL struct { - CreateIndex uint64 - ModifyIndex uint64 - ID string - Name string - Type string - Rules string + ID string + Name string + Type string + Rules string + + RaftIndex } type ACLs []*ACL diff --git a/consul/structs/structs_test.go b/consul/structs/structs_test.go index 4a2215cf8b4c..a89123efd0e8 100644 --- a/consul/structs/structs_test.go +++ b/consul/structs/structs_test.go @@ -53,3 +53,68 @@ func TestStructs_Implements(t *testing.T) { _ CompoundResponse = &KeyringResponses{} ) } + +// testServiceNode gives a fully filled out ServiceNode instance. +func testServiceNode() *ServiceNode { + return &ServiceNode{ + Node: "node1", + Address: "127.0.0.1", + ServiceID: "service1", + ServiceName: "dogs", + ServiceTags: []string{"prod", "v1"}, + ServiceAddress: "127.0.0.2", + ServicePort: 8080, + ServiceEnableTagOverride: true, + RaftIndex: RaftIndex{ + CreateIndex: 1, + ModifyIndex: 2, + }, + } +} + +func TestStructs_ServiceNode_Clone(t *testing.T) { + sn := testServiceNode() + + clone := sn.Clone() + if !reflect.DeepEqual(sn, clone) { + t.Fatalf("bad: %v", clone) + } + + sn.ServiceTags = append(sn.ServiceTags, "hello") + if reflect.DeepEqual(sn, clone) { + t.Fatalf("clone wasn't independent of the original") + } +} + +func TestStructs_ServiceNode_Conversions(t *testing.T) { + sn := testServiceNode() + + sn2 := sn.ToNodeService().ToServiceNode("node1", "127.0.0.1") + if !reflect.DeepEqual(sn, sn2) { + t.Fatalf("bad: %v", sn2) + } +} + +func TestStructs_DirEntry_Clone(t *testing.T) { + e := &DirEntry{ + LockIndex: 5, + Key: "hello", + Flags: 23, + Value: []byte("this is a test"), + Session: "session1", + RaftIndex: RaftIndex{ + CreateIndex: 1, + ModifyIndex: 2, + }, + } + + clone := e.Clone() + if !reflect.DeepEqual(e, clone) { + t.Fatalf("bad: %v", clone) + } + + e.Value = []byte("a new value") + if reflect.DeepEqual(e, clone) { + t.Fatalf("clone wasn't independent of the original") + } +} diff --git a/scripts/verify_no_uuid.sh b/scripts/verify_no_uuid.sh index 4e9edf2ece19..5f67ef14d6c2 100755 --- a/scripts/verify_no_uuid.sh +++ b/scripts/verify_no_uuid.sh @@ -1,6 +1,6 @@ #!/bin/bash -grep generateUUID consul/state_store.go +grep generateUUID consul/state/state_store.go RESULT=$? if [ $RESULT -eq 0 ]; then exit 1 diff --git a/scripts/windows/verify_no_uuid.bat b/scripts/windows/verify_no_uuid.bat index bcbc5fa28c47..a1d5b8ec4aaa 100644 --- a/scripts/windows/verify_no_uuid.bat +++ b/scripts/windows/verify_no_uuid.bat @@ -2,10 +2,10 @@ setlocal -if not exist %1\consul\state_store.go exit /B 1 +if not exist %1\consul\state\state_store.go exit /B 1 if not exist %1\consul\fsm.go exit /B 1 -findstr /R generateUUID %1\consul\state_store.go 1>nul +findstr /R generateUUID %1\consul\state\state_store.go 1>nul if not %ERRORLEVEL% EQU 1 exit /B 1 findstr generateUUID %1\consul\fsm.go 1>nul diff --git a/website/source/docs/agent/http/kv.html.markdown b/website/source/docs/agent/http/kv.html.markdown index 3e8a18f174c7..30abd31c0880 100644 --- a/website/source/docs/agent/http/kv.html.markdown +++ b/website/source/docs/agent/http/kv.html.markdown @@ -61,8 +61,9 @@ if "?recurse" is provided, the returned `X-Consul-Index` corresponds to the latest `ModifyIndex` within the prefix, and a blocking query using that "?index" will wait until any key within that prefix is updated. -`LockIndex` is the last index of a successful lock acquisition. If the lock is -held, the `Session` key provides the session that owns the lock. +`LockIndex` is the number of times this key has successfully been acquired in +a lock. If the lock is held, the `Session` key provides the session that owns +the lock. `Key` is simply the full path of the entry. @@ -114,7 +115,10 @@ be used with a PUT request: operation. This is useful as it allows leader election to be built on top of Consul. If the lock is not held and the session is valid, this increments the `LockIndex` and sets the `Session` value of the key in addition to updating - the key contents. A key does not need to exist to be acquired. + the key contents. A key does not need to exist to be acquired. If the lock is + already held by the given session, then the `LockIndex` is not incremented but + the key contents are updated. This lets the current lock holder update the key + contents without having to give up the lock and reacquire it. * ?release=\ : This flag is used to turn the `PUT` into a lock release operation. This is useful when paired with "?acquire=" as it allows clients to diff --git a/website/source/docs/faq.html.markdown b/website/source/docs/faq.html.markdown index e99644d6ec05..96b869798608 100644 --- a/website/source/docs/faq.html.markdown +++ b/website/source/docs/faq.html.markdown @@ -6,16 +6,6 @@ sidebar_current: "docs-faq" # Frequently Asked Questions -## Q: Why is virtual memory usage high? - -Consul makes use of [LMDB](http://symas.com/mdb/) internally for various data -storage purposes. LMDB relies on using memory-mapping, a technique in which -a sparse file is represented as a contiguous range of memory. Consul configures -high limits for these file sizes and as a result relies on large chunks of -virtual memory to be allocated. However, in practice, the limits are much larger -than any realistic deployment of Consul would ever use, and the resident memory or -physical memory used is much lower. - ## Q: What is Checkpoint? / Does Consul call home? Consul makes use of a HashiCorp service called [Checkpoint](http://checkpoint.hashicorp.com)