From 49f19aede633d0dac18e4a849cad22649c2cb77e Mon Sep 17 00:00:00 2001 From: Allen Zhong Date: Mon, 24 May 2021 16:11:34 +0800 Subject: [PATCH 1/3] cluster/executor: implement native SCP download instead of cat --- pkg/cluster/executor/scp.go | 128 ++++++++++++++++++++++++++++++++++++ pkg/cluster/executor/ssh.go | 13 +--- 2 files changed, 129 insertions(+), 12 deletions(-) create mode 100644 pkg/cluster/executor/scp.go diff --git a/pkg/cluster/executor/scp.go b/pkg/cluster/executor/scp.go new file mode 100644 index 0000000000..9e8cf5da1e --- /dev/null +++ b/pkg/cluster/executor/scp.go @@ -0,0 +1,128 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bufio" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/pingcap/tiup/pkg/utils" + "golang.org/x/crypto/ssh" +) + +// ScpDownload downloads a file from remote with SCP +// The implementation is partially inspired by github.com/dtylman/scp +func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string) error { + // prepare dst file + targetPath := filepath.Dir(dst) + if err := utils.CreateDir(targetPath); err != nil { + return err + } + targetFile, err := os.Create(dst) + if err != nil { + return err + } + + r, err := session.StdoutPipe() + if err != nil { + return err + } + bufr := bufio.NewReader(r) + + w, err := session.StdinPipe() + if err != nil { + return err + } + + copyF := func() error { + // parse SCP command + line, _, err := bufr.ReadLine() + if err != nil { + return err + } + if line[0] != byte('C') { + return fmt.Errorf("incorrect scp command '%b', should be 'C'", line[0]) + } + + mode, err := strconv.ParseUint(string(line[1:5]), 0, 32) + if err != nil { + return fmt.Errorf("error parsing file mode; %s", err) + } + if err := targetFile.Chmod(fs.FileMode(mode)); err != nil { + return fmt.Errorf("error setting file mode; %s", err) + } + + size, err := strconv.Atoi(strings.Fields(string(line))[1]) + if err != nil { + return err + } + + if err := ack(w); err != nil { + return err + } + + // transferring data + n, err := io.CopyN(targetFile, bufr, int64(size)) + if err != nil { + return err + } + if n < int64(size) { + return fmt.Errorf("error downloading via scp, file size mismatch") + } + if err := targetFile.Sync(); err != nil { + return err + } + + return ack(w) + } + + copyErrC := make(chan error, 1) + go func() { + defer w.Close() + defer targetFile.Close() + copyErrC <- copyF() + }() + + err = session.Start(fmt.Sprintf("scp -f %s", src)) + if err != nil { + return err + } + if err := ack(w); err != nil { // send an empty byte to start transfer + return err + } + + err = <-copyErrC + if err != nil { + return err + } + return session.Wait() +} + +func ack(w io.Writer) error { + msg := []byte("\x00") + n, err := w.Write(msg) + if err != nil { + return fmt.Errorf("fail to send response to remote: %s", err) + } + if n < len(msg) { + return fmt.Errorf("fail to send response to remote, size mismatch") + } + return nil +} diff --git a/pkg/cluster/executor/ssh.go b/pkg/cluster/executor/ssh.go index 6010800286..1beb281323 100644 --- a/pkg/cluster/executor/ssh.go +++ b/pkg/cluster/executor/ssh.go @@ -197,18 +197,7 @@ func (e *EasySSHExecutor) Transfer(ctx context.Context, src string, dst string, defer client.Close() defer session.Close() - targetPath := filepath.Dir(dst) - if err = utils.CreateDir(targetPath); err != nil { - return err - } - targetFile, err := os.Create(dst) - if err != nil { - return err - } - - session.Stdout = targetFile - - return session.Run(fmt.Sprintf("cat %s", src)) + return ScpDownload(session, client, src, dst) } func (e *NativeSSHExecutor) prompt(def string) string { From 8c2777368977024902dc2f4aa90b3bf1b13bbfe0 Mon Sep 17 00:00:00 2001 From: Allen Zhong Date: Tue, 25 May 2021 16:27:31 +0800 Subject: [PATCH 2/3] executor/ssh: implement rate limit for scp download --- components/cluster/command/transfer.go | 1 + components/dm/ansible/import.go | 4 ++-- components/dm/ansible/import_test.go | 4 ++-- components/dm/spec/logic.go | 6 +++--- pkg/cluster/ansible/config.go | 9 ++++++--- pkg/cluster/ctxt/context.go | 2 +- pkg/cluster/executor/checkpoint.go | 6 ++++-- pkg/cluster/executor/local.go | 2 +- pkg/cluster/executor/local_test.go | 2 +- pkg/cluster/executor/scp.go | 8 ++++++-- pkg/cluster/executor/ssh.go | 9 ++++++--- pkg/cluster/manager/transfer.go | 5 +++-- pkg/cluster/operation/telemetry.go | 2 +- pkg/cluster/spec/alertmanager.go | 2 +- pkg/cluster/spec/cdc.go | 2 +- pkg/cluster/spec/drainer.go | 2 +- pkg/cluster/spec/grafana.go | 6 +++--- pkg/cluster/spec/instance.go | 8 ++++---- pkg/cluster/spec/pd.go | 4 ++-- pkg/cluster/spec/prometheus.go | 6 +++--- pkg/cluster/spec/pump.go | 2 +- pkg/cluster/spec/tidb.go | 2 +- pkg/cluster/spec/tiflash.go | 2 +- pkg/cluster/spec/tikv.go | 2 +- pkg/cluster/spec/tispark.go | 18 +++++++++--------- pkg/cluster/task/builder.go | 3 ++- pkg/cluster/task/copy_file.go | 3 ++- pkg/cluster/task/init_config_test.go | 2 +- pkg/cluster/task/install_package.go | 2 +- pkg/cluster/task/monitored_config.go | 6 +++--- pkg/cluster/task/tls.go | 9 ++++++--- 31 files changed, 80 insertions(+), 61 deletions(-) diff --git a/components/cluster/command/transfer.go b/components/cluster/command/transfer.go index 3ced0e4da4..b6c027ef95 100644 --- a/components/cluster/command/transfer.go +++ b/components/cluster/command/transfer.go @@ -48,6 +48,7 @@ func newPullCmd() *cobra.Command { cmd.Flags().StringSliceVarP(&gOpt.Roles, "role", "R", nil, "Only exec on host with specified roles") cmd.Flags().StringSliceVarP(&gOpt.Nodes, "node", "N", nil, "Only exec on host with specified nodes") + cmd.Flags().IntVarP(&opt.Limit, "limit", "l", 0, "Limits the used bandwidth, specified in Kbit/s") return cmd } diff --git a/components/dm/ansible/import.go b/components/dm/ansible/import.go index e0fd50dc21..5c9a6233c7 100644 --- a/components/dm/ansible/import.go +++ b/components/dm/ansible/import.go @@ -180,7 +180,7 @@ func (im *Importer) fetchFile(ctx context.Context, host string, port int, fname tmp = filepath.Join(tmp, filepath.Base(fname)) - err = e.Transfer(ctx, fname, tmp, true /*download*/) + err = e.Transfer(ctx, fname, tmp, true /*download*/, 0) if err != nil { return nil, errors.Annotatef(err, "transfer %s from %s:%d", fname, host, port) } @@ -254,7 +254,7 @@ func (im *Importer) ScpSourceToMaster(ctx context.Context, topo *spec.Specificat return errors.AddStack(err) } - err = e.Transfer(ctx, f.Name(), filepath.Join(target, addr+".yml"), false) + err = e.Transfer(ctx, f.Name(), filepath.Join(target, addr+".yml"), false, 0) if err != nil { return err } diff --git a/components/dm/ansible/import_test.go b/components/dm/ansible/import_test.go index 51711695f6..d57cfe45f0 100644 --- a/components/dm/ansible/import_test.go +++ b/components/dm/ansible/import_test.go @@ -48,13 +48,13 @@ func (g *executorGetter) Get(host string) ctxt.Executor { // Transfer implements executor interface. // Replace the deploy directory as the local one in testdata, so we can fetch it. -func (l *localExecutor) Transfer(ctx context.Context, src string, target string, download bool) error { +func (l *localExecutor) Transfer(ctx context.Context, src, target string, download bool, limit int) error { mydeploy, err := filepath.Abs("./testdata/deploy_dir/" + l.host) if err != nil { return errors.AddStack(err) } src = strings.Replace(src, "/home/tidb/deploy", mydeploy, 1) - return l.Local.Transfer(ctx, src, target, download) + return l.Local.Transfer(ctx, src, target, download, 0) } func TestParseRunScript(t *testing.T) { diff --git a/components/dm/spec/logic.go b/components/dm/spec/logic.go index 189154e830..daee4c31a6 100644 --- a/components/dm/spec/logic.go +++ b/components/dm/spec/logic.go @@ -134,7 +134,7 @@ func (i *MasterInstance) InitConfig( return err } dst := filepath.Join(paths.Deploy, "scripts", "run_dm-master.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } if _, _, err := e.Execute(ctx, "chmod +x "+dst, false); err != nil { @@ -176,7 +176,7 @@ func (i *MasterInstance) ScaleConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_dm-master.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } if _, _, err := e.Execute(ctx, "chmod +x "+dst, false); err != nil { @@ -265,7 +265,7 @@ func (i *WorkerInstance) InitConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_dm-worker.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } diff --git a/pkg/cluster/ansible/config.go b/pkg/cluster/ansible/config.go index 3aab27aef3..062e9a8d20 100644 --- a/pkg/cluster/ansible/config.go +++ b/pkg/cluster/ansible/config.go @@ -54,7 +54,8 @@ func ImportConfig(name string, clsMeta *spec.ClusterMeta, sshTimeout uint64, ssh inst.GetHost(), inst.GetPort())), inst.GetHost(), - true). + true, + 0). Build() copyFileTasks = append(copyFileTasks, t) case spec.ComponentTiFlash: @@ -71,7 +72,8 @@ func ImportConfig(name string, clsMeta *spec.ClusterMeta, sshTimeout uint64, ssh inst.GetHost(), inst.GetPort())), inst.GetHost(), - true). + true, + 0). CopyFile(filepath.Join(inst.DeployDir(), "conf", inst.ComponentName()+"-learner.toml"), spec.ClusterPath(name, spec.AnsibleImportedConfigPath, @@ -80,7 +82,8 @@ func ImportConfig(name string, clsMeta *spec.ClusterMeta, sshTimeout uint64, ssh inst.GetHost(), inst.GetPort())), inst.GetHost(), - true). + true, + 0). Build() copyFileTasks = append(copyFileTasks, t) default: diff --git a/pkg/cluster/ctxt/context.go b/pkg/cluster/ctxt/context.go index 963fc7676d..13913a6363 100644 --- a/pkg/cluster/ctxt/context.go +++ b/pkg/cluster/ctxt/context.go @@ -45,7 +45,7 @@ type ( Execute(ctx context.Context, cmd string, sudo bool, timeout ...time.Duration) (stdout []byte, stderr []byte, err error) // Transfer copies files from or to a target - Transfer(ctx context.Context, src string, dst string, download bool) error + Transfer(ctx context.Context, src, dst string, download bool, limit int) error } // ExecutorGetter get the executor by host. diff --git a/pkg/cluster/executor/checkpoint.go b/pkg/cluster/executor/checkpoint.go index 7869304f6b..36d22c45b2 100644 --- a/pkg/cluster/executor/checkpoint.go +++ b/pkg/cluster/executor/checkpoint.go @@ -46,6 +46,7 @@ var ( checkpoint.Field("src", reflect.DeepEqual), checkpoint.Field("dst", reflect.DeepEqual), checkpoint.Field("download", reflect.DeepEqual), + checkpoint.Field("limit", reflect.DeepEqual), ) ) @@ -86,7 +87,7 @@ func (c *CheckPointExecutor) Execute(ctx context.Context, cmd string, sudo bool, } // Transfer implements Executer interface. -func (c *CheckPointExecutor) Transfer(ctx context.Context, src string, dst string, download bool) (err error) { +func (c *CheckPointExecutor) Transfer(ctx context.Context, src, dst string, download bool, limit int) (err error) { point := checkpoint.Acquire(ctx, scpPoint, map[string]interface{}{ "host": c.config.Host, "port": c.config.Port, @@ -94,6 +95,7 @@ func (c *CheckPointExecutor) Transfer(ctx context.Context, src string, dst strin "src": src, "dst": dst, "download": download, + "limit": limit, }) defer func() { point.Release(err, @@ -108,5 +110,5 @@ func (c *CheckPointExecutor) Transfer(ctx context.Context, src string, dst strin return nil } - return c.Executor.Transfer(ctx, src, dst, download) + return c.Executor.Transfer(ctx, src, dst, download, limit) } diff --git a/pkg/cluster/executor/local.go b/pkg/cluster/executor/local.go index a1feb15a93..c66b9f160c 100644 --- a/pkg/cluster/executor/local.go +++ b/pkg/cluster/executor/local.go @@ -100,7 +100,7 @@ func (l *Local) Execute(ctx context.Context, cmd string, sudo bool, timeout ...t } // Transfer implements Executer interface. -func (l *Local) Transfer(ctx context.Context, src string, dst string, download bool) error { +func (l *Local) Transfer(ctx context.Context, src, dst string, download bool, limit int) error { targetPath := filepath.Dir(dst) if err := utils.CreateDir(targetPath); err != nil { return err diff --git a/pkg/cluster/executor/local_test.go b/pkg/cluster/executor/local_test.go index 99d6fd1816..cc37eff030 100644 --- a/pkg/cluster/executor/local_test.go +++ b/pkg/cluster/executor/local_test.go @@ -54,7 +54,7 @@ func TestLocal(t *testing.T) { defer os.Remove(dst.Name()) // Transfer src to dst and check it. - err = local.Transfer(ctx, src.Name(), dst.Name(), false) + err = local.Transfer(ctx, src.Name(), dst.Name(), false, 0) assert.Nil(err) data, err := os.ReadFile(dst.Name()) diff --git a/pkg/cluster/executor/scp.go b/pkg/cluster/executor/scp.go index 9e8cf5da1e..590c3299b4 100644 --- a/pkg/cluster/executor/scp.go +++ b/pkg/cluster/executor/scp.go @@ -29,7 +29,7 @@ import ( // ScpDownload downloads a file from remote with SCP // The implementation is partially inspired by github.com/dtylman/scp -func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string) error { +func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string, limit int) error { // prepare dst file targetPath := filepath.Dir(dst) if err := utils.CreateDir(targetPath); err != nil { @@ -100,7 +100,11 @@ func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string) erro copyErrC <- copyF() }() - err = session.Start(fmt.Sprintf("scp -f %s", src)) + remoteCmd := fmt.Sprintf("scp -f %s", src) + if limit > 0 { + remoteCmd = fmt.Sprintf("scp -l %d -f %s", limit, src) + } + err = session.Start(remoteCmd) if err != nil { return err } diff --git a/pkg/cluster/executor/ssh.go b/pkg/cluster/executor/ssh.go index 1beb281323..84626458d2 100644 --- a/pkg/cluster/executor/ssh.go +++ b/pkg/cluster/executor/ssh.go @@ -180,7 +180,7 @@ func (e *EasySSHExecutor) Execute(ctx context.Context, cmd string, sudo bool, ti // This function depends on `scp` (a tool from OpenSSH or other SSH implementation) // This function is based on easyssh.MakeConfig.Scp() but with support of copying // file from remote to local. -func (e *EasySSHExecutor) Transfer(ctx context.Context, src string, dst string, download bool) error { +func (e *EasySSHExecutor) Transfer(ctx context.Context, src, dst string, download bool, limit int) error { if !download { err := e.Config.Scp(src, dst) if err != nil { @@ -197,7 +197,7 @@ func (e *EasySSHExecutor) Transfer(ctx context.Context, src string, dst string, defer client.Close() defer session.Close() - return ScpDownload(session, client, src, dst) + return ScpDownload(session, client, src, dst, limit) } func (e *NativeSSHExecutor) prompt(def string) string { @@ -308,7 +308,7 @@ func (e *NativeSSHExecutor) Execute(ctx context.Context, cmd string, sudo bool, // Transfer copies files via SCP // This function depends on `scp` (a tool from OpenSSH or other SSH implementation) -func (e *NativeSSHExecutor) Transfer(ctx context.Context, src string, dst string, download bool) error { +func (e *NativeSSHExecutor) Transfer(ctx context.Context, src, dst string, download bool, limit int) error { if e.ConnectionTestResult != nil { return e.ConnectionTestResult } @@ -323,6 +323,9 @@ func (e *NativeSSHExecutor) Transfer(ctx context.Context, src string, dst string } args := []string{scp, "-r", "-o", "StrictHostKeyChecking=no"} + if limit > 0 { + args = append(args, "-l", fmt.Sprint(limit)) + } args = e.configArgs(args) // prefix and postfix args if download { diff --git a/pkg/cluster/manager/transfer.go b/pkg/cluster/manager/transfer.go index 1c4169864c..1e8a5084f6 100644 --- a/pkg/cluster/manager/transfer.go +++ b/pkg/cluster/manager/transfer.go @@ -38,6 +38,7 @@ type TransferOptions struct { Local string Remote string Pull bool // default to push + Limit int // rate limit in Kbit/s } // Transfer copies files from or to host in the tidb cluster. @@ -93,9 +94,9 @@ func (m *Manager) Transfer(name string, opt TransferOptions, gOpt operator.Optio for _, p := range i.Slice() { t := task.NewBuilder() if opt.Pull { - t.CopyFile(p, srcPath, host, opt.Pull) + t.CopyFile(p, srcPath, host, opt.Pull, opt.Limit) } else { - t.CopyFile(srcPath, p, host, opt.Pull) + t.CopyFile(srcPath, p, host, opt.Pull, opt.Limit) } shellTasks = append(shellTasks, t.Build()) } diff --git a/pkg/cluster/operation/telemetry.go b/pkg/cluster/operation/telemetry.go index e1b0ad659f..df4c451ad5 100644 --- a/pkg/cluster/operation/telemetry.go +++ b/pkg/cluster/operation/telemetry.go @@ -87,7 +87,7 @@ func GetNodeInfo( dstDir := filepath.Join(dir, "bin") dstPath := filepath.Join(dstDir, path.Base(srcPath)) - err = exec.Transfer(nctx, srcPath, dstPath, false) + err = exec.Transfer(nctx, srcPath, dstPath, false, 0) if err != nil { return err } diff --git a/pkg/cluster/spec/alertmanager.go b/pkg/cluster/spec/alertmanager.go index 2ea74c6ac3..cadf533474 100644 --- a/pkg/cluster/spec/alertmanager.go +++ b/pkg/cluster/spec/alertmanager.go @@ -148,7 +148,7 @@ func (i *AlertManagerInstance) InitConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_alertmanager.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } if _, _, err := e.Execute(ctx, "chmod +x "+dst, false); err != nil { diff --git a/pkg/cluster/spec/cdc.go b/pkg/cluster/spec/cdc.go index 6c53bb60e3..bb241b07dc 100644 --- a/pkg/cluster/spec/cdc.go +++ b/pkg/cluster/spec/cdc.go @@ -177,7 +177,7 @@ func (i *CDCInstance) InitConfig( return err } dst := filepath.Join(paths.Deploy, "scripts", "run_cdc.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } diff --git a/pkg/cluster/spec/drainer.go b/pkg/cluster/spec/drainer.go index 858b9a1e85..e3ba1e51c5 100644 --- a/pkg/cluster/spec/drainer.go +++ b/pkg/cluster/spec/drainer.go @@ -171,7 +171,7 @@ func (i *DrainerInstance) InitConfig( return err } dst := filepath.Join(paths.Deploy, "scripts", "run_drainer.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } diff --git a/pkg/cluster/spec/grafana.go b/pkg/cluster/spec/grafana.go index 836cf8d42a..3528085077 100644 --- a/pkg/cluster/spec/grafana.go +++ b/pkg/cluster/spec/grafana.go @@ -142,7 +142,7 @@ func (i *GrafanaInstance) InitConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_grafana.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -164,7 +164,7 @@ func (i *GrafanaInstance) InitConfig( return err } dst = filepath.Join(paths.Deploy, "conf", "grafana.ini") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -265,7 +265,7 @@ func (i *GrafanaInstance) installDashboards(ctx context.Context, e ctxt.Executor srcPath := PackagePath(ComponentDMMaster, clusterVersion, i.OS(), i.Arch()) dstPath := filepath.Join(tmp, filepath.Base(srcPath)) - err = e.Transfer(ctx, srcPath, dstPath, false) + err = e.Transfer(ctx, srcPath, dstPath, false, 0) if err != nil { return err } diff --git a/pkg/cluster/spec/instance.go b/pkg/cluster/spec/instance.go index 2866d0f299..5e77b41914 100644 --- a/pkg/cluster/spec/instance.go +++ b/pkg/cluster/spec/instance.go @@ -182,7 +182,7 @@ func (i *BaseInstance) InitConfig(ctx context.Context, e ctxt.Executor, opt Glob return errors.Trace(err) } tgt := filepath.Join("/tmp", comp+"_"+uuid.New().String()+".service") - if err := e.Transfer(ctx, sysCfg, tgt, false); err != nil { + if err := e.Transfer(ctx, sysCfg, tgt, false, 0); err != nil { return errors.Annotatef(err, "transfer from %s to %s failed", sysCfg, tgt) } cmd := fmt.Sprintf("mv %s /etc/systemd/system/%s-%d.service", tgt, comp, port) @@ -203,7 +203,7 @@ func (i *BaseInstance) TransferLocalConfigFile(ctx context.Context, e ctxt.Execu return errors.Annotatef(err, "execute: %s", cmd) } - if err := e.Transfer(ctx, local, remote, false); err != nil { + if err := e.Transfer(ctx, local, remote, false, 0); err != nil { return errors.Annotatef(err, "transfer from %s to %s failed", local, remote) } @@ -255,7 +255,7 @@ func (i *BaseInstance) MergeServerConfig(ctx context.Context, e ctxt.Executor, g } dst := filepath.Join(paths.Deploy, "conf", fmt.Sprintf("%s.toml", i.ComponentName())) // transfer config - return e.Transfer(ctx, fp, dst, false) + return e.Transfer(ctx, fp, dst, false, 0) } // mergeTiFlashLearnerServerConfig merges the server configuration and overwrite the global configuration @@ -271,7 +271,7 @@ func (i *BaseInstance) mergeTiFlashLearnerServerConfig(ctx context.Context, e ct } dst := filepath.Join(paths.Deploy, "conf", fmt.Sprintf("%s-learner.toml", i.ComponentName())) // transfer config - return e.Transfer(ctx, fp, dst, false) + return e.Transfer(ctx, fp, dst, false, 0) } // ID returns the identifier of this instance, the ID is constructed by host:port diff --git a/pkg/cluster/spec/pd.go b/pkg/cluster/spec/pd.go index 25f2429128..cd953cd814 100644 --- a/pkg/cluster/spec/pd.go +++ b/pkg/cluster/spec/pd.go @@ -184,7 +184,7 @@ func (i *PDInstance) InitConfig( return err } dst := filepath.Join(paths.Deploy, "scripts", "run_pd.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } if _, _, err := e.Execute(ctx, "chmod +x "+dst, false); err != nil { @@ -285,7 +285,7 @@ func (i *PDInstance) ScaleConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_pd.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } if _, _, err := e.Execute(ctx, "chmod +x "+dst, false); err != nil { diff --git a/pkg/cluster/spec/prometheus.go b/pkg/cluster/spec/prometheus.go index 3b6703b11d..258c6c3e9e 100644 --- a/pkg/cluster/spec/prometheus.go +++ b/pkg/cluster/spec/prometheus.go @@ -164,7 +164,7 @@ func (i *MonitorInstance) InitConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_prometheus.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -306,7 +306,7 @@ func (i *MonitorInstance) InitConfig( return err } dst = filepath.Join(paths.Deploy, "conf", "prometheus.yml") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -332,7 +332,7 @@ func (i *MonitorInstance) installRules(ctx context.Context, e ctxt.Executor, dep srcPath := PackagePath(ComponentDMMaster, clusterVersion, i.OS(), i.Arch()) dstPath := filepath.Join(tmp, filepath.Base(srcPath)) - err = e.Transfer(ctx, srcPath, dstPath, false) + err = e.Transfer(ctx, srcPath, dstPath, false, 0) if err != nil { return err } diff --git a/pkg/cluster/spec/pump.go b/pkg/cluster/spec/pump.go index de87fc6d29..9bae4a5eb6 100644 --- a/pkg/cluster/spec/pump.go +++ b/pkg/cluster/spec/pump.go @@ -166,7 +166,7 @@ func (i *PumpInstance) InitConfig( return err } dst := filepath.Join(paths.Deploy, "scripts", "run_pump.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } diff --git a/pkg/cluster/spec/tidb.go b/pkg/cluster/spec/tidb.go index d0b08a9383..5235e17196 100644 --- a/pkg/cluster/spec/tidb.go +++ b/pkg/cluster/spec/tidb.go @@ -145,7 +145,7 @@ func (i *TiDBInstance) InitConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_tidb.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } if _, _, err := e.Execute(ctx, "chmod +x "+dst, false); err != nil { diff --git a/pkg/cluster/spec/tiflash.go b/pkg/cluster/spec/tiflash.go index 909abc24e9..f18fad74a4 100644 --- a/pkg/cluster/spec/tiflash.go +++ b/pkg/cluster/spec/tiflash.go @@ -544,7 +544,7 @@ func (i *TiFlashInstance) InitConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_tiflash.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } diff --git a/pkg/cluster/spec/tikv.go b/pkg/cluster/spec/tikv.go index 5e9c00fb2c..4521fdbb50 100644 --- a/pkg/cluster/spec/tikv.go +++ b/pkg/cluster/spec/tikv.go @@ -219,7 +219,7 @@ func (i *TiKVInstance) InitConfig( } dst := filepath.Join(paths.Deploy, "scripts", "run_tikv.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } diff --git a/pkg/cluster/spec/tispark.go b/pkg/cluster/spec/tispark.go index 4ebcfac882..025983d9c2 100644 --- a/pkg/cluster/spec/tispark.go +++ b/pkg/cluster/spec/tispark.go @@ -211,7 +211,7 @@ func (i *TiSparkMasterInstance) InitConfig( return errors.Trace(err) } tgt := filepath.Join("/tmp", comp+"_"+uuid.New().String()+".service") - if err := e.Transfer(ctx, sysCfg, tgt, false); err != nil { + if err := e.Transfer(ctx, sysCfg, tgt, false, 0); err != nil { return errors.Annotatef(err, "transfer from %s to %s failed", sysCfg, tgt) } cmd := fmt.Sprintf("mv %s /etc/systemd/system/%s-%d.service", tgt, comp, port) @@ -237,7 +237,7 @@ func (i *TiSparkMasterInstance) InitConfig( return err } dst := filepath.Join(paths.Deploy, "conf", "spark-defaults.conf") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -253,7 +253,7 @@ func (i *TiSparkMasterInstance) InitConfig( } // tispark files are all in a "spark" sub-directory of deploy dir dst = filepath.Join(paths.Deploy, "conf", "spark-env.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -267,7 +267,7 @@ func (i *TiSparkMasterInstance) InitConfig( return err } dst = filepath.Join(paths.Deploy, "conf", "log4j.properties") - return e.Transfer(ctx, fp, dst, false) + return e.Transfer(ctx, fp, dst, false, 0) } // ScaleConfig deploy temporary config on scaling @@ -374,7 +374,7 @@ func (i *TiSparkWorkerInstance) InitConfig( return errors.Trace(err) } tgt := filepath.Join("/tmp", comp+"_"+uuid.New().String()+".service") - if err := e.Transfer(ctx, sysCfg, tgt, false); err != nil { + if err := e.Transfer(ctx, sysCfg, tgt, false, 0); err != nil { return errors.Annotatef(err, "transfer from %s to %s failed", sysCfg, tgt) } cmd := fmt.Sprintf("mv %s /etc/systemd/system/%s-%d.service", tgt, comp, port) @@ -400,7 +400,7 @@ func (i *TiSparkWorkerInstance) InitConfig( return err } dst := filepath.Join(paths.Deploy, "conf", "spark-defaults.conf") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -417,7 +417,7 @@ func (i *TiSparkWorkerInstance) InitConfig( } // tispark files are all in a "spark" sub-directory of deploy dir dst = filepath.Join(paths.Deploy, "conf", "spark-env.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -431,7 +431,7 @@ func (i *TiSparkWorkerInstance) InitConfig( return err } dst = filepath.Join(paths.Deploy, "sbin", "start-slave.sh") - if err := e.Transfer(ctx, fp, dst, false); err != nil { + if err := e.Transfer(ctx, fp, dst, false, 0); err != nil { return err } @@ -445,7 +445,7 @@ func (i *TiSparkWorkerInstance) InitConfig( return err } dst = filepath.Join(paths.Deploy, "conf", "log4j.properties") - return e.Transfer(ctx, fp, dst, false) + return e.Transfer(ctx, fp, dst, false, 0) } // ScaleConfig deploy temporary config on scaling diff --git a/pkg/cluster/task/builder.go b/pkg/cluster/task/builder.go index a8bdb5006a..37c88022a5 100644 --- a/pkg/cluster/task/builder.go +++ b/pkg/cluster/task/builder.go @@ -130,12 +130,13 @@ func (b *Builder) UpdateTopology(cluster, profile string, metadata *spec.Cluster } // CopyFile appends a CopyFile task to the current task collection -func (b *Builder) CopyFile(src, dst, server string, download bool) *Builder { +func (b *Builder) CopyFile(src, dst, server string, download bool, limit int) *Builder { b.tasks = append(b.tasks, &CopyFile{ src: src, dst: dst, remote: server, download: download, + limit: limit, }) return b } diff --git a/pkg/cluster/task/copy_file.go b/pkg/cluster/task/copy_file.go index 8da709c0e9..dc85db8a2f 100644 --- a/pkg/cluster/task/copy_file.go +++ b/pkg/cluster/task/copy_file.go @@ -27,6 +27,7 @@ type CopyFile struct { dst string remote string download bool + limit int } // Execute implements the Task interface @@ -36,7 +37,7 @@ func (c *CopyFile) Execute(ctx context.Context) error { return ErrNoExecutor } - err := e.Transfer(ctx, c.src, c.dst, c.download) + err := e.Transfer(ctx, c.src, c.dst, c.download, c.limit) if err != nil { return errors.Annotate(err, "failed to transfer file") } diff --git a/pkg/cluster/task/init_config_test.go b/pkg/cluster/task/init_config_test.go index 705538a680..b33bd9923b 100644 --- a/pkg/cluster/task/init_config_test.go +++ b/pkg/cluster/task/init_config_test.go @@ -37,7 +37,7 @@ func (e *fakeExecutor) Execute(ctx context.Context, cmd string, sudo bool, timeo return []byte{}, []byte{}, nil } -func (e *fakeExecutor) Transfer(ctx context.Context, src string, dst string, download bool) error { +func (e *fakeExecutor) Transfer(ctx context.Context, src, dst string, download bool, limit int) error { return nil } diff --git a/pkg/cluster/task/install_package.go b/pkg/cluster/task/install_package.go index 58118a9dc6..9c14d67e4a 100644 --- a/pkg/cluster/task/install_package.go +++ b/pkg/cluster/task/install_package.go @@ -42,7 +42,7 @@ func (c *InstallPackage) Execute(ctx context.Context) error { dstDir := filepath.Join(c.dstDir, "bin") dstPath := filepath.Join(dstDir, path.Base(c.srcPath)) - err := exec.Transfer(ctx, c.srcPath, dstPath, false) + err := exec.Transfer(ctx, c.srcPath, dstPath, false, 0) if err != nil { return errors.Annotatef(err, "failed to scp %s to %s:%s", c.srcPath, c.host, dstPath) } diff --git a/pkg/cluster/task/monitored_config.go b/pkg/cluster/task/monitored_config.go index d6d0974595..006c593f55 100644 --- a/pkg/cluster/task/monitored_config.go +++ b/pkg/cluster/task/monitored_config.go @@ -114,7 +114,7 @@ func (m *MonitoredConfig) syncMonitoredSystemConfig(ctx context.Context, exec ct return err } tgt := filepath.Join("/tmp", comp+"_"+uuid.New().String()+".service") - if err := exec.Transfer(ctx, sysCfg, tgt, false); err != nil { + if err := exec.Transfer(ctx, sysCfg, tgt, false, 0); err != nil { return err } if outp, errp, err := exec.Execute(ctx, fmt.Sprintf("mv %s /etc/systemd/system/%s-%d.service", tgt, comp, port), true); err != nil { @@ -135,7 +135,7 @@ func (m *MonitoredConfig) syncMonitoredScript(ctx context.Context, exec ctxt.Exe return err } dst := filepath.Join(m.paths.Deploy, "scripts", fmt.Sprintf("run_%s.sh", comp)) - if err := exec.Transfer(ctx, fp, dst, false); err != nil { + if err := exec.Transfer(ctx, fp, dst, false, 0); err != nil { return err } if _, _, err := exec.Execute(ctx, "chmod +x "+dst, false); err != nil { @@ -151,7 +151,7 @@ func (m *MonitoredConfig) syncBlackboxConfig(ctx context.Context, exec ctxt.Exec return err } dst := filepath.Join(m.paths.Deploy, "conf", "blackbox.yml") - return exec.Transfer(ctx, fp, dst, false) + return exec.Transfer(ctx, fp, dst, false, 0) } // Rollback implements the Task interface diff --git a/pkg/cluster/task/tls.go b/pkg/cluster/task/tls.go index 09ba4fea3c..4148946c6f 100644 --- a/pkg/cluster/task/tls.go +++ b/pkg/cluster/task/tls.go @@ -91,17 +91,20 @@ func (c *TLSCert) Execute(ctx context.Context) error { } if err := e.Transfer(ctx, caFile, filepath.Join(c.paths.Deploy, "tls", spec.TLSCACert), - false /* download */); err != nil { + false, /* download */ + 0 /* limit */); err != nil { return errors.Annotate(err, "failed to transfer CA cert to server") } if err := e.Transfer(ctx, keyFile, filepath.Join(c.paths.Deploy, "tls", fmt.Sprintf("%s.pem", c.inst.Role())), - false /* download */); err != nil { + false, /* download */ + 0 /* limit */); err != nil { return errors.Annotate(err, "failed to transfer TLS private key to server") } if err := e.Transfer(ctx, certFile, filepath.Join(c.paths.Deploy, "tls", fmt.Sprintf("%s.crt", c.inst.Role())), - false /* download */); err != nil { + false, /* download */ + 0 /* limit */); err != nil { return errors.Annotate(err, "failed to transfer TLS cert to server") } From bb5d8ffc106b20c2efd23c5643992174b2957d73 Mon Sep 17 00:00:00 2001 From: Allen Zhong Date: Tue, 25 May 2021 17:02:30 +0800 Subject: [PATCH 3/3] executor/ssh: adjust as per PR comments --- pkg/cluster/executor/checkpoint.go | 1 - pkg/cluster/executor/scp.go | 23 ++++++++++------------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/pkg/cluster/executor/checkpoint.go b/pkg/cluster/executor/checkpoint.go index 36d22c45b2..4fa5d828a5 100644 --- a/pkg/cluster/executor/checkpoint.go +++ b/pkg/cluster/executor/checkpoint.go @@ -46,7 +46,6 @@ var ( checkpoint.Field("src", reflect.DeepEqual), checkpoint.Field("dst", reflect.DeepEqual), checkpoint.Field("download", reflect.DeepEqual), - checkpoint.Field("limit", reflect.DeepEqual), ) ) diff --git a/pkg/cluster/executor/scp.go b/pkg/cluster/executor/scp.go index 590c3299b4..84e1f209ac 100644 --- a/pkg/cluster/executor/scp.go +++ b/pkg/cluster/executor/scp.go @@ -30,16 +30,6 @@ import ( // ScpDownload downloads a file from remote with SCP // The implementation is partially inspired by github.com/dtylman/scp func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string, limit int) error { - // prepare dst file - targetPath := filepath.Dir(dst) - if err := utils.CreateDir(targetPath); err != nil { - return err - } - targetFile, err := os.Create(dst) - if err != nil { - return err - } - r, err := session.StdoutPipe() if err != nil { return err @@ -65,9 +55,17 @@ func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string, limi if err != nil { return fmt.Errorf("error parsing file mode; %s", err) } - if err := targetFile.Chmod(fs.FileMode(mode)); err != nil { - return fmt.Errorf("error setting file mode; %s", err) + + // prepare dst file + targetPath := filepath.Dir(dst) + if err := utils.CreateDir(targetPath); err != nil { + return err + } + targetFile, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fs.FileMode(mode)) + if err != nil { + return err } + defer targetFile.Close() size, err := strconv.Atoi(strings.Fields(string(line))[1]) if err != nil { @@ -96,7 +94,6 @@ func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string, limi copyErrC := make(chan error, 1) go func() { defer w.Close() - defer targetFile.Close() copyErrC <- copyF() }()