diff --git a/README.md b/README.md index c11dc9d..fbf81ca 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ Flags: --dry-run evaluate whether or not endpoints require defragmentation, but don't actually perform it --endpoints strings comma separated etcd endpoints (default [127.0.0.1:2379]) --etcd-storage-quota-bytes int etcd storage quota in bytes (the value passed to etcd instance by flag --quota-backend-bytes) (default 2147483648) + --exclude-localhost whether to exclude localhost endpoints -h, --help help for etcd-defrag --insecure-discovery accept insecure SRV records describing cluster endpoints (default true) --insecure-skip-tls-verify skip server certificate verification (CAUTION: this option should be enabled only for testing purposes) diff --git a/endpoints.go b/endpoints.go index cd2e495..266671e 100644 --- a/endpoints.go +++ b/endpoints.go @@ -44,12 +44,12 @@ func endpoints(gcfg globalConfig) ([]string, error) { return endpointsFromCluster(gcfg) } -func IsLocalEndpoint(ep string) (bool, error) { - +func isLocalEndpoint(ep string) (bool, error) { if strings.HasPrefix(ep, "unix:") || strings.HasPrefix(ep, "unixs:") { return true, nil } + hostPort := ep if strings.Contains(ep, "://") { url, err := url.Parse(ep) if err != nil { @@ -59,31 +59,23 @@ func IsLocalEndpoint(ep string) (bool, error) { return false, errBadScheme } - return isLocalEndpoint(url.Host) + hostPort = url.Host } - return isLocalEndpoint(ep) -} - -func isLocalEndpoint(ep string) (bool, error) { - - hostStr, _, err := net.SplitHostPort(ep) + hostname, _, err := net.SplitHostPort(hostPort) if err != nil { return false, err } - // lookup localhost - addrs, err := net.LookupHost(hostStr) - if err != nil { - return false, nil + if strings.EqualFold(hostname, "localhost") { + return true, nil } - for _, addr := range addrs { - ip := net.ParseIP(addr) - if ip != nil && ip.IsLoopback() { - return true, nil - } + ip := net.ParseIP(hostname) + if ip != nil && ip.IsLoopback() { + return true, nil } + return false, nil } @@ -98,9 +90,9 @@ func endpointsFromCluster(gcfg globalConfig) ([]string, error) { // learner member only serves Status and SerializableRead requests, just ignore it if !m.GetIsLearner() { for _, ep := range m.ClientURLs { - // not append local endpoint when set excludeLocalhost + // Do not append loopback endpoints when `--exclude-localhost` is set. if gcfg.excludeLocalhost { - ok, err := IsLocalEndpoint(ep) + ok, err := isLocalEndpoint(ep) if err != nil { return nil, err } diff --git a/endpoints_test.go b/endpoints_test.go index f684ef2..79e7a80 100644 --- a/endpoints_test.go +++ b/endpoints_test.go @@ -267,7 +267,7 @@ func TestIsLocalEp(t *testing.T) { for _, testcase := range testcases { t.Run(testcase.name, func(t *testing.T) { - if ok, err := IsLocalEndpoint(testcase.ep); err != testcase.err || ok != testcase.desire { + if ok, err := isLocalEndpoint(testcase.ep); err != testcase.err || ok != testcase.desire { t.Errorf("expected %v, got err: %v result: %v", testcase.desire, err, ok) } }) diff --git a/main.go b/main.go index d9aeb3d..c7045b6 100644 --- a/main.go +++ b/main.go @@ -23,7 +23,7 @@ func newDefragCommand() *cobra.Command { defragCmd.Flags().StringSliceVar(&globalCfg.endpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd endpoints") defragCmd.Flags().BoolVar(&globalCfg.useClusterEndpoints, "cluster", false, "use all endpoints from the cluster member list") - defragCmd.Flags().BoolVar(&globalCfg.excludeLocalhost, "exclude-localhost", true, "whether to exclude localhost endpoints") + defragCmd.Flags().BoolVar(&globalCfg.excludeLocalhost, "exclude-localhost", false, "whether to exclude localhost endpoints") defragCmd.Flags().DurationVar(&globalCfg.dialTimeout, "dial-timeout", 2*time.Second, "dial timeout for client connections") defragCmd.Flags().DurationVar(&globalCfg.commandTimeout, "command-timeout", 30*time.Second, "command timeout (excluding dial timeout)")