Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sftp retry on connection lost #465

Merged
merged 18 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions sftp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func sshClientConfig(config *SSHConfig) (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// NewSSHClient establishes an SSH connection and returns an SSH client
func NewSSHClient(config *SSHConfig) (*ssh.Client, error) {
// newSSHClient establishes an SSH connection and returns an SSH client
func newSSHClient(config *SSHConfig) (*ssh.Client, error) {
sshConfig, err := sshClientConfig(config)
if err != nil {
return nil, fmt.Errorf("cannot configure SSH client: %w", err)
Expand All @@ -80,34 +80,58 @@ func NewSSHClient(config *SSHConfig) (*ssh.Client, error) {
}

type clientImpl struct {
client *sftp.Client
sftpClient *sftp.Client
config *SSHConfig
}

type Client interface {
fracasula marked this conversation as resolved.
Show resolved Hide resolved
OpenFile(path string, f int) (io.ReadWriteCloser, error)
Remove(path string) error
MkdirAll(path string) error
Reset() error
}

// newSFTPClient creates an SFTP client with existing SSH client
func newSFTPClient(client *ssh.Client) (Client, error) {
sftpClient, err := sftp.NewClient(client)
func newSFTPClient(client *ssh.Client) (*sftp.Client, error) {
return sftp.NewClient(client)
}

func newSFTPClientFromConfig(config *SSHConfig) (*sftp.Client, error) {
sshClient, err := newSSHClient(config)
if err != nil {
return nil, fmt.Errorf("creating SSH client: %w", err)
}
return newSFTPClient(sshClient)
}

func newClient(config *SSHConfig) (Client, error) {
sftpClient, err := newSFTPClientFromConfig(config)
if err != nil {
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
return nil, fmt.Errorf("creating SFTP client: %w", err)
}
return &clientImpl{
client: sftpClient,
sftpClient: sftpClient,
config: config,
}, nil
}

func (c *clientImpl) OpenFile(path string, f int) (io.ReadWriteCloser, error) {
return c.client.OpenFile(path, f)
return c.sftpClient.OpenFile(path, f)
}

func (c *clientImpl) Remove(path string) error {
return c.client.Remove(path)
return c.sftpClient.Remove(path)
}

func (c *clientImpl) MkdirAll(path string) error {
return c.client.MkdirAll(path)
return c.sftpClient.MkdirAll(path)
}

func (c *clientImpl) Reset() error {
newSFTPClient, err := newSFTPClientFromConfig(c.config)
if err != nil {
return err
}
c.sftpClient = newSFTPClient
return nil
}
14 changes: 14 additions & 0 deletions sftp/mock_sftp/mock_sftp_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

87 changes: 77 additions & 10 deletions sftp/sftp.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package sftp

import (
"errors"
"fmt"
"io"
"os"
"path/filepath"

"golang.org/x/crypto/ssh"
"github.com/pkg/sftp"
)

const (
Expand All @@ -23,21 +24,33 @@ type FileManager interface {
Delete(remoteFilePath string) error
}

type Option func(impl *fileManagerImpl)

// WithRetryOnIdleConnection enables retrying the operation once in case of a "connection lost" error due to an idle connection.
func WithRetryOnIdleConnection() Option {
return func(impl *fileManagerImpl) {
impl.retryOnIdleConnection = true
}
}

// fileManagerImpl is a real implementation of FileManager
type fileManagerImpl struct {
client Client
client Client
retryOnIdleConnection bool
}

func NewFileManager(sshClient *ssh.Client) (FileManager, error) {
sftpClient, err := newSFTPClient(sshClient)
if err != nil {
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
// Upload uploads a file to the remote server
func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error {
if fm.retryOnIdleConnection {
return fm.retryOnConnectionLost(func() error {
return fm.upload(localFilePath, remoteFilePath)
})
}
return &fileManagerImpl{client: sftpClient}, nil

return fm.upload(localFilePath, remoteFilePath)
}

// Upload uploads a file to the remote server
func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error {
func (fm *fileManagerImpl) upload(localFilePath, remoteFilePath string) error {
localFile, err := os.Open(localFilePath)
if err != nil {
return fmt.Errorf("cannot open local file: %w", err)
Expand Down Expand Up @@ -70,6 +83,16 @@ func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error {

// Download downloads a file from the remote server
func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error {
if fm.retryOnIdleConnection {
return fm.retryOnConnectionLost(func() error {
return fm.download(remoteFilePath, localDir)
})
}

return fm.download(remoteFilePath, localDir)
}

func (fm *fileManagerImpl) download(remoteFilePath, localDir string) error {
remoteFile, err := fm.client.OpenFile(remoteFilePath, os.O_RDONLY)
if err != nil {
return fmt.Errorf("cannot open remote file: %w", err)
Expand Down Expand Up @@ -97,10 +120,54 @@ func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error {

// Delete deletes a file on the remote server
func (fm *fileManagerImpl) Delete(remoteFilePath string) error {
if fm.retryOnIdleConnection {
return fm.retryOnConnectionLost(func() error {
return fm.delete(remoteFilePath)
})
}

return fm.delete(remoteFilePath)
}

func (fm *fileManagerImpl) delete(remoteFilePath string) error {
err := fm.client.Remove(remoteFilePath)
if err != nil {
return fmt.Errorf("cannot delete file: %w", err)
}

return nil
}

func (fm *fileManagerImpl) reset() error {
return fm.client.Reset()
}

// NewFileManager is not concurrent safe. It should not be used from multiple goroutines concurrently without additional synchronization.
func NewFileManager(config *SSHConfig, opts ...Option) (FileManager, error) {
sftpClient, err := newClient(config)
if err != nil {
return nil, err
}
fm := &fileManagerImpl{client: sftpClient}
for _, opt := range opts {
opt(fm)
}
return fm, nil
}

func (fm *fileManagerImpl) retryOnConnectionLost(fileOperation func() error) error {
err := fileOperation()
if err == nil || !isConnectionLostError(err) {
return err // Operation successful or non-retryable error
}

if err := fm.reset(); err != nil {
return err
}

// Retry the operation
return fileOperation()
}

func isConnectionLostError(err error) bool {
return errors.Is(err, sftp.ErrSshFxConnectionLost)
}
57 changes: 43 additions & 14 deletions sftp/sftp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/golang/mock/gomock"
"github.com/ory/dockertest/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/require"

"github.com/rudderlabs/rudder-go-kit/sftp/mock_sftp"
Expand Down Expand Up @@ -115,7 +116,7 @@ func TestSSHClientConfig(t *testing.T) {
}
}

func TestUpload(t *testing.T) {
func TestUploadWithRetry(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -138,16 +139,24 @@ func TestUpload(t *testing.T) {

mockSFTPClient := mock_sftp.NewMockClient(ctrl)
mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil)
mockSFTPClient.EXPECT().MkdirAll(gomock.Any()).Return(nil)

fileManager := &fileManagerImpl{client: mockSFTPClient}
mockSFTPClient.EXPECT().Reset().Return(nil)
callCounter := 0
mockSFTPClient.EXPECT().MkdirAll(gomock.Any()).DoAndReturn(func(_ interface{}) error {
callCounter++
if callCounter == 1 {
return sftp.ErrSshFxConnectionLost
}
return nil
}).Times(2)

fileManager := &fileManagerImpl{client: mockSFTPClient, retryOnIdleConnection: true}

err = fileManager.Upload(localFilePath, "someRemotePath")
require.NoError(t, err)
require.Equal(t, data, remoteBuf.Bytes())
}

func TestDownload(t *testing.T) {
func TestDownloadWithRetry(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -162,9 +171,17 @@ func TestDownload(t *testing.T) {
remoteBuf := bytes.NewBuffer(data)

mockSFTPClient := mock_sftp.NewMockClient(ctrl)
mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil)

fileManager := &fileManagerImpl{client: mockSFTPClient}
callCounter := 0
mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}) (io.ReadWriteCloser, error) {
callCounter++
if callCounter == 1 {
return nil, sftp.ErrSSHFxConnectionLost
}
return &nopReadWriteCloser{remoteBuf}, nil
}).Times(2)
mockSFTPClient.EXPECT().Reset().Return(nil)

fileManager := &fileManagerImpl{client: mockSFTPClient, retryOnIdleConnection: true}

err = fileManager.Download(filepath.Join("someRemoteDir", "test_file.json"), localDir)
require.NoError(t, err)
Expand All @@ -173,15 +190,24 @@ func TestDownload(t *testing.T) {
require.Equal(t, data, localFileContents)
}

func TestDelete(t *testing.T) {
func TestDeleteWithRetry(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

remoteFilePath := "someRemoteFilePath"
mockSFTPClient := mock_sftp.NewMockClient(ctrl)
mockSFTPClient.EXPECT().Remove(remoteFilePath).Return(nil)
callCounter := 0
mockSFTPClient.EXPECT().Remove(gomock.Any()).DoAndReturn(func(_ interface{}) error {
callCounter++
if callCounter == 1 {
return sftp.ErrSSHFxConnectionLost
}
return nil
}).Times(2)

fileManager := &fileManagerImpl{client: mockSFTPClient}
mockSFTPClient.EXPECT().Reset().Return(nil)

fileManager := &fileManagerImpl{client: mockSFTPClient, retryOnIdleConnection: true}

err := fileManager.Delete(remoteFilePath)
require.NoError(t, err)
Expand Down Expand Up @@ -211,14 +237,15 @@ func TestSFTP(t *testing.T) {
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
sshClient, err := NewSSHClient(&SSHConfig{
sshConfig := &SSHConfig{
User: "linuxserver.io",
HostName: hostname,
Port: port,
AuthMethod: "keyAuth",
PrivateKey: string(privateKey),
DialTimeout: 10 * time.Second,
})
}
sshClient, err := newSSHClient(sshConfig)
require.NoError(t, err)

// Create session
Expand All @@ -230,9 +257,11 @@ func TestSFTP(t *testing.T) {
err = session.Run(fmt.Sprintf("mkdir -p %s", remoteDir))
require.NoError(t, err)

sftpManger, err := NewFileManager(sshClient)
sftpClient, err := newSFTPClient(sshClient)
require.NoError(t, err)

sftpManger := &fileManagerImpl{client: &clientImpl{sftpClient: sftpClient}}

// Create local and remote directories within the temporary directory
baseDir := t.TempDir()
localDir := filepath.Join(baseDir, "local")
Expand Down
Loading