diff --git a/sftp/client.go b/sftp/client.go index 448de0b1..84ffba76 100644 --- a/sftp/client.go +++ b/sftp/client.go @@ -87,6 +87,7 @@ type Client interface { Create(path string) (io.WriteCloser, error) Open(path string) (io.ReadCloser, error) Remove(path string) error + MkdirAll(path string) error } // newSFTPClient creates an SFTP client with existing SSH client @@ -111,3 +112,7 @@ func (c *clientImpl) Open(path string) (io.ReadCloser, error) { func (c *clientImpl) Remove(path string) error { return c.client.Remove(path) } + +func (c *clientImpl) MkdirAll(path string) error { + return c.client.MkdirAll(path) +} diff --git a/sftp/mock_sftp/mock_sftp_client.go b/sftp/mock_sftp/mock_sftp_client.go index 8cba806c..3e5eeaec 100644 --- a/sftp/mock_sftp/mock_sftp_client.go +++ b/sftp/mock_sftp/mock_sftp_client.go @@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) Create(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockClient)(nil).Create), arg0) } +// MkdirAll mocks base method. +func (m *MockClient) MkdirAll(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MkdirAll", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// MkdirAll indicates an expected call of MkdirAll. +func (mr *MockClientMockRecorder) MkdirAll(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MkdirAll", reflect.TypeOf((*MockClient)(nil).MkdirAll), arg0) +} + // Open mocks base method. func (m *MockClient) Open(arg0 string) (io.ReadCloser, error) { m.ctrl.T.Helper() diff --git a/sftp/sftp.go b/sftp/sftp.go index 1bfb9786..37f96128 100644 --- a/sftp/sftp.go +++ b/sftp/sftp.go @@ -37,7 +37,7 @@ func NewFileManager(sshClient *ssh.Client) (FileManager, error) { } // Upload uploads a file to the remote server -func (fm *fileManagerImpl) Upload(localFilePath, remoteDir string) error { +func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error { localFile, err := os.Open(localFilePath) if err != nil { return fmt.Errorf("cannot open local file: %w", err) @@ -46,8 +46,13 @@ func (fm *fileManagerImpl) Upload(localFilePath, remoteDir string) error { _ = localFile.Close() }() - remoteFileName := filepath.Join(remoteDir, filepath.Base(localFilePath)) - remoteFile, err := fm.client.Create(remoteFileName) + // Create the directory if it does not exist + remoteDir := filepath.Dir(remoteFilePath) + if err := fm.client.MkdirAll(remoteDir); err != nil { + return fmt.Errorf("cannot create remote directory: %w", err) + } + + remoteFile, err := fm.client.Create(remoteFilePath) if err != nil { return fmt.Errorf("cannot create remote file: %w", err) } diff --git a/sftp/sftp_test.go b/sftp/sftp_test.go index 32692851..ae0b561c 100644 --- a/sftp/sftp_test.go +++ b/sftp/sftp_test.go @@ -138,10 +138,11 @@ func TestUpload(t *testing.T) { mockSFTPClient := mock_sftp.NewMockClient(ctrl) mockSFTPClient.EXPECT().Create(gomock.Any()).Return(&nopWriteCloser{remoteBuf}, nil) + mockSFTPClient.EXPECT().MkdirAll(gomock.Any()).Return(nil) fileManager := &fileManagerImpl{client: mockSFTPClient} - err = fileManager.Upload(localFilePath, "someRemoteDir") + err = fileManager.Upload(localFilePath, "someRemotePath") require.NoError(t, err) require.Equal(t, data, remoteBuf.Bytes()) } @@ -226,7 +227,7 @@ func TestSFTP(t *testing.T) { require.NoError(t, err) defer func() { _ = session.Close() }() - remoteDir := filepath.Join("/tmp", "remote") + remoteDir := filepath.Join("/tmp", "remote", "data") err = session.Run(fmt.Sprintf("mkdir -p %s", remoteDir)) require.NoError(t, err) @@ -252,7 +253,7 @@ func TestSFTP(t *testing.T) { err = os.WriteFile(localFilePath, data, 0o644) require.NoError(t, err) - err = sftpManger.Upload(localFilePath, remoteDir) + err = sftpManger.Upload(localFilePath, remoteFilePath) require.NoError(t, err) err = sftpManger.Download(remoteFilePath, baseDir)