Skip to content

Commit

Permalink
support transfer in background
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Jul 14, 2024
1 parent 615fd3d commit 1a1c1eb
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 33 deletions.
41 changes: 41 additions & 0 deletions trzsz/comm.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
"time"

"github.com/klauspost/compress/zstd"
"golang.org/x/term"
)

var onExitFuncs []func()
Expand Down Expand Up @@ -115,6 +116,7 @@ type baseArgs struct {
Escape bool `arg:"-e" help:"escape all known control characters"`
Directory bool `arg:"-d" help:"transfer directories and files"`
Recursive bool `arg:"-r" help:"transfer directories and files, same as -d"`
Fork bool `arg:"-f" help:"fork to transfer in background (implies -q)"`
Bufsize bufferSize `arg:"-B" placeholder:"N" default:"10M" help:"max buffer chunk size (1K<=N<=1G). (default: 10M)"`
Timeout int `arg:"-t" placeholder:"N" default:"20" help:"timeout ( N seconds ) for each buffer chunk.\nN <= 0 means never timeout. (default: 20)"`
Compress compressType `arg:"-c" placeholder:"yes/no/auto" default:"auto" help:"compress type (default: auto)"`
Expand Down Expand Up @@ -1028,3 +1030,42 @@ func (w *promptWriter) Write(p []byte) (int, error) {
func (w *promptWriter) Close() error {
return nil
}

func forkToBackground() (bool, error) {
if v := os.Getenv("TRZSZ-FORK-BACKGROUND"); v == "TRUE" {
return false, nil
}

cmd := exec.Command(os.Args[0], os.Args[1:]...)
cmd.Env = append(os.Environ(), "TRZSZ-FORK-BACKGROUND=TRUE")
cmd.SysProcAttr = getSysProcAttr()
cmd.Stdout = os.Stdout
stdin, err := cmd.StdinPipe()
if err != nil {
return true, fmt.Errorf("fork stdin pipe failed: %v", err)
}
defer stdin.Close()
stderr, err := cmd.StderrPipe()
if err != nil {
return true, fmt.Errorf("fork stderr pipe failed: %v", err)
}
defer stderr.Close()
if err := cmd.Start(); err != nil {
return true, fmt.Errorf("fork start failed: %v", err)
}

fd := int(os.Stdin.Fd())
if term.IsTerminal(fd) {
state, err := term.MakeRaw(fd)
if err != nil {
return true, fmt.Errorf("make stdin raw failed: %v\r\n", err)
}
defer func() { _ = term.Restore(fd, state) }()
}
go func() {
_, _ = io.Copy(stdin, os.Stdin)
}()

_, _ = io.Copy(os.Stderr, stderr)
return true, nil
}
43 changes: 25 additions & 18 deletions trzsz/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,28 +449,35 @@ func (filter *TrzszFilter) handleTrzsz() {
transfer.connectToTunnel(*connector, filter.trigger.uniqueID, filter.trigger.tunnelPort)
}

defer func() {
transfer.cleanup()
filter.transfer.CompareAndSwap(transfer, nil)
}()
defer filter.transfer.CompareAndSwap(transfer, nil)

defer func() {
if err := recover(); err != nil {
transfer.clientError(newTrzszError(fmt.Sprintf("%v", err), "panic", true))
done := make(chan struct{}, 1)
go func() {
defer close(done)
defer func() {
if err := recover(); err != nil {
transfer.clientError(newTrzszError(fmt.Sprintf("%v", err), "panic", true))
}
}()
var err error
switch filter.trigger.mode {
case 'S':
err = filter.downloadFiles(transfer)
case 'R':
err = filter.uploadFiles(transfer, false)
case 'D':
err = filter.uploadFiles(transfer, true)
}
if err != nil {
transfer.clientError(err)
}
transfer.cleanup()
done <- struct{}{}
}()

var err error
switch filter.trigger.mode {
case 'S':
err = filter.downloadFiles(transfer)
case 'R':
err = filter.uploadFiles(transfer, false)
case 'D':
err = filter.uploadFiles(transfer, true)
}
if err != nil {
transfer.clientError(err)
select {
case <-done:
case <-transfer.background():
}
}

Expand Down
39 changes: 38 additions & 1 deletion trzsz/transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type transferAction struct {
SupportBinary bool `json:"binary"`
SupportDirectory bool `json:"support_dir"`
TunnelConnected bool `json:"tunnel"`
SupportFork bool `json:"fork"`
}

type transferConfig struct {
Expand All @@ -77,6 +78,7 @@ type transferConfig struct {
TmuxPaneColumns int32 `json:"tmux_pane_width"`
TmuxOutputJunk bool `json:"tmux_output_junk"`
CompressType compressType `json:"compress"`
Fork bool `json:"fork"`
}

type trzszTransfer struct {
Expand Down Expand Up @@ -106,6 +108,8 @@ type trzszTransfer struct {
tunnelConnected bool
tunnelConn atomic.Pointer[net.Conn]
tunnelInitWG sync.WaitGroup
bgChan chan struct{}
termReseted atomic.Bool
}

func maxDuration(a, b time.Duration) time.Duration {
Expand Down Expand Up @@ -143,6 +147,7 @@ func newTransfer(writer io.Writer, stdinState *term.State, flushInTime bool, log
MaxBufSize: 10 * 1024 * 1024,
},
logger: logger,
bgChan: make(chan struct{}, 1),
}
t.bufInitPhase.Store(true)
t.bufferSize.Store(10240)
Expand Down Expand Up @@ -245,6 +250,19 @@ func (t *trzszTransfer) cleanup() {
}
}

func (t *trzszTransfer) background() <-chan struct{} {
return t.bgChan
}

func (t *trzszTransfer) switchToBackground() {
os.Stdin.Close()
go func() {
time.Sleep(500 * time.Millisecond) // wait for client switch to background
t.resetTerm("Switch to transfer in background.", true)
os.Stderr.Close()
}()
}

func (t *trzszTransfer) addReceivedData(buf []byte, tunnel bool) {
if t.tunnelConnected && !tunnel {
if t.logger != nil {
Expand All @@ -253,7 +271,7 @@ func (t *trzszTransfer) addReceivedData(buf []byte, tunnel bool) {
return
}
if t.logger != nil {
t.logger.writeTraceLog(buf, "svrout")
t.logger.writeTraceLog(buf, "rcvbuf")
}
if !t.stopped.Load() {
t.buffer.addBuffer(buf)
Expand Down Expand Up @@ -582,6 +600,7 @@ func (t *trzszTransfer) sendAction(confirm bool, serverVersion *trzszVersion, re
t.writer = *conn
t.tunnelConnected = true
action.TunnelConnected = true
action.SupportFork = true
}

if !t.tunnelConnected && (isWindowsEnvironment() || remoteIsWindows) {
Expand Down Expand Up @@ -638,6 +657,11 @@ func (t *trzszTransfer) sendConfig(args *baseArgs, action *transferAction, escap
}
if action.TunnelConnected {
cfgMap["binary"] = true
if args.Fork {
t.switchToBackground()
cfgMap["fork"] = true
cfgMap["quiet"] = true
}
} else if args.Binary {
cfgMap["binary"] = true
cfgMap["escape_chars"] = escapeChars
Expand Down Expand Up @@ -680,6 +704,9 @@ func (t *trzszTransfer) recvConfig() (*transferConfig, error) {
if err := json.Unmarshal([]byte(cfgStr), &t.transferConfig); err != nil {
return nil, err
}
if t.transferConfig.Fork {
t.bgChan <- struct{}{}
}
return &t.transferConfig, nil
}

Expand All @@ -693,6 +720,16 @@ func (t *trzszTransfer) recvExit() (string, error) {

func (t *trzszTransfer) serverExit(msg string) {
t.cleanInput(500 * time.Millisecond)
t.resetTerm(msg, false)
}

func (t *trzszTransfer) resetTerm(msg string, ignorable bool) {
if !t.termReseted.CompareAndSwap(false, true) {
if !ignorable {
os.Stdout.WriteString(fmt.Sprintf("\x1b7\r\n%s\r\n\x1b8", msg))
}
return
}
if t.stdinState != nil {
_ = term.Restore(int(os.Stdin.Fd()), t.stdinState)
}
Expand Down
34 changes: 30 additions & 4 deletions trzsz/trz.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ func recvFiles(transfer *trzszTransfer, args *trzArgs, tmuxMode tmuxModeType, tm
args.Binary = false
}

// check if the client doesn't support fork to background
if args.Fork && !action.SupportFork {
return simpleTrzszError("The client doesn't support fork to background")
}

// check if the client doesn't support transfer directory
if args.Directory && !action.SupportDirectory {
return simpleTrzszError("The client doesn't support transfer directory")
Expand Down Expand Up @@ -109,6 +114,18 @@ func recvFiles(transfer *trzszTransfer, args *trzArgs, tmuxMode tmuxModeType, tm
func TrzMain() int {
args := parseTrzArgs(os.Args)

// fork to background
if args.Fork {
parent, err := forkToBackground()
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 1
}
if parent {
return 0
}
}

// cleanup on exit
defer cleanupOnExit()

Expand Down Expand Up @@ -187,11 +204,20 @@ func TrzMain() int {
wrapTransferInput(transfer, os.Stdin, false)
handleServerSignal(transfer)

if err := recvFiles(transfer, args, tmuxMode, tmuxPaneWidth); err != nil {
transfer.serverError(err)
}
done := make(chan struct{}, 1)
go func() {
defer close(done)
if err := recvFiles(transfer, args, tmuxMode, tmuxPaneWidth); err != nil {
transfer.serverError(err)
}
transfer.cleanup()
done <- struct{}{}
}()

transfer.cleanup()
select {
case <-done:
case <-transfer.background():
}

return 0
}
4 changes: 4 additions & 0 deletions trzsz/trz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ func TestTrzArgs(t *testing.T) {
assertArgsEqual("-b", newTrzArgs(baseArgs{Binary: true}, "."))
assertArgsEqual("-e", newTrzArgs(baseArgs{Escape: true}, "."))
assertArgsEqual("-d", newTrzArgs(baseArgs{Directory: true}, "."))
assertArgsEqual("-d -d", newTrzArgs(baseArgs{Directory: true}, "."))
assertArgsEqual("-r", newTrzArgs(baseArgs{Directory: true, Recursive: true}, "."))
assertArgsEqual("-f", newTrzArgs(baseArgs{Fork: true}, "."))
assertArgsEqual("-B 2k", newTrzArgs(baseArgs{Bufsize: bufferSize{2 * 1024}}, "."))
assertArgsEqual("-t 3", newTrzArgs(baseArgs{Timeout: 3}, "."))
assertArgsEqual("-cNo", newTrzArgs(baseArgs{Compress: kCompressNo}, "."))
Expand All @@ -73,7 +75,9 @@ func TestTrzArgs(t *testing.T) {
assertArgsEqual("--binary", newTrzArgs(baseArgs{Binary: true}, "."))
assertArgsEqual("--escape", newTrzArgs(baseArgs{Escape: true}, "."))
assertArgsEqual("--directory", newTrzArgs(baseArgs{Directory: true}, "."))
assertArgsEqual("--directory -d", newTrzArgs(baseArgs{Directory: true}, "."))
assertArgsEqual("--recursive", newTrzArgs(baseArgs{Directory: true, Recursive: true}, "."))
assertArgsEqual("--fork", newTrzArgs(baseArgs{Fork: true}, "."))
assertArgsEqual("--bufsize 2M", newTrzArgs(baseArgs{Bufsize: bufferSize{2 * 1024 * 1024}}, "."))
assertArgsEqual("--timeout 55", newTrzArgs(baseArgs{Timeout: 55}, "."))
assertArgsEqual("--compress No", newTrzArgs(baseArgs{Compress: kCompressNo}, "."))
Expand Down
34 changes: 30 additions & 4 deletions trzsz/tsz.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func sendFiles(transfer *trzszTransfer, files []*sourceFile, args *tszArgs, tmux
args.Binary = false
}

// check if the client doesn't support fork to background
if args.Fork && !action.SupportFork {
return simpleTrzszError("The client doesn't support fork to background")
}

// check if the client doesn't support transfer directory
if args.Directory && !action.SupportDirectory {
return simpleTrzszError("The client doesn't support transfer directory")
Expand Down Expand Up @@ -108,6 +113,18 @@ func sendFiles(transfer *trzszTransfer, files []*sourceFile, args *tszArgs, tmux
func TszMain() int {
args := parseTszArgs(os.Args)

// fork to background
if args.Fork {
parent, err := forkToBackground()
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 1
}
if parent {
return 0
}
}

// cleanup on exit
defer cleanupOnExit()

Expand Down Expand Up @@ -183,11 +200,20 @@ func TszMain() int {
wrapTransferInput(transfer, os.Stdin, false)
handleServerSignal(transfer)

if err := sendFiles(transfer, files, args, tmuxMode, tmuxPaneWidth); err != nil {
transfer.serverError(err)
}
done := make(chan struct{}, 1)
go func() {
defer close(done)
if err := sendFiles(transfer, files, args, tmuxMode, tmuxPaneWidth); err != nil {
transfer.serverError(err)
}
transfer.cleanup()
done <- struct{}{}
}()

transfer.cleanup()
select {
case <-done:
case <-transfer.background():
}

return 0
}
2 changes: 2 additions & 0 deletions trzsz/tsz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func TestTszArgs(t *testing.T) {
assertArgsEqual("-e a", newTszArgs(baseArgs{Escape: true}, []string{"a"}))
assertArgsEqual("-d a", newTszArgs(baseArgs{Directory: true}, []string{"a"}))
assertArgsEqual("-r a", newTszArgs(baseArgs{Directory: true, Recursive: true}, []string{"a"}))
assertArgsEqual("-f a", newTszArgs(baseArgs{Fork: true}, []string{"a"}))
assertArgsEqual("-B 2k a", newTszArgs(baseArgs{Bufsize: bufferSize{2 * 1024}}, []string{"a"}))
assertArgsEqual("-t 3 a", newTszArgs(baseArgs{Timeout: 3}, []string{"a"}))
assertArgsEqual("-cno a", newTszArgs(baseArgs{Compress: kCompressNo}, []string{"a"}))
Expand All @@ -74,6 +75,7 @@ func TestTszArgs(t *testing.T) {
assertArgsEqual("--escape a", newTszArgs(baseArgs{Escape: true}, []string{"a"}))
assertArgsEqual("--directory a", newTszArgs(baseArgs{Directory: true}, []string{"a"}))
assertArgsEqual("--recursive a", newTszArgs(baseArgs{Directory: true, Recursive: true}, []string{"a"}))
assertArgsEqual("--fork a", newTszArgs(baseArgs{Fork: true}, []string{"a"}))
assertArgsEqual("--bufsize 2M a", newTszArgs(baseArgs{Bufsize: bufferSize{2 * 1024 * 1024}}, []string{"a"}))
assertArgsEqual("--timeout 55 a", newTszArgs(baseArgs{Timeout: 55}, []string{"a"}))
assertArgsEqual("--compress NO a", newTszArgs(baseArgs{Compress: kCompressNo}, []string{"a"}))
Expand Down
Loading

0 comments on commit 1a1c1eb

Please sign in to comment.