Skip to content

Commit

Permalink
fix: sftp retry on connection lost
Browse files Browse the repository at this point in the history
  • Loading branch information
Gauravudia committed May 6, 2024
1 parent af06f04 commit 8a8a9ff
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 55 deletions.
4 changes: 2 additions & 2 deletions sftp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func sshClientConfig(config *SSHConfig) (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// NewSSHClient establishes an SSH connection and returns an SSH client
func NewSSHClient(config *SSHConfig) (*ssh.Client, error) {
// newSSHClient establishes an SSH connection and returns an SSH client
func newSSHClient(config *SSHConfig) (*ssh.Client, error) {
sshConfig, err := sshClientConfig(config)
if err != nil {
return nil, fmt.Errorf("cannot configure SSH client: %w", err)
Expand Down
160 changes: 107 additions & 53 deletions sftp/sftp.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package sftp

import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"

"golang.org/x/crypto/ssh"
)
Expand All @@ -25,82 +27,134 @@ type FileManager interface {

// fileManagerImpl is a real implementation of FileManager
type fileManagerImpl struct {
client Client
}

func NewFileManager(sshClient *ssh.Client) (FileManager, error) {
sftpClient, err := newSFTPClient(sshClient)
if err != nil {
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
}
return &fileManagerImpl{client: sftpClient}, nil
client Client
config *SSHConfig
sshClient *ssh.Client
}

// Upload uploads a file to the remote server
func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error {
localFile, err := os.Open(localFilePath)
if err != nil {
return fmt.Errorf("cannot open local file: %w", err)
}
defer func() {
_ = localFile.Close()
}()

// Create the directory if it does not exist
remoteDir := filepath.Dir(remoteFilePath)
if err := fm.client.MkdirAll(remoteDir); err != nil {
return fmt.Errorf("cannot create remote directory: %w", err)
fileOperation := func() error {
localFile, err := os.Open(localFilePath)
if err != nil {
return fmt.Errorf("cannot open local file: %w", err)
}
defer func() {
_ = localFile.Close()
}()

// Create the directory if it does not exist
remoteDir := filepath.Dir(remoteFilePath)
if err := fm.client.MkdirAll(remoteDir); err != nil {
return fmt.Errorf("cannot create remote directory: %w", err)
}

remoteFile, err := fm.client.OpenFile(remoteFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
if err != nil {
return fmt.Errorf("cannot create remote file: %w", err)
}
defer func() {
_ = remoteFile.Close()
}()

_, err = io.Copy(remoteFile, localFile)
if err != nil {
return fmt.Errorf("error copying file: %w", err)
}

return nil
}

remoteFile, err := fm.client.OpenFile(remoteFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
if err != nil {
return fmt.Errorf("cannot create remote file: %w", err)
return fm.retryOnConnectionLost(fileOperation)
}

// Download downloads a file from the remote server
func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error {
fileOperation := func() error {
remoteFile, err := fm.client.OpenFile(remoteFilePath, os.O_RDONLY)
if err != nil {
return fmt.Errorf("cannot open remote file: %w", err)
}
defer func() {
_ = remoteFile.Close()
}()

localFileName := filepath.Join(localDir, filepath.Base(remoteFilePath))
localFile, err := os.Create(localFileName)
if err != nil {
return fmt.Errorf("cannot create local file: %w", err)
}
defer func() {
_ = localFile.Close()
}()

_, err = io.Copy(localFile, remoteFile)
if err != nil {
return fmt.Errorf("cannot copy remote file content to local file: %w", err)
}

return nil
}
defer func() {
_ = remoteFile.Close()
}()
return fm.retryOnConnectionLost(fileOperation)

_, err = io.Copy(remoteFile, localFile)
if err != nil {
return fmt.Errorf("error copying file: %w", err)
}

// Delete deletes a file on the remote server
func (fm *fileManagerImpl) Delete(remoteFilePath string) error {
fileOperation := func() error {
err := fm.client.Remove(remoteFilePath)
if err != nil {
return fmt.Errorf("cannot delete file: %w", err)
}

return nil
}

return nil
return fm.retryOnConnectionLost(fileOperation)

}

// Download downloads a file from the remote server
func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error {
remoteFile, err := fm.client.OpenFile(remoteFilePath, os.O_RDONLY)
func NewFileManager(config *SSHConfig) (FileManager, error) {
sshClient, err := newSSHClient(config)
if err != nil {
return fmt.Errorf("cannot open remote file: %w", err)
return nil, fmt.Errorf("creating SSH client: %w", err)
}
defer func() {
_ = remoteFile.Close()
}()

localFileName := filepath.Join(localDir, filepath.Base(remoteFilePath))
localFile, err := os.Create(localFileName)
sftpClient, err := newSFTPClient(sshClient)
if err != nil {
return fmt.Errorf("cannot create local file: %w", err)
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
}
return &fileManagerImpl{client: sftpClient, config: config, sshClient: sshClient}, nil
}

func isConnectionLostError(err error) bool {
// Implement the logic to check if the error indicates a "connection lost" condition
return strings.Contains(err.Error(), "connection lost")
}

func (fm *fileManagerImpl) retryOnConnectionLost(fileOperation func() error) error {
err := fileOperation()
if err == nil || !isConnectionLostError(err) {
return err // Operation successful or non-retryable error
}
defer func() {
_ = localFile.Close()
}()

_, err = io.Copy(localFile, remoteFile)
if err != nil {
return fmt.Errorf("cannot copy remote file content to local file: %w", err)
if err := fm.recreateSFTPClient(); err != nil {
return err // Error recreating the SFTP client
}

return nil
// Retry the operation
return fileOperation()
}

// Delete deletes a file on the remote server
func (fm *fileManagerImpl) Delete(remoteFilePath string) error {
err := fm.client.Remove(remoteFilePath)
func (fm *fileManagerImpl) recreateSFTPClient() error {
newFileManager, err := NewFileManager(fm.config)
if err != nil {
return fmt.Errorf("cannot delete file: %w", err)
return err // Error recreating the SFTP client
}
newFM, ok := newFileManager.(*fileManagerImpl)
if !ok {
return errors.New("error while typecasting")
}

fm.client = newFM.client // Update the SFTP client
return nil
}

0 comments on commit 8a8a9ff

Please sign in to comment.