Skip to content

Commit

Permalink
Merge pull request #52 from gojoy/feat/remote
Browse files Browse the repository at this point in the history
filter local endpoints when set remote
  • Loading branch information
ahrtr authored Sep 14, 2024
2 parents 6915f58 + 612b447 commit 163fc0e
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 1 deletion.
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type globalConfig struct {
endpoints []string
useClusterEndpoints bool
excludeLocalhost bool

dialTimeout time.Duration
commandTimeout time.Duration
Expand Down
61 changes: 60 additions & 1 deletion endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package main
import (
"errors"
"fmt"
"net"
"net/url"
"os"
"strings"

"go.etcd.io/etcd/client/pkg/v3/srv"
"golang.org/x/exp/slices"
)

var errBadScheme = errors.New("url scheme must be http or https")

func endpointsWithLeaderAtEnd(gcfg globalConfig, statusList []epStatus) ([]string, error) {
eps, err := endpoints(gcfg)
if err != nil || len(eps) <= 1 {
Expand Down Expand Up @@ -40,6 +44,49 @@ func endpoints(gcfg globalConfig) ([]string, error) {
return endpointsFromCluster(gcfg)
}

func IsLocalEndpoint(ep string) (bool, error) {

if strings.HasPrefix(ep, "unix:") || strings.HasPrefix(ep, "unixs:") {
return true, nil
}

if strings.Contains(ep, "://") {
url, err := url.Parse(ep)
if err != nil {
return false, err
}
if url.Scheme != "http" && url.Scheme != "https" {
return false, errBadScheme
}

return isLocalEndpoint(url.Host)
}

return isLocalEndpoint(ep)
}

func isLocalEndpoint(ep string) (bool, error) {

hostStr, _, err := net.SplitHostPort(ep)
if err != nil {
return false, err
}

// lookup localhost
addrs, err := net.LookupHost(hostStr)
if err != nil {
return false, nil
}

for _, addr := range addrs {
ip := net.ParseIP(addr)
if ip != nil && ip.IsLoopback() {
return true, nil
}
}
return false, nil
}

func endpointsFromCluster(gcfg globalConfig) ([]string, error) {
memberlistResp, err := memberList(gcfg)
if err != nil {
Expand All @@ -50,7 +97,19 @@ func endpointsFromCluster(gcfg globalConfig) ([]string, error) {
for _, m := range memberlistResp.Members {
// learner member only serves Status and SerializableRead requests, just ignore it
if !m.GetIsLearner() {
eps = append(eps, m.ClientURLs...)
for _, ep := range m.ClientURLs {
// not append local endpoint when set excludeLocalhost
if gcfg.excludeLocalhost {
ok, err := IsLocalEndpoint(ep)
if err != nil {
return nil, err
}
if ok {
continue
}
}
eps = append(eps, ep)
}
}
}

Expand Down
149 changes: 149 additions & 0 deletions endpoints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"

"testing"

"go.etcd.io/etcd/api/v3/etcdserverpb"
Expand Down Expand Up @@ -111,6 +112,70 @@ func TestEndpointDedup(t *testing.T) {
}
}

func TestEndpointExcludeLocalhost(t *testing.T) {
oldCreateClient := createClient
t.Cleanup(func() {
createClient = oldCreateClient
})

fakeClient := fakeClientURLClient{}
createClient = func(cfgSpec *clientv3.ConfigSpec) (EtcdCluster, error) {
return &fakeClient, nil
}

testcases := []struct {
name string
returnedMemberList *clientv3.MemberListResponse
expectedEndpoints []string
excludeLocalhost bool
}{
{
"normal",
&clientv3.MemberListResponse{
Members: []*etcdserverpb.Member{
{
ClientURLs: []string{"etcd2.example.com:2379", "127.0.0.1:2379"},
},
{
ClientURLs: []string{"etcd3.example.com:2379"},
},
},
},
[]string{"127.0.0.1:2379", "etcd2.example.com:2379", "etcd3.example.com:2379"},
false,
},
{
"excludeLocalhost",
&clientv3.MemberListResponse{
Members: []*etcdserverpb.Member{
{
ClientURLs: []string{"etcd2.example.com:2379", "127.0.0.1:2379"},
},
{
ClientURLs: []string{"etcd3.example.com:2379"},
},
},
},
[]string{"etcd2.example.com:2379", "etcd3.example.com:2379"},
true,
},
}

for _, testcase := range testcases {
t.Run(testcase.name, func(t *testing.T) {
fakeClient.memberListResp = testcase.returnedMemberList
ep, err := endpointsFromCluster(globalConfig{endpoints: []string{"https://localhost:2379"}, excludeLocalhost: testcase.excludeLocalhost})
if err != nil {
t.Error(err)
}

if !slices.Equal(testcase.expectedEndpoints, ep) {
t.Errorf("endpoints didn't match. Expected %v got %v", testcase.expectedEndpoints, ep)
}
})
}
}

type fakeClientURLClient struct {
*clientv3.Client
memberListResp *clientv3.MemberListResponse
Expand All @@ -124,3 +189,87 @@ func (f fakeClientURLClient) MemberList(ctx context.Context, opts ...clientv3.Op
func (fakeClientURLClient) Close() error {
return nil
}

func TestIsLocalEp(t *testing.T) {
testcases := []struct {
name string
ep string
desire bool
err error
}{
{
"ipv4 loopback address",
"127.0.0.1:2379",
true,
nil,
},
{
"ipv4 non-loopback address",
"10.7.7.7:2379",
false,
nil,
},
{
"http url with ipv4 loopback address",
"http://127.0.0.1:2379",
true,
nil,
},
{
"http url with ipv4 non-loopback address",
"http://10.7.7.7:2379",
false,
nil,
},
{
"https url with hostname",
"https://abc-0.ns1-etcd.ns1.svc.cluster.local.:2379",
false,
nil,
},
{
"ipv6 abbreviated loopback address",
"[::1]:2379",
true,
nil,
},
{
"ipv6 loopback address",
"[0:0:0:0:0:0:0:1]:2379",
true,
nil,
},
{
"ipv6 non-loopback address",
"[2007:0db8:3c4d:0015:0000:0000:1a2f:1a2b]:2379",
false,
nil,
},
{
"localhost hostname",
"localhost:2379",
true,
nil,
},
{
"https url with localhost hostname",
"https://localhost:2379",
true,
nil,
},
{
"url with bad scheme",
"abc://localhost:2379",
false,
errBadScheme,
},
}

for _, testcase := range testcases {
t.Run(testcase.name, func(t *testing.T) {
if ok, err := IsLocalEndpoint(testcase.ep); err != testcase.err || ok != testcase.desire {
t.Errorf("expected %v, got err: %v result: %v", testcase.desire, err, ok)
}
})
}
}
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func newDefragCommand() *cobra.Command {

defragCmd.Flags().StringSliceVar(&globalCfg.endpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd endpoints")
defragCmd.Flags().BoolVar(&globalCfg.useClusterEndpoints, "cluster", false, "use all endpoints from the cluster member list")
defragCmd.Flags().BoolVar(&globalCfg.excludeLocalhost, "exclude-localhost", true, "whether to exclude localhost endpoints")

defragCmd.Flags().DurationVar(&globalCfg.dialTimeout, "dial-timeout", 2*time.Second, "dial timeout for client connections")
defragCmd.Flags().DurationVar(&globalCfg.commandTimeout, "command-timeout", 30*time.Second, "command timeout (excluding dial timeout)")
Expand Down

0 comments on commit 163fc0e

Please sign in to comment.