-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Dilip Kola <kdilipkola@gmail.com> Co-authored-by: Akash Chetty <achetty.iitr@gmail.com>
- Loading branch information
1 parent
c788d93
commit f0b67e9
Showing
7 changed files
with
607 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
//go:generate mockgen -destination=mock_sftp/mock_sftp_client.go -package mock_sftp github.com/rudderlabs/rudder-go-kit/sftp Client | ||
package sftp | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"io" | ||
"time" | ||
|
||
"github.com/pkg/sftp" | ||
"golang.org/x/crypto/ssh" | ||
) | ||
|
||
// SSHConfig represents the configuration for SSH connection | ||
type SSHConfig struct { | ||
HostName string | ||
Port int | ||
User string | ||
AuthMethod string | ||
PrivateKey string | ||
Password string // Password for password-based authentication | ||
DialTimeout time.Duration | ||
} | ||
|
||
// sshClientConfig constructs an SSH client configuration based on the provided SSHConfig. | ||
func sshClientConfig(config *SSHConfig) (*ssh.ClientConfig, error) { | ||
if config == nil { | ||
return nil, errors.New("config should not be nil") | ||
} | ||
|
||
if config.HostName == "" { | ||
return nil, errors.New("hostname should not be empty") | ||
} | ||
|
||
if config.Port == 0 { | ||
return nil, errors.New("port should not be empty") | ||
} | ||
|
||
if config.User == "" { | ||
return nil, errors.New("user should not be empty") | ||
} | ||
|
||
var authMethods ssh.AuthMethod | ||
|
||
switch config.AuthMethod { | ||
case PasswordAuth: | ||
authMethods = ssh.Password(config.Password) | ||
case KeyAuth: | ||
privateKey, err := ssh.ParsePrivateKey([]byte(config.PrivateKey)) | ||
if err != nil { | ||
return nil, fmt.Errorf("cannot parse private key: %w", err) | ||
} | ||
authMethods = ssh.PublicKeys(privateKey) | ||
default: | ||
return nil, errors.New("unsupported authentication method") | ||
} | ||
|
||
sshConfig := &ssh.ClientConfig{ | ||
User: config.User, | ||
Auth: []ssh.AuthMethod{authMethods}, | ||
Timeout: config.DialTimeout, | ||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), | ||
} | ||
|
||
return sshConfig, nil | ||
} | ||
|
||
// NewSSHClient establishes an SSH connection and returns an SSH client | ||
func NewSSHClient(config *SSHConfig) (*ssh.Client, error) { | ||
sshConfig, err := sshClientConfig(config) | ||
if err != nil { | ||
return nil, fmt.Errorf("cannot configure SSH client: %w", err) | ||
} | ||
|
||
sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", config.HostName, config.Port), sshConfig) | ||
if err != nil { | ||
return nil, fmt.Errorf("cannot dial SSH host %q:%d: %w", config.HostName, config.Port, err) | ||
} | ||
return sshClient, nil | ||
} | ||
|
||
type clientImpl struct { | ||
client *sftp.Client | ||
} | ||
|
||
type Client interface { | ||
Create(path string) (io.WriteCloser, error) | ||
Open(path string) (io.ReadCloser, error) | ||
Remove(path string) error | ||
} | ||
|
||
// newSFTPClient creates an SFTP client with existing SSH client | ||
func newSFTPClient(client *ssh.Client) (Client, error) { | ||
sftpClient, err := sftp.NewClient(client) | ||
if err != nil { | ||
return nil, fmt.Errorf("cannot create SFTP client: %w", err) | ||
} | ||
return &clientImpl{ | ||
client: sftpClient, | ||
}, nil | ||
} | ||
|
||
func (c *clientImpl) Create(path string) (io.WriteCloser, error) { | ||
return c.client.Create(path) | ||
} | ||
|
||
func (c *clientImpl) Open(path string) (io.ReadCloser, error) { | ||
return c.client.Open(path) | ||
} | ||
|
||
func (c *clientImpl) Remove(path string) error { | ||
return c.client.Remove(path) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
package sftp | ||
|
||
import ( | ||
"fmt" | ||
"io" | ||
"os" | ||
"path/filepath" | ||
|
||
"golang.org/x/crypto/ssh" | ||
) | ||
|
||
const ( | ||
// PasswordAuth indicates password-based authentication | ||
PasswordAuth = "passwordAuth" | ||
// KeyAuth indicates key-based authentication | ||
KeyAuth = "keyAuth" | ||
) | ||
|
||
// FileManager is an interface for managing files on a remote server | ||
type FileManager interface { | ||
Upload(localFilePath, remoteDir string) error | ||
Download(remoteFilePath, localDir string) error | ||
Delete(remoteFilePath string) error | ||
} | ||
|
||
// fileManagerImpl is a real implementation of FileManager | ||
type fileManagerImpl struct { | ||
client Client | ||
} | ||
|
||
func NewFileManager(sshClient *ssh.Client) (FileManager, error) { | ||
sftpClient, err := newSFTPClient(sshClient) | ||
if err != nil { | ||
return nil, fmt.Errorf("cannot create SFTP client: %w", err) | ||
} | ||
return &fileManagerImpl{client: sftpClient}, nil | ||
} | ||
|
||
// Upload uploads a file to the remote server | ||
func (fm *fileManagerImpl) Upload(localFilePath, remoteDir string) error { | ||
localFile, err := os.Open(localFilePath) | ||
if err != nil { | ||
return fmt.Errorf("cannot open local file: %w", err) | ||
} | ||
defer func() { | ||
_ = localFile.Close() | ||
}() | ||
|
||
remoteFileName := filepath.Join(remoteDir, filepath.Base(localFilePath)) | ||
remoteFile, err := fm.client.Create(remoteFileName) | ||
if err != nil { | ||
return fmt.Errorf("cannot create remote file: %w", err) | ||
} | ||
defer func() { | ||
_ = remoteFile.Close() | ||
}() | ||
|
||
_, err = io.Copy(remoteFile, localFile) | ||
if err != nil { | ||
return fmt.Errorf("error copying file: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Download downloads a file from the remote server | ||
func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error { | ||
remoteFile, err := fm.client.Open(remoteFilePath) | ||
if err != nil { | ||
return fmt.Errorf("cannot open remote file: %w", err) | ||
} | ||
defer func() { | ||
_ = remoteFile.Close() | ||
}() | ||
|
||
localFileName := filepath.Join(localDir, filepath.Base(remoteFilePath)) | ||
localFile, err := os.Create(localFileName) | ||
if err != nil { | ||
return fmt.Errorf("cannot create local file: %w", err) | ||
} | ||
defer func() { | ||
_ = localFile.Close() | ||
}() | ||
|
||
_, err = io.Copy(localFile, remoteFile) | ||
if err != nil { | ||
return fmt.Errorf("cannot copy remote file content to local file: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Delete deletes a file on the remote server | ||
func (fm *fileManagerImpl) Delete(remoteFilePath string) error { | ||
err := fm.client.Remove(remoteFilePath) | ||
if err != nil { | ||
return fmt.Errorf("cannot delete file: %w", err) | ||
} | ||
|
||
return nil | ||
} |
Oops, something went wrong.