From d59bd7a24e273e58737c3efa832adabc57495bed Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Tue, 18 Jun 2024 18:31:23 +0300 Subject: [PATCH] client: add tests --- internal/client/index_internal_test.go | 1 + internal/client/storage.go | 59 +++++- internal/client/storage_test.go | 273 +++++++++++++++++++++++++ 3 files changed, 326 insertions(+), 7 deletions(-) diff --git a/internal/client/index_internal_test.go b/internal/client/index_internal_test.go index 38c0df15300..f51f461cec7 100644 --- a/internal/client/index_internal_test.go +++ b/internal/client/index_internal_test.go @@ -22,6 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) { return ci } +// TODO(s.chzhen): Remove. func TestClientIndex_Find(t *testing.T) { const ( cliIPNone = "1.2.3.4" diff --git a/internal/client/storage.go b/internal/client/storage.go index bc317d2b162..23ab770b87f 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -1,6 +1,7 @@ package client import ( + "fmt" "net/netip" "sync" @@ -53,9 +54,45 @@ func (s *Storage) Add(p *Persistent) (err error) { return nil } +// Find finds persistent client by string representation of the client ID, IP +// address, or MAC. +func (s *Storage) Find(id string) (p *Persistent, ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.index.Find(id) +} + +// FindLoose is like [Storage.Find] but it also tries to find a persistent +// client by IP address without zone. It strips the IPv6 zone index from the +// stored IP addresses before comparing, because querylog entries don't have it. +// See TODO on [querylog.logEntry.IP]. +// +// Note that multiple clients can have the same IP address with different zones. +// Therefore, the result of this method is indeterminate. +func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + + p, ok = s.index.Find(id) + if ok { + return p, ok + } + + p = s.index.FindByIPWithoutZone(ip) + if p != nil { + return p, true + } + + return nil, false +} + // RemoveByName removes persistent client information. ok is false if no such // client exists by that name. func (s *Storage) RemoveByName(name string) (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + p, ok := s.index.FindByName(name) if !ok { return false @@ -66,23 +103,31 @@ func (s *Storage) RemoveByName(name string) (ok bool) { return true } -// Update updates stored persistent client information p with new information n -// or returns an error. p and n must have the same UID. -func (s *Storage) Update(p, n *Persistent) (err error) { +// Update finds the stored persistent client by its name and updates its +// information from n. +func (s *Storage) Update(name string, n *Persistent) (err error) { defer func() { err = errors.Annotate(err, "updating client: %w") }() - if err != nil { - // Don't wrap the error since there is already an annotation deferred. - return err + s.mu.Lock() + defer s.mu.Unlock() + + stored, ok := s.index.FindByName(name) + if !ok { + return fmt.Errorf("client %q is not found", name) } + // Client n has a newly generated UID, so replace it with the stored one. + // + // TODO(s.chzhen): Remove when frontend starts handling UIDs. + n.UID = stored.UID + err = s.index.Clashes(n) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err } - s.index.Delete(p) + s.index.Delete(stored) s.index.Add(n) return nil diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 897472b96c5..fde65cc480c 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -1,6 +1,7 @@ package client_test import ( + "net" "net/netip" "testing" @@ -10,6 +11,32 @@ import ( "github.com/stretchr/testify/require" ) +// newStorage is a helper function that returns a client storage filled with +// persistent clients from the m. It also generates a UID for each client. +func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { + tb.Helper() + + s = client.NewStorage() + + for _, c := range m { + c.UID = client.MustNewUID() + require.NoError(tb, s.Add(c)) + } + + return s +} + +// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an +// error. +func mustParseMAC(s string) (mac net.HardwareAddr) { + mac, err := net.ParseMAC(s) + if err != nil { + panic(err) + } + + return mac +} + func TestStorage_Add(t *testing.T) { const ( existingName = "existing_name" @@ -143,3 +170,249 @@ func TestStorage_RemoveByName(t *testing.T) { assert.False(t, s.RemoveByName(existingName)) }) } + +func TestStorage_Find(t *testing.T) { + const ( + cliIPNone = "1.2.3.4" + cliIP1 = "1.1.1.1" + cliIP2 = "2.2.2.2" + + cliIPv6 = "1:2:3::4" + + cliSubnet = "2.2.2.0/24" + cliSubnetIP = "2.2.2.222" + + cliID = "client-id" + cliMAC = "11:11:11:11:11:11" + + linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0" + linkLocalSubnet = "fe80::/16" + ) + + var ( + clientWithBothFams = &client.Persistent{ + Name: "client1", + IPs: []netip.Addr{ + netip.MustParseAddr(cliIP1), + netip.MustParseAddr(cliIPv6), + }, + } + + clientWithSubnet = &client.Persistent{ + Name: "client2", + IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, + Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, + } + + clientWithMAC = &client.Persistent{ + Name: "client_with_mac", + MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + } + + clientWithID = &client.Persistent{ + Name: "client_with_id", + ClientIDs: []string{cliID}, + } + + clientLinkLocal = &client.Persistent{ + Name: "client_link_local", + Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)}, + } + ) + + clients := []*client.Persistent{ + clientWithBothFams, + clientWithSubnet, + clientWithMAC, + clientWithID, + clientLinkLocal, + } + s := newStorage(t, clients) + + testCases := []struct { + want *client.Persistent + name string + ids []string + }{{ + name: "ipv4_ipv6", + ids: []string{cliIP1, cliIPv6}, + want: clientWithBothFams, + }, { + name: "ipv4_subnet", + ids: []string{cliIP2, cliSubnetIP}, + want: clientWithSubnet, + }, { + name: "mac", + ids: []string{cliMAC}, + want: clientWithMAC, + }, { + name: "client_id", + ids: []string{cliID}, + want: clientWithID, + }, { + name: "client_link_local_subnet", + ids: []string{linkLocalIP}, + want: clientLinkLocal, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, id := range tc.ids { + c, ok := s.Find(id) + require.True(t, ok) + + assert.Equal(t, tc.want, c) + } + }) + } + + t.Run("not_found", func(t *testing.T) { + _, ok := s.Find(cliIPNone) + assert.False(t, ok) + }) +} + +func TestStorage_FindLoose(t *testing.T) { + const ( + nonExistingClientID = "client_id" + ) + + var ( + ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1") + ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2") + ) + + var ( + clientNoZone = &client.Persistent{ + Name: "client", + IPs: []netip.Addr{ip}, + } + + clientWithZone = &client.Persistent{ + Name: "client_with_zone", + IPs: []netip.Addr{ipWithZone}, + } + ) + + ci := newStorage( + t, + []*client.Persistent{ + clientNoZone, + clientWithZone, + }, + ) + + testCases := []struct { + ip netip.Addr + want assert.BoolAssertionFunc + wantCli *client.Persistent + name string + }{{ + name: "without_zone", + ip: ip, + wantCli: clientNoZone, + want: assert.True, + }, { + name: "with_zone", + ip: ipWithZone, + wantCli: clientWithZone, + want: assert.True, + }, { + name: "zero_address", + ip: netip.Addr{}, + wantCli: nil, + want: assert.False, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c, ok := ci.FindLoose(tc.ip.WithZone(""), nonExistingClientID) + assert.Equal(t, tc.wantCli, c) + tc.want(t, ok) + }) + } +} + +func TestStorage_Update(t *testing.T) { + const ( + clientName = "client_name" + obstructingName = "obstructing_name" + obstructingClientID = "obstructing_client_id" + ) + + var ( + obstructingIP = netip.MustParseAddr("1.2.3.4") + obstructingSubnet = netip.MustParsePrefix("1.2.3.0/24") + ) + + obstructingClient := &client.Persistent{ + Name: obstructingName, + IPs: []netip.Addr{obstructingIP}, + Subnets: []netip.Prefix{obstructingSubnet}, + ClientIDs: []string{obstructingClientID}, + } + + clientToUpdate := &client.Persistent{ + Name: clientName, + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + } + + testCases := []struct { + name string + cli *client.Persistent + wantErrMsg string + }{{ + name: "basic", + cli: &client.Persistent{ + Name: "basic", + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + UID: client.MustNewUID(), + }, + wantErrMsg: "", + }, { + name: "duplicate_name", + cli: &client.Persistent{ + Name: obstructingName, + IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")}, + }, + wantErrMsg: `updating client: another client uses the same name "obstructing_name"`, + }, { + name: "duplicate_ip", + cli: &client.Persistent{ + Name: "duplicate_ip", + IPs: []netip.Addr{obstructingIP}, + }, + wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`, + }, { + name: "duplicate_subnet", + cli: &client.Persistent{ + Name: "duplicate_subnet", + Subnets: []netip.Prefix{obstructingSubnet}, + }, + wantErrMsg: `updating client: another client "obstructing_name" ` + + `uses the same subnet "1.2.3.0/24"`, + }, { + name: "duplicate_client_id", + cli: &client.Persistent{ + Name: "duplicate_client_id", + ClientIDs: []string{obstructingClientID}, + }, + wantErrMsg: `updating client: another client "obstructing_name" ` + + `uses the same ClientID "obstructing_client_id"`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := newStorage( + t, + []*client.Persistent{ + clientToUpdate, + obstructingClient, + }, + ) + + err := s.Update(clientName, tc.cli) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +}