From 612b447292473ee455fee440592cf8bafba8248d Mon Sep 17 00:00:00 2001 From: Guo Jix <729324352@qq.com> Date: Thu, 22 Aug 2024 07:12:46 +0000 Subject: [PATCH] filter local endpoints when set excludeLocalhost Signed-off-by: GitHub --- config.go | 1 + endpoints.go | 61 ++++++++++++++++++- endpoints_test.go | 149 ++++++++++++++++++++++++++++++++++++++++++++++ main.go | 1 + 4 files changed, 211 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 9eec59b..62c0768 100644 --- a/config.go +++ b/config.go @@ -10,6 +10,7 @@ import ( type globalConfig struct { endpoints []string useClusterEndpoints bool + excludeLocalhost bool dialTimeout time.Duration commandTimeout time.Duration diff --git a/endpoints.go b/endpoints.go index 280c265..cd2e495 100644 --- a/endpoints.go +++ b/endpoints.go @@ -3,6 +3,8 @@ package main import ( "errors" "fmt" + "net" + "net/url" "os" "strings" @@ -10,6 +12,8 @@ import ( "golang.org/x/exp/slices" ) +var errBadScheme = errors.New("url scheme must be http or https") + func endpointsWithLeaderAtEnd(gcfg globalConfig, statusList []epStatus) ([]string, error) { eps, err := endpoints(gcfg) if err != nil || len(eps) <= 1 { @@ -40,6 +44,49 @@ func endpoints(gcfg globalConfig) ([]string, error) { return endpointsFromCluster(gcfg) } +func IsLocalEndpoint(ep string) (bool, error) { + + if strings.HasPrefix(ep, "unix:") || strings.HasPrefix(ep, "unixs:") { + return true, nil + } + + if strings.Contains(ep, "://") { + url, err := url.Parse(ep) + if err != nil { + return false, err + } + if url.Scheme != "http" && url.Scheme != "https" { + return false, errBadScheme + } + + return isLocalEndpoint(url.Host) + } + + return isLocalEndpoint(ep) +} + +func isLocalEndpoint(ep string) (bool, error) { + + hostStr, _, err := net.SplitHostPort(ep) + if err != nil { + return false, err + } + + // lookup localhost + addrs, err := net.LookupHost(hostStr) + if err != nil { + return false, nil + } + + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip != nil && ip.IsLoopback() { + return true, nil + } + } + return false, nil +} + func endpointsFromCluster(gcfg globalConfig) ([]string, error) { memberlistResp, err := memberList(gcfg) if err != nil { @@ -50,7 +97,19 @@ func endpointsFromCluster(gcfg globalConfig) ([]string, error) { for _, m := range memberlistResp.Members { // learner member only serves Status and SerializableRead requests, just ignore it if !m.GetIsLearner() { - eps = append(eps, m.ClientURLs...) + for _, ep := range m.ClientURLs { + // not append local endpoint when set excludeLocalhost + if gcfg.excludeLocalhost { + ok, err := IsLocalEndpoint(ep) + if err != nil { + return nil, err + } + if ok { + continue + } + } + eps = append(eps, ep) + } } } diff --git a/endpoints_test.go b/endpoints_test.go index 34bd4a0..f684ef2 100644 --- a/endpoints_test.go +++ b/endpoints_test.go @@ -2,6 +2,7 @@ package main import ( "context" + "testing" "go.etcd.io/etcd/api/v3/etcdserverpb" @@ -111,6 +112,70 @@ func TestEndpointDedup(t *testing.T) { } } +func TestEndpointExcludeLocalhost(t *testing.T) { + oldCreateClient := createClient + t.Cleanup(func() { + createClient = oldCreateClient + }) + + fakeClient := fakeClientURLClient{} + createClient = func(cfgSpec *clientv3.ConfigSpec) (EtcdCluster, error) { + return &fakeClient, nil + } + + testcases := []struct { + name string + returnedMemberList *clientv3.MemberListResponse + expectedEndpoints []string + excludeLocalhost bool + }{ + { + "normal", + &clientv3.MemberListResponse{ + Members: []*etcdserverpb.Member{ + { + ClientURLs: []string{"etcd2.example.com:2379", "127.0.0.1:2379"}, + }, + { + ClientURLs: []string{"etcd3.example.com:2379"}, + }, + }, + }, + []string{"127.0.0.1:2379", "etcd2.example.com:2379", "etcd3.example.com:2379"}, + false, + }, + { + "excludeLocalhost", + &clientv3.MemberListResponse{ + Members: []*etcdserverpb.Member{ + { + ClientURLs: []string{"etcd2.example.com:2379", "127.0.0.1:2379"}, + }, + { + ClientURLs: []string{"etcd3.example.com:2379"}, + }, + }, + }, + []string{"etcd2.example.com:2379", "etcd3.example.com:2379"}, + true, + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + fakeClient.memberListResp = testcase.returnedMemberList + ep, err := endpointsFromCluster(globalConfig{endpoints: []string{"https://localhost:2379"}, excludeLocalhost: testcase.excludeLocalhost}) + if err != nil { + t.Error(err) + } + + if !slices.Equal(testcase.expectedEndpoints, ep) { + t.Errorf("endpoints didn't match. Expected %v got %v", testcase.expectedEndpoints, ep) + } + }) + } +} + type fakeClientURLClient struct { *clientv3.Client memberListResp *clientv3.MemberListResponse @@ -124,3 +189,87 @@ func (f fakeClientURLClient) MemberList(ctx context.Context, opts ...clientv3.Op func (fakeClientURLClient) Close() error { return nil } + +func TestIsLocalEp(t *testing.T) { + testcases := []struct { + name string + ep string + desire bool + err error + }{ + { + "ipv4 loopback address", + "127.0.0.1:2379", + true, + nil, + }, + { + "ipv4 non-loopback address", + "10.7.7.7:2379", + false, + nil, + }, + { + "http url with ipv4 loopback address", + "http://127.0.0.1:2379", + true, + nil, + }, + { + "http url with ipv4 non-loopback address", + "http://10.7.7.7:2379", + false, + nil, + }, + { + "https url with hostname", + "https://abc-0.ns1-etcd.ns1.svc.cluster.local.:2379", + false, + nil, + }, + { + "ipv6 abbreviated loopback address", + "[::1]:2379", + true, + nil, + }, + { + "ipv6 loopback address", + "[0:0:0:0:0:0:0:1]:2379", + true, + nil, + }, + { + "ipv6 non-loopback address", + "[2007:0db8:3c4d:0015:0000:0000:1a2f:1a2b]:2379", + false, + nil, + }, + { + "localhost hostname", + "localhost:2379", + true, + nil, + }, + { + "https url with localhost hostname", + "https://localhost:2379", + true, + nil, + }, + { + "url with bad scheme", + "abc://localhost:2379", + false, + errBadScheme, + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + if ok, err := IsLocalEndpoint(testcase.ep); err != testcase.err || ok != testcase.desire { + t.Errorf("expected %v, got err: %v result: %v", testcase.desire, err, ok) + } + }) + } +} diff --git a/main.go b/main.go index 2b2d9ce..56cdb8a 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ func newDefragCommand() *cobra.Command { defragCmd.Flags().StringSliceVar(&globalCfg.endpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd endpoints") defragCmd.Flags().BoolVar(&globalCfg.useClusterEndpoints, "cluster", false, "use all endpoints from the cluster member list") + defragCmd.Flags().BoolVar(&globalCfg.excludeLocalhost, "exclude-localhost", true, "whether to exclude localhost endpoints") defragCmd.Flags().DurationVar(&globalCfg.dialTimeout, "dial-timeout", 2*time.Second, "dial timeout for client connections") defragCmd.Flags().DurationVar(&globalCfg.commandTimeout, "command-timeout", 30*time.Second, "command timeout (excluding dial timeout)")