From 5b1baf18d9b7546bd174d7740dd7f654d25445b9 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 5 Feb 2024 16:18:56 +0800 Subject: [PATCH] Refine the PD client initialization Signed-off-by: JmPotato --- tools/pd-ctl/pdctl/command/cluster_command.go | 7 +-- tools/pd-ctl/pdctl/command/global.go | 54 ++++++++++++------- tools/pd-ctl/pdctl/ctl.go | 11 ---- 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/tools/pd-ctl/pdctl/command/cluster_command.go b/tools/pd-ctl/pdctl/command/cluster_command.go index 47df92f1e3d5..631397a2dcfb 100644 --- a/tools/pd-ctl/pdctl/command/cluster_command.go +++ b/tools/pd-ctl/pdctl/command/cluster_command.go @@ -19,9 +19,10 @@ import "github.com/spf13/cobra" // NewClusterCommand return a cluster subcommand of rootCmd func NewClusterCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "cluster", - Short: "show the cluster information", - Run: showClusterCommandFunc, + Use: "cluster", + Short: "show the cluster information", + Run: showClusterCommandFunc, + PersistentPreRunE: requirePDClient, } cmd.AddCommand(NewClusterStatusCommand()) return cmd diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 61b60dfa9466..1fee5f42f213 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -54,17 +54,30 @@ func initTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) { // PDCli is a pd HTTP client var PDCli pd.Client +func requirePDClient(cmd *cobra.Command, _ []string) error { + var ( + caPath string + err error + ) + caPath, err = cmd.Flags().GetString("cacert") + if err == nil && len(caPath) != 0 { + var certPath, keyPath string + certPath, err = cmd.Flags().GetString("cert") + if err != nil { + return err + } + keyPath, err = cmd.Flags().GetString("key") + if err != nil { + return err + } + return InitNewPDClientWithTLS(cmd, caPath, certPath, keyPath) + } + return InitNewPDClient(cmd) +} + // shouldInitPDClient checks whether we should create a new PD client according to the cluster information. func shouldInitPDClient(cmd *cobra.Command) (bool, error) { - if PDCli == nil { - return true, nil - } - // Use PD client to get the current cluster information. - currentClusterInfo, err := PDCli.GetCluster(cmd.Context()) - if err != nil { - return false, err - } - // Use HTTP request to get the new cluster information. + // Get the cluster information the current command assigned to. newClusterInfoJSON, err := doRequest(cmd, clusterPrefix, http.MethodGet, http.Header{}) if err != nil { return false, err @@ -74,34 +87,39 @@ func shouldInitPDClient(cmd *cobra.Command) (bool, error) { if err != nil { return false, err } + // If the PD client is nil and we get the cluster information successfully, + // we should initialize the PD client directly. + if PDCli == nil { + return true, nil + } + // Get current cluster information that the PD client connects to. + currentClusterInfo, err := PDCli.GetCluster(cmd.Context()) + if err != nil { + return true, nil + } + // Compare the cluster ID to determine whether we should re-initialize the PD client. return currentClusterInfo.GetId() == 0 || newClusterInfo.GetId() != currentClusterInfo.GetId(), nil } // InitNewPDClient creates a PD HTTP client with the given PD addresses. -func InitNewPDClient(cmd *cobra.Command) error { +func InitNewPDClient(cmd *cobra.Command, opts ...pd.ClientOption) error { if should, err := shouldInitPDClient(cmd); !should || err != nil { return err } if PDCli != nil { PDCli.Close() } - PDCli = pd.NewClient(pdControlCallerID, getEndpoints(cmd)) + PDCli = pd.NewClient(pdControlCallerID, getEndpoints(cmd), opts...) return nil } // InitNewPDClientWithTLS creates a PD HTTP client with the given PD addresses and TLS config. func InitNewPDClientWithTLS(cmd *cobra.Command, caPath, certPath, keyPath string) error { - if should, err := shouldInitPDClient(cmd); !should || err != nil { - return err - } - if PDCli != nil { - PDCli.Close() - } tlsConfig, err := initTLSConfig(caPath, certPath, keyPath) if err != nil { return err } - PDCli = pd.NewClient(pdControlCallerID, getEndpoints(cmd), pd.WithTLSConfig(tlsConfig)) + InitNewPDClient(cmd, pd.WithTLSConfig(tlsConfig)) return nil } diff --git a/tools/pd-ctl/pdctl/ctl.go b/tools/pd-ctl/pdctl/ctl.go index 53b1e0492ecb..ee78f886c0f7 100644 --- a/tools/pd-ctl/pdctl/ctl.go +++ b/tools/pd-ctl/pdctl/ctl.go @@ -89,17 +89,6 @@ func GetRootCmd() *cobra.Command { rootCmd.Println(err) return err } - err = command.InitNewPDClientWithTLS(cmd, caPath, certPath, keyPath) - if err != nil { - rootCmd.Println(err) - return err - } - } else { - err = command.InitNewPDClient(cmd) - if err != nil { - rootCmd.Println(err) - return err - } } return nil }