diff --git a/drivers/sftp/driver.go b/drivers/sftp/driver.go index 77f5198457c..1f216598d2d 100644 --- a/drivers/sftp/driver.go +++ b/drivers/sftp/driver.go @@ -16,7 +16,8 @@ import ( type SFTP struct { model.Storage Addition - client *sftp.Client + client *sftp.Client + clientConnectionError error } func (d *SFTP) Config() driver.Config { @@ -39,6 +40,9 @@ func (d *SFTP) Drop(ctx context.Context) error { } func (d *SFTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.clientReconnectOnConnectionError(); err != nil { + return nil, err + } log.Debugf("[sftp] list dir: %s", dir.GetPath()) files, err := d.client.ReadDir(dir.GetPath()) if err != nil { @@ -51,6 +55,9 @@ func (d *SFTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([] } func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.clientReconnectOnConnectionError(); err != nil { + return nil, err + } remoteFile, err := d.client.Open(file.GetPath()) if err != nil { return nil, err @@ -62,14 +69,23 @@ func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* } func (d *SFTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } return d.client.MkdirAll(path.Join(parentDir.GetPath(), dirName)) } func (d *SFTP) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } return d.client.Rename(srcObj.GetPath(), path.Join(dstDir.GetPath(), srcObj.GetName())) } func (d *SFTP) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } return d.client.Rename(srcObj.GetPath(), path.Join(path.Dir(srcObj.GetPath()), newName)) } @@ -78,10 +94,16 @@ func (d *SFTP) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { } func (d *SFTP) Remove(ctx context.Context, obj model.Obj) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } return d.remove(obj.GetPath()) } func (d *SFTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } dstFile, err := d.client.Create(path.Join(dstDir.GetPath(), stream.GetName())) if err != nil { return err diff --git a/drivers/sftp/util.go b/drivers/sftp/util.go index 3deb8dcf94b..eaeeaff5814 100644 --- a/drivers/sftp/util.go +++ b/drivers/sftp/util.go @@ -4,6 +4,7 @@ import ( "path" "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) @@ -30,6 +31,23 @@ func (d *SFTP) initClient() error { return err } d.client, err = sftp.NewClient(conn) + if err == nil { + d.clientConnectionError = nil + go func(d *SFTP) { + d.clientConnectionError = d.client.Wait() + }(d) + } + return err +} + +func (d *SFTP) clientReconnectOnConnectionError() error { + err := d.clientConnectionError + if err == nil { + return nil + } + log.Debugf("[sftp] discarding closed sftp connection: %v", err) + _ = d.client.Close() + err = d.initClient() return err }