diff --git a/pkg/cluster/executor/ssh.go b/pkg/cluster/executor/ssh.go index 84626458d2..e885b82939 100644 --- a/pkg/cluster/executor/ssh.go +++ b/pkg/cluster/executor/ssh.go @@ -96,7 +96,7 @@ type ( var _ ctxt.Executor = &EasySSHExecutor{} var _ ctxt.Executor = &NativeSSHExecutor{} -// Initialize builds and initializes a EasySSHExecutor +// initialize builds and initializes a EasySSHExecutor func (e *EasySSHExecutor) initialize(config SSHConfig) { // build easyssh config e.Config = &easyssh.MakeConfig{ @@ -207,7 +207,14 @@ func (e *NativeSSHExecutor) prompt(def string) string { return def } -func (e *NativeSSHExecutor) configArgs(args []string) []string { +func (e *NativeSSHExecutor) configArgs(args []string, isScp bool) []string { + if e.Config.Port != 0 && e.Config.Port != 22 { + if isScp { + args = append(args, "-P", strconv.Itoa(e.Config.Port)) + } else { + args = append(args, "-p", strconv.Itoa(e.Config.Port)) + } + } if e.Config.Timeout != 0 { args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", int64(e.Config.Timeout.Seconds()))) } @@ -263,7 +270,7 @@ func (e *NativeSSHExecutor) Execute(ctx context.Context, cmd string, sudo bool, args := []string{ssh, "-o", "StrictHostKeyChecking=no"} - args = e.configArgs(args) // prefix and postfix args + args = e.configArgs(args, false) // prefix and postfix args args = append(args, fmt.Sprintf("%s@%s", e.Config.User, e.Config.Host), cmd) command := exec.CommandContext(ctx, args[0], args[1:]...) @@ -326,7 +333,7 @@ func (e *NativeSSHExecutor) Transfer(ctx context.Context, src, dst string, downl if limit > 0 { args = append(args, "-l", fmt.Sprint(limit)) } - args = e.configArgs(args) // prefix and postfix args + args = e.configArgs(args, true) // prefix and postfix args if download { targetPath := filepath.Dir(dst) diff --git a/pkg/cluster/executor/ssh_test.go b/pkg/cluster/executor/ssh_test.go new file mode 100644 index 0000000000..ff8f38a3b3 --- /dev/null +++ b/pkg/cluster/executor/ssh_test.go @@ -0,0 +1,90 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNativeSSHConfigArgs(t *testing.T) { + testcases := []struct { + c *SSHConfig + s bool + e string + }{ + { + &SSHConfig{ + KeyFile: "id_rsa", + }, + false, + "-i id_rsa", + }, + { + &SSHConfig{ + Timeout: 60 * time.Second, + Port: 23, + KeyFile: "id_rsa", + }, + false, + "-p 23 -o ConnectTimeout=60 -i id_rsa", + }, + { + &SSHConfig{ + Timeout: 60 * time.Second, + Port: 23, + KeyFile: "id_rsa", + }, + true, + "-P 23 -o ConnectTimeout=60 -i id_rsa", + }, + { + &SSHConfig{ + Timeout: 60 * time.Second, + KeyFile: "id_rsa", + Port: 23, + Passphrase: "tidb", + }, + false, + "sshpass -p tidb -P passphrase -p 23 -o ConnectTimeout=60 -i id_rsa", + }, + { + &SSHConfig{ + Timeout: 60 * time.Second, + KeyFile: "id_rsa", + Port: 23, + Passphrase: "tidb", + }, + true, + "sshpass -p tidb -P passphrase -P 23 -o ConnectTimeout=60 -i id_rsa", + }, + { + &SSHConfig{ + Timeout: 60 * time.Second, + Password: "tidb", + }, + true, + "sshpass -p tidb -P password -o ConnectTimeout=60", + }, + } + + e := &NativeSSHExecutor{} + for _, tc := range testcases { + e.Config = tc.c + assert.Equal(t, tc.e, strings.Join(e.configArgs([]string{}, tc.s), " ")) + } +}