From fc290aa1e938413ae7c539f33d0d354a5e6bf189 Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sun, 3 Dec 2023 23:20:37 +0800 Subject: [PATCH] support zmodem lrzsz ( rz / sz ) --- README.md | 10 ++ go.mod | 16 +- go.sum | 32 ++-- trzsz/comm.go | 6 + trzsz/filter.go | 55 ++++++- trzsz/progress.go | 72 ++++---- trzsz/pty_unix.go | 12 +- trzsz/pty_windows.go | 88 ++++------ trzsz/transfer.go | 3 +- trzsz/trz.go | 11 +- trzsz/trzsz.go | 22 ++- trzsz/tsz.go | 11 +- trzsz/zmodem.go | 384 +++++++++++++++++++++++++++++++++++++++++++ trzsz/zmodem_test.go | 70 ++++++++ 14 files changed, 637 insertions(+), 155 deletions(-) create mode 100644 trzsz/zmodem.go create mode 100644 trzsz/zmodem_test.go diff --git a/README.md b/README.md index 16d418f..254efb9 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,16 @@ DefaultDownloadPath = /Users/username/Downloads/ - If the `DefaultDownloadPath` is not empty, downloading files will be saved to the path automatically instead of asking each time. +## Zmodem support + +- Use `-z` or `--zmodem` to enable the `rz / sz` feature. e.g., `trzsz -z ssh remote_server`. + +- `lrzsz` needs to be installed on the client ( local computer ). e.g., `brew install lrzsz`, `apt install lrzsz`, etc. + +- `trzsz --zmodem ssh xxx` is not supported on Windows. You can use [trzsz-ssh ( tssh )](https://trzsz.github.io/ssh) instead, `tssh --zmodem xxx`. + +- About the progress, the transferred and speed are not precise. They appear larger than reality. It just indicating that the transfer is in progress. + ## Trouble shooting - If using [MSYS2](https://www.msys2.org/) or [Git Bash](https://www.atlassian.com/git/tutorials/git-bash) on windows, and getting an error `The handle is invalid`. diff --git a/go.mod b/go.mod index 6e9c569..ba2b947 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,16 @@ module github.com/trzsz/trzsz-go go 1.20 require ( - github.com/UserExistsError/conpty v0.1.1 - github.com/creack/pty v1.1.18 - github.com/klauspost/compress v1.17.1 + github.com/UserExistsError/conpty v0.1.2 + github.com/creack/pty v1.1.21 + github.com/klauspost/compress v1.17.4 github.com/ncruces/zenity v0.10.10 github.com/stretchr/testify v1.8.4 github.com/trzsz/go-arg v1.5.2 - github.com/trzsz/promptui v0.10.3 - golang.org/x/sys v0.13.0 - golang.org/x/term v0.13.0 - golang.org/x/text v0.13.0 + github.com/trzsz/promptui v0.10.5 + golang.org/x/sys v0.15.0 + golang.org/x/term v0.15.0 + golang.org/x/text v0.14.0 ) require ( @@ -24,6 +24,6 @@ require ( github.com/josephspurrier/goversioninfo v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/randall77/makefat v0.0.0-20210315173500-7ddd0e42c844 // indirect - golang.org/x/image v0.13.0 // indirect + golang.org/x/image v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c187069..a01d8c1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/UserExistsError/conpty v0.1.1 h1:cHDsU/XeoeDAQmVvCTV53SrXLG39YJ4++Pp3iAi1gXE= -github.com/UserExistsError/conpty v0.1.1/go.mod h1:PDglKIkX3O/2xVk0MV9a6bCWxRmPVfxqZoTG/5sSd9I= +github.com/UserExistsError/conpty v0.1.2 h1:ikx+zk1ekB8Agiajun6Cpg4Ju/cEaU/mnRZQYT21naI= +github.com/UserExistsError/conpty v0.1.2/go.mod h1:PDglKIkX3O/2xVk0MV9a6bCWxRmPVfxqZoTG/5sSd9I= github.com/akavel/rsrc v0.10.2 h1:Zxm8V5eI1hW4gGaYsJQUhxpjkENuG91ki8B4zCrvEsw= github.com/akavel/rsrc v0.10.2/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= @@ -10,8 +10,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= -github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= -github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= +github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -19,8 +19,8 @@ github.com/dchest/jsmin v0.0.0-20220218165748-59f39799265f h1:OGqDDftRTwrvUoL6pO github.com/dchest/jsmin v0.0.0-20220218165748-59f39799265f/go.mod h1:Dv9D0NUlAsaQcGQZa5kc5mqR9ua72SmA8VXi4cd+cBw= github.com/josephspurrier/goversioninfo v1.4.0 h1:Puhl12NSHUSALHSuzYwPYQkqa2E1+7SrtAPJorKK0C8= github.com/josephspurrier/goversioninfo v1.4.0/go.mod h1:JWzv5rKQr+MmW+LvM412ToT/IkYDZjaclF2pKDss8IY= -github.com/klauspost/compress v1.17.1 h1:NE3C767s2ak2bweCZo3+rdP4U/HoyVXLv/X9f2gPS5g= -github.com/klauspost/compress v1.17.1/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/ncruces/zenity v0.10.10 h1:V/rtAhr5QLdDThahOkm7EYlnw4RuEsf7oN+Xb6lz1j0= github.com/ncruces/zenity v0.10.10/go.mod h1:k3k4hJ4Wt1MUbeV48y+Gbl7Fp9skfGszN/xtKmuvhZk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -34,20 +34,20 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/trzsz/go-arg v1.5.2 h1:zGxCuTKvtC3jBf7HbvNk0HooUjv8uKAy2mY+bHVhRas= github.com/trzsz/go-arg v1.5.2/go.mod h1:IC6Z/FiVH7uYvcbp1/gJhDYCFPS/GkL0APYakVvgY4I= -github.com/trzsz/promptui v0.10.3 h1:uhcLQsLZqMxEtGiYoeM2lR/Hd4pSxoYsd2eFctH8MCs= -github.com/trzsz/promptui v0.10.3/go.mod h1:GMZtu6ZTzU73CBFkzGtmB4wnTROIAbv4GFA74fV8V8g= +github.com/trzsz/promptui v0.10.5 h1:tlzJkx+JOeE0sqKWmqgaoToZiYqj5G1Mz+QDV97VFu8= +github.com/trzsz/promptui v0.10.5/go.mod h1:GMZtu6ZTzU73CBFkzGtmB4wnTROIAbv4GFA74fV8V8g= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= -golang.org/x/image v0.13.0 h1:3cge/F/QTkNLauhf2QoE9zp+7sr+ZcL4HnoZmdwg9sg= -golang.org/x/image v0.13.0/go.mod h1:6mmbMOeV28HuMTgA6OSRkdXKYw/t5W9Uwn2Yv1r3Yxk= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/trzsz/comm.go b/trzsz/comm.go index c5028cf..f9c889f 100644 --- a/trzsz/comm.go +++ b/trzsz/comm.go @@ -50,6 +50,12 @@ import ( var onExitFuncs []func() +func cleanupOnExit() { + for i := len(onExitFuncs) - 1; i >= 0; i-- { + onExitFuncs[i]() + } +} + var timeNowFunc = time.Now var linuxRuntime bool = (runtime.GOOS == "linux") diff --git a/trzsz/filter.go b/trzsz/filter.go index 0db2874..3e5f5a4 100644 --- a/trzsz/filter.go +++ b/trzsz/filter.go @@ -53,6 +53,8 @@ type TrzszOptions struct { // DetectTraceLog is for debugging. // If DetectTraceLog is true, will detect the server output to determine whether to enable trace logging. DetectTraceLog bool + // EnableZmodem enable zmodem lrzsz ( rz / sz ) feature. + EnableZmodem bool } // TrzszFilter is a filter that supports trzsz ( trz / tsz ). @@ -63,6 +65,7 @@ type TrzszFilter struct { serverOut io.Reader options TrzszOptions transfer atomic.Pointer[trzszTransfer] + zmodem atomic.Pointer[zmodemTransfer] progress atomic.Pointer[textProgressBar] promptPipe atomic.Pointer[io.PipeWriter] trigger *trzszTrigger @@ -227,6 +230,8 @@ func (filter *TrzszFilter) getDefaultDownloadPath() string { return resolveHomeDir(path) } +var errUserCanceled = fmt.Errorf("Cancelled") + var parentWindowID = getParentWindowID() func zenityErrorWithTips(err error) error { @@ -265,12 +270,12 @@ func (filter *TrzszFilter) chooseDownloadPath() (string, error) { options = append(options, zenity.Attach(parentWindowID)) } path, err := zenity.SelectFile(options...) + if err == zenity.ErrCanceled || len(path) == 0 { + return "", errUserCanceled + } if err != nil { return "", zenityErrorWithTips(err) } - if len(path) == 0 { - return "", zenity.ErrCanceled - } return path, nil } @@ -294,18 +299,18 @@ func (filter *TrzszFilter) chooseUploadPaths(directory bool) ([]string, error) { options = append(options, zenity.Attach(parentWindowID)) } files, err := zenity.SelectFileMultiple(options...) + if err == zenity.ErrCanceled || len(files) == 0 { + return nil, errUserCanceled + } if err != nil { return nil, zenityErrorWithTips(err) } - if len(files) == 0 { - return nil, zenity.ErrCanceled - } return files, nil } func (filter *TrzszFilter) downloadFiles(transfer *trzszTransfer) error { path, err := filter.chooseDownloadPath() - if err == zenity.ErrCanceled { + if err == errUserCanceled { return transfer.sendAction(false, filter.trigger.version, filter.trigger.winServer) } if err != nil { @@ -344,7 +349,7 @@ func (filter *TrzszFilter) downloadFiles(transfer *trzszTransfer) error { func (filter *TrzszFilter) uploadFiles(transfer *trzszTransfer, directory bool) error { paths, err := filter.chooseUploadPaths(directory) - if err == zenity.ErrCanceled { + if err == errUserCanceled { return transfer.sendAction(false, filter.trigger.version, filter.trigger.winServer) } if err != nil { @@ -602,6 +607,16 @@ func (filter *TrzszFilter) sendInput(buf []byte, detectDragFile *atomic.Bool) { } return } + if filter.options.EnableZmodem { + if zmodem := filter.zmodem.Load(); zmodem != nil { + if len(buf) == 1 && buf[0] == '\x03' { + zmodem.stopTransferringFiles() // `ctrl + c` to stop transferring files + } + if zmodem.isTransferringFiles() { + return + } + } + } if detectDragFile.Load() { dragFiles, hasDir, ignore := detectDragFiles(buf) if dragFiles != nil { @@ -657,6 +672,16 @@ func (filter *TrzszFilter) wrapOutput() { if filter.logger != nil { buf = filter.logger.writeTraceLog(buf, "svrout") } + if filter.options.EnableZmodem { + if zmodem := filter.zmodem.Load(); zmodem != nil { + if zmodem.handleServerOutput(buf) { + continue + } else { + filter.zmodem.Store(nil) + } + } + } + var trigger *trzszTrigger buf, trigger = detector.detectTrzsz(buf, filter.tunnelConnector.Load() != nil) if trigger != nil { @@ -676,6 +701,20 @@ func (filter *TrzszFilter) wrapOutput() { continue } } + + if filter.options.EnableZmodem { + if zmodem := detectZmodem(buf); zmodem != nil { + _ = writeAll(filter.clientOut, buf) + filter.zmodem.Store(zmodem) + go zmodem.handleZmodemEvent(filter.logger, filter.serverIn, filter.clientOut, + func() ([]string, error) { + return filter.chooseUploadPaths(false) + }, + filter.chooseDownloadPath) + continue + } + } + _ = writeAll(filter.clientOut, buf) } if err == io.EOF { diff --git a/trzsz/progress.go b/trzsz/progress.go index 02438ef..c02280b 100644 --- a/trzsz/progress.go +++ b/trzsz/progress.go @@ -136,6 +136,43 @@ func convertTimeToString(seconds float64) string { const kSpeedArraySize = 30 +type recentSpeed struct { + speedCnt int + speedIdx int + timeArray [kSpeedArraySize]*time.Time + stepArray [kSpeedArraySize]int64 +} + +func (s *recentSpeed) initFirstStep(now *time.Time) { + s.timeArray[0] = now + s.stepArray[0] = 0 + s.speedCnt = 1 + s.speedIdx = 1 +} + +func (s *recentSpeed) getSpeed(step int64, now *time.Time) float64 { + var speed float64 + if s.speedCnt <= kSpeedArraySize { + s.speedCnt++ + speed = float64(step-s.stepArray[0]) / (float64(now.Sub(*s.timeArray[0])) / float64(time.Second)) + } else { + speed = float64(step-s.stepArray[s.speedIdx]) / (float64(now.Sub(*s.timeArray[s.speedIdx])) / float64(time.Second)) + } + + s.timeArray[s.speedIdx] = now + s.stepArray[s.speedIdx] = step + + s.speedIdx++ + if s.speedIdx >= kSpeedArraySize { + s.speedIdx = 0 + } + + if math.IsNaN(speed) { + return -1 + } + return speed +} + type textProgressBar struct { writer io.Writer columns atomic.Int32 @@ -149,10 +186,7 @@ type textProgressBar struct { startTime *time.Time lastUpdateTime *time.Time firstWrite bool - speedCnt int - speedIdx int - timeArray [kSpeedArraySize]*time.Time - stepArray [kSpeedArraySize]int64 + recentSpeed recentSpeed pausing atomic.Bool tmuxPrefix string } @@ -193,10 +227,7 @@ func (p *textProgressBar) onName(name string) { p.fileIdx++ now := timeNowFunc() p.startTime = &now - p.timeArray[0] = p.startTime - p.stepArray[0] = 0 - p.speedCnt = 1 - p.speedIdx = 1 + p.recentSpeed.initFirstStep(&now) p.preSize = 0 p.fileStep = -1 } @@ -268,7 +299,7 @@ func (p *textProgressBar) showProgress() { percentage = fmt.Sprintf("%.0f%%", math.Round(float64(p.fileStep)*100.0/float64(p.fileSize))) } total := convertSizeToString(float64(p.fileStep)) - speed := p.getSpeed(&now) + speed := p.recentSpeed.getSpeed(p.fileStep, &now) speedStr := "--- B/s" etaStr := "--- ETA" if speed > 0 { @@ -290,29 +321,6 @@ func (p *textProgressBar) showProgress() { } } -func (p *textProgressBar) getSpeed(now *time.Time) float64 { - var speed float64 - if p.speedCnt <= kSpeedArraySize { - p.speedCnt++ - speed = float64(p.fileStep-p.stepArray[0]) / (float64(now.Sub(*p.timeArray[0])) / float64(time.Second)) - } else { - speed = float64(p.fileStep-p.stepArray[p.speedIdx]) / (float64(now.Sub(*p.timeArray[p.speedIdx])) / float64(time.Second)) - } - - p.timeArray[p.speedIdx] = now - p.stepArray[p.speedIdx] = p.fileStep - - p.speedIdx++ - if p.speedIdx >= kSpeedArraySize { - p.speedIdx %= kSpeedArraySize - } - - if math.IsNaN(speed) { - return -1 - } - return speed -} - func (p *textProgressBar) getProgressText(percentage, total, speed, eta string) string { const barMinLength = 24 diff --git a/trzsz/pty_unix.go b/trzsz/pty_unix.go index b63efcc..580c168 100644 --- a/trzsz/pty_unix.go +++ b/trzsz/pty_unix.go @@ -48,6 +48,10 @@ type trzszPty struct { closed atomic.Bool } +func setupVirtualTerminal() error { + return nil +} + func spawn(name string, arg ...string) (*trzszPty, error) { // spawn a pty cmd := exec.Command(name, arg...) @@ -125,13 +129,5 @@ func syscallAccessRok(path string) error { return syscall.Access(path, unix.R_OK) } -func enableVirtualTerminal() (uint32, uint32, error) { - return 0, 0, nil -} - -func resetVirtualTerminal(inMode, outMode uint32) error { - return nil -} - func setupConsoleOutput() { } diff --git a/trzsz/pty_windows.go b/trzsz/pty_windows.go index 49fea5a..65d3149 100644 --- a/trzsz/pty_windows.go +++ b/trzsz/pty_windows.go @@ -28,7 +28,6 @@ import ( "context" "io" "os" - "os/exec" "strings" "sync/atomic" "syscall" @@ -42,10 +41,6 @@ type trzszPty struct { Stdin io.ReadWriteCloser Stdout io.ReadWriteCloser cpty *conpty.ConPty - inCP uint32 - outCP uint32 - inMode uint32 - outMode uint32 width int height int closed atomic.Bool @@ -87,50 +82,55 @@ func getConsoleSize() (int, int, error) { return int(info.Window.Right-info.Window.Left) + 1, int(info.Window.Bottom-info.Window.Top) + 1, nil } -func enableVirtualTerminal() (uint32, uint32, error) { +func enableVirtualTerminal() error { var inMode, outMode uint32 inHandle, err := syscall.GetStdHandle(syscall.STD_INPUT_HANDLE) if err != nil { - return 0, 0, err + return err } if err := windows.GetConsoleMode(windows.Handle(inHandle), &inMode); err != nil { - return 0, 0, err + return err } + onExitFuncs = append(onExitFuncs, func() { + windows.SetConsoleMode(windows.Handle(inHandle), inMode) + }) if err := windows.SetConsoleMode(windows.Handle(inHandle), inMode|windows.ENABLE_VIRTUAL_TERMINAL_INPUT); err != nil { - return 0, 0, err + return err } outHandle, err := syscall.GetStdHandle(syscall.STD_OUTPUT_HANDLE) if err != nil { - return 0, 0, err + return err } if err := windows.GetConsoleMode(windows.Handle(outHandle), &outMode); err != nil { - return 0, 0, err + return err } + onExitFuncs = append(onExitFuncs, func() { + windows.SetConsoleMode(windows.Handle(outHandle), outMode) + }) if err := windows.SetConsoleMode(windows.Handle(outHandle), outMode|windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING|windows.DISABLE_NEWLINE_AUTO_RETURN); err != nil { - return 0, 0, err + return err } - return inMode, outMode, nil + return nil } -func resetVirtualTerminal(inMode, outMode uint32) error { - inHandle, err := syscall.GetStdHandle(syscall.STD_INPUT_HANDLE) - if err != nil { - return err - } - if err := windows.SetConsoleMode(windows.Handle(inHandle), inMode); err != nil { +func setupVirtualTerminal() error { + // enable virtual terminal + if err := enableVirtualTerminal(); err != nil { return err } - outHandle, err := syscall.GetStdHandle(syscall.STD_OUTPUT_HANDLE) - if err != nil { - return err - } - if err := windows.SetConsoleMode(windows.Handle(outHandle), outMode); err != nil { - return err - } + // set code page to UTF8 + inCP := getConsoleCP() + outCP := getConsoleOutputCP() + setConsoleCP(CP_UTF8) + setConsoleOutputCP(CP_UTF8) + onExitFuncs = append(onExitFuncs, func() { + setConsoleCP(inCP) + setConsoleOutputCP(outCP) + }) return nil } @@ -142,31 +142,16 @@ func spawn(name string, args ...string) (*trzszPty, error) { return nil, err } - // enable virtual terminal - inMode, outMode, err := enableVirtualTerminal() - if err != nil { - return nil, err - } - - // set code page to UTF8 - inCP := getConsoleCP() - outCP := getConsoleOutputCP() - setConsoleCP(CP_UTF8) - setConsoleOutputCP(CP_UTF8) - // spawn a pty var cmdLine strings.Builder - cmdLine.WriteString(strings.ReplaceAll(name, "\"", "\"\"\"")) + cmdLine.WriteString(windows.EscapeArg(name)) for _, arg := range args { - cmdLine.WriteString(" \"") - cmdLine.WriteString(strings.ReplaceAll(arg, "\"", "\"\"\"")) - cmdLine.WriteString("\"") + cmdLine.WriteString(" ") + cmdLine.WriteString(windows.EscapeArg(arg)) } + cpty, err := conpty.Start(cmdLine.String(), conpty.ConPtyDimensions(width, height)) if err != nil { - setConsoleCP(inCP) - setConsoleOutputCP(outCP) - resetVirtualTerminal(inMode, outMode) return nil, err } @@ -174,10 +159,6 @@ func spawn(name string, args ...string) (*trzszPty, error) { Stdin: cpty, Stdout: cpty, cpty: cpty, - inCP: inCP, - outCP: outCP, - inMode: inMode, - outMode: outMode, width: width, height: height, startTime: time.Now(), @@ -217,15 +198,6 @@ func (t *trzszPty) Close() { } t.closed.Store(true) t.cpty.Close() - setConsoleCP(t.inCP) - setConsoleOutputCP(t.outCP) - resetVirtualTerminal(t.inMode, t.outMode) - if time.Now().Sub(t.startTime) > 10*time.Second { - time.Sleep(100 * time.Millisecond) - cmd := exec.Command("cmd", "/c", "cls") - cmd.Stdout = os.Stdout - cmd.Run() - } } func (t *trzszPty) Wait() { diff --git a/trzsz/transfer.go b/trzsz/transfer.go index 6b3451c..7e5a0bc 100644 --- a/trzsz/transfer.go +++ b/trzsz/transfer.go @@ -262,11 +262,10 @@ func (t *trzszTransfer) addReceivedData(buf []byte, tunnel bool) { } func (t *trzszTransfer) stopTransferringFiles(stopAndDelete bool) { - if t.stopped.Load() { + if !t.stopped.CompareAndSwap(false, true) { return } t.stopAndDelete.Store(stopAndDelete) - t.stopped.Store(true) t.buffer.stopBuffer() if !t.tunnelConnected { diff --git a/trzsz/trz.go b/trzsz/trz.go index c9af21b..a2b3af8 100644 --- a/trzsz/trz.go +++ b/trzsz/trz.go @@ -109,11 +109,8 @@ func recvFiles(transfer *trzszTransfer, args *trzArgs, tmuxMode tmuxModeType, tm func TrzMain() int { args := parseTrzArgs(os.Args) - defer func() { - for i := len(onExitFuncs) - 1; i >= 0; i-- { - onExitFuncs[i]() - } - }() + // cleanup on exit + defer cleanupOnExit() var err error args.Path, err = filepath.Abs(args.Path) @@ -143,9 +140,7 @@ func TrzMain() int { uniqueID := (time.Now().UnixMilli() % 10e10) * 100 if isRunningOnWindows() { - if inMode, outMode, err := enableVirtualTerminal(); err == nil { - defer resetVirtualTerminal(inMode, outMode) // nolint:all - } + _ = setupVirtualTerminal() setupConsoleOutput() uniqueID += 10 } else if tmuxMode == tmuxNormalMode { diff --git a/trzsz/trzsz.go b/trzsz/trzsz.go index 04a1994..2355148 100644 --- a/trzsz/trzsz.go +++ b/trzsz/trzsz.go @@ -39,6 +39,7 @@ type trzszArgs struct { Relay bool TraceLog bool DragFile bool + Zmodem bool Name string Args []string } @@ -48,7 +49,7 @@ func printVersion() { } func printHelp() { - fmt.Print("usage: trzsz [-h] [-v] [-r] [-t] [-d] command line\n\n" + + fmt.Print("usage: trzsz [-h] [-v] [-r] [-t] [-d] [-z] command line\n\n" + "Wrapping command line to support trzsz ( trz / tsz ).\n\n" + "positional arguments:\n" + " command line the original command line\n\n" + @@ -57,7 +58,8 @@ func printHelp() { " -v, --version show version number and exit\n" + " -r, --relay run as a trzsz relay server\n" + " -t, --tracelog eanble trace log for debugging\n" + - " -d, --dragfile enable drag file(s) to upload\n") + " -d, --dragfile enable drag file(s) to upload\n" + + " -z, --zmodem enable zmodem lrzsz ( rz / sz )\n") } func parseTrzszArgs() *trzszArgs { @@ -76,6 +78,8 @@ func parseTrzszArgs() *trzszArgs { args.TraceLog = true } else if os.Args[i] == "-d" || os.Args[i] == "--dragfile" { args.DragFile = true + } else if os.Args[i] == "-z" || os.Args[i] == "--zmodem" { + args.Zmodem = true } else { break } @@ -122,11 +126,14 @@ func TrzszMain() int { return 0 } - defer func() { - for i := len(onExitFuncs) - 1; i >= 0; i-- { - onExitFuncs[i]() - } - }() + // cleanup on exit + defer cleanupOnExit() + + // setup virtual terminal on Windows + if err := setupVirtualTerminal(); err != nil { + fmt.Fprintf(os.Stderr, "setup virtual terminal failed: %v\r\n", err) + return -1 + } // spawn a pty pty, err := spawn(args.Name, args.Args...) @@ -165,6 +172,7 @@ func TrzszMain() int { TerminalColumns: columns, DetectDragFile: args.DragFile, DetectTraceLog: args.TraceLog, + EnableZmodem: args.Zmodem, }) pty.OnResize(filter.SetTerminalColumns) // handle signal diff --git a/trzsz/tsz.go b/trzsz/tsz.go index 3308229..3d64524 100644 --- a/trzsz/tsz.go +++ b/trzsz/tsz.go @@ -108,11 +108,8 @@ func sendFiles(transfer *trzszTransfer, files []*sourceFile, args *tszArgs, tmux func TszMain() int { args := parseTszArgs(os.Args) - defer func() { - for i := len(onExitFuncs) - 1; i >= 0; i-- { - onExitFuncs[i]() - } - }() + // cleanup on exit + defer cleanupOnExit() files, err := checkPathsReadable(args.File, args.Directory) if err != nil { @@ -143,9 +140,7 @@ func TszMain() int { uniqueID := (time.Now().UnixMilli() % 10e10) * 100 if isRunningOnWindows() { - if inMode, outMode, err := enableVirtualTerminal(); err == nil { - defer resetVirtualTerminal(inMode, outMode) // nolint:all - } + _ = setupVirtualTerminal() setupConsoleOutput() uniqueID += 10 } else if tmuxMode == tmuxNormalMode { diff --git a/trzsz/zmodem.go b/trzsz/zmodem.go new file mode 100644 index 0000000..96a10cb --- /dev/null +++ b/trzsz/zmodem.go @@ -0,0 +1,384 @@ +/* +MIT License + +Copyright (c) 2023 Lonny Wong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package trzsz + +import ( + "bytes" + "fmt" + "io" + "os/exec" + "path/filepath" + "regexp" + "sync/atomic" + "time" +) + +type zmodemTransfer struct { + upload bool + logger *traceLogger + serverIn io.Writer + clientOut io.Writer + cmd atomic.Pointer[exec.Cmd] + stdin io.WriteCloser + stdout io.ReadCloser + clientFinished atomic.Bool + serverFinished atomic.Bool + errorOccurred atomic.Bool + stopped atomic.Bool + cleaned atomic.Bool + cleanupTimer *time.Timer + clientTimer *time.Timer + serverTimer *time.Timer + lastUpdateTime *time.Time + totalSize int64 + recentSpeed recentSpeed +} + +var zmodemOverAndOut = []byte("OO\x08\x08") +var zmodemCanNotOpenFile = []byte("cannot open ") +var zmodemCancelSequence = []byte("\x18\x18\x18\x18\x18\x18\x18\x18\x18\x18\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08") + +var zmodemInitRegexp = regexp.MustCompile(`\*\*\x18B0(0|1)[0-9a-f]{12}`) +var zmodemFinishRegexp = regexp.MustCompile(`\*\*\x18B08[0-9a-f]{12}`) + +func detectZmodem(buf []byte) *zmodemTransfer { + match := zmodemInitRegexp.FindSubmatch(buf) + if len(match) < 2 || bytes.Contains(buf, zmodemCancelSequence) || bytes.Contains(buf, zmodemCanNotOpenFile) { + return nil + } + if match[1][0] == '1' { + return &zmodemTransfer{upload: true} + } else { + return &zmodemTransfer{upload: false} + } +} + +func (z *zmodemTransfer) writeMessage(msg string) { + _ = writeAll(z.clientOut, []byte(fmt.Sprintf("\r\x1b[2K%s\r\n", msg))) +} + +func (z *zmodemTransfer) updateProgress(delta int) { + if delta < 0 { + _ = writeAll(z.clientOut, []byte("\r\n")) + return + } + z.totalSize += int64(delta) + now := timeNowFunc() + if z.lastUpdateTime != nil && now.Sub(*z.lastUpdateTime) < 200*time.Millisecond { + return + } + z.lastUpdateTime = &now + _ = writeAll(z.clientOut, []byte(fmt.Sprintf("\r\x1b[2KTransferred %s, Speed %s.", + convertSizeToString(float64(z.totalSize)), z.getSpeed(&now)))) +} + +func (z *zmodemTransfer) getSpeed(now *time.Time) string { + if z.recentSpeed.speedCnt == 0 { + z.recentSpeed.initFirstStep(now) + return "N/A" + } + + speed := z.recentSpeed.getSpeed(z.totalSize, now) + if speed < 0 { + return "N/A" + } + return convertSizeToString(speed) + "/s" +} + +func (z *zmodemTransfer) resetCleanupTimer() { + if z.cleanupTimer != nil { + z.cleanupTimer.Stop() + } + z.cleanupTimer = time.AfterFunc(500*time.Millisecond, func() { + z.cleaned.Store(true) + _, _ = z.serverIn.Write([]byte("\r")) // enter for shell prompt + }) +} + +func (z *zmodemTransfer) resetClientTimer() { + if !z.upload { + return + } + if z.clientTimer != nil { + z.clientTimer.Stop() + } + z.clientTimer = time.AfterFunc(20*time.Second, func() { + z.handleZmodemError("client timeout") + }) +} + +func (z *zmodemTransfer) resetServerTimer() { + if z.upload { + return + } + if z.serverTimer != nil { + z.serverTimer.Stop() + } + z.serverTimer = time.AfterFunc(20*time.Second, func() { + z.handleZmodemError("server timeout") + }) +} + +func (z *zmodemTransfer) isTransferringFiles() bool { + return !z.stopped.Load() || !z.cleaned.Load() +} + +func (z *zmodemTransfer) stopTransferringFiles() { + z.handleZmodemError("Stopped") +} + +func (z *zmodemTransfer) handleZmodemError(msg string) { + if !z.stopped.CompareAndSwap(false, true) { + return + } + + z.errorOccurred.Store(true) + + if z.logger != nil { + z.logger.writeTraceLog([]byte(msg), "debug") + } + + _ = writeAll(z.serverIn, zmodemCancelSequence) + + if cmd := z.cmd.Load(); cmd != nil { + _ = writeAll(z.stdin, zmodemCancelSequence) + z.ensureClientExit(cmd) + } + + z.writeMessage(msg) +} + +func (z *zmodemTransfer) handleServerOutput(buf []byte) bool { + if z.stopped.Load() { + if z.cleaned.Load() { + return false + } + z.resetCleanupTimer() + return true + } + + // forward server output to the client + if cmd := z.cmd.Load(); cmd != nil { + z.resetServerTimer() + if len(buf) < 50 && zmodemFinishRegexp.Match(buf) { + if z.serverFinished.CompareAndSwap(false, true) { + z.ensureOverAndOut() + } + } + err := writeAll(z.stdin, buf) + if err == nil && !z.upload { + z.updateProgress(len(buf)) + } + return true + } + + // server canceled before the client startup + if bytes.Contains(buf, zmodemCancelSequence) || bytes.Contains(buf, zmodemCanNotOpenFile) { + z.cleaned.Store(true) + z.stopped.Store(true) + return false + } + + // skip it and wait for the client to start + return true +} + +func (z *zmodemTransfer) handleZmodemStream(cmd *exec.Cmd) { + if z.logger != nil { + z.logger.writeTraceLog([]byte("zmodem begin"), "debug") + } + z.cmd.Store(cmd) + z.resetClientTimer() + z.resetServerTimer() + + // async check if the client has exited + go z.checkClientExited(cmd) + + // forward client output to the server + buffer := make([]byte, 32*1024) + for { + n, err := z.stdout.Read(buffer) + z.resetClientTimer() + if n > 0 { + buf := buffer[:n] + if z.logger != nil { + z.logger.writeTraceLog(buf, "zmodem") + } + if z.errorOccurred.Load() || z.serverFinished.Load() && z.clientFinished.Load() { + if z.logger != nil { + z.logger.writeTraceLog([]byte("ignore zmodem output"), "debug") + } + break + } + if len(buf) < 50 && zmodemFinishRegexp.Match(buf) { + if z.clientFinished.CompareAndSwap(false, true) { + z.ensureOverAndOut() + } + } + if err := writeAll(z.serverIn, buf); err != nil { + z.handleZmodemError(fmt.Sprintf("write to server failed: %v", err)) + break + } + if z.upload { + z.updateProgress(n) + } + } + if err == io.EOF { + break + } + if err != nil { + z.handleZmodemError(fmt.Sprintf("read from client failed: %v", err)) + break + } + } + + if z.clientTimer != nil { + z.clientTimer.Stop() + } + + z.ensureClientExit(cmd) +} + +func (z *zmodemTransfer) ensureOverAndOut() { + if !z.serverFinished.Load() || !z.clientFinished.Load() { + return + } + if z.logger != nil { + z.logger.writeTraceLog([]byte("Over and Out"), "debug") + } + if z.upload { + _ = writeAll(z.serverIn, zmodemOverAndOut) + } else { + _ = writeAll(z.stdin, zmodemOverAndOut) + } +} + +func (z *zmodemTransfer) ensureClientExit(cmd *exec.Cmd) { + go func() { + time.Sleep(500 * time.Millisecond) + _ = cmd.Process.Kill() + }() +} + +func (z *zmodemTransfer) checkClientExited(cmd *exec.Cmd) { + _ = cmd.Wait() + z.stopped.Store(true) + + if z.serverTimer != nil { + z.serverTimer.Stop() + } + + z.updateProgress(-1) // -1 means finished + + if z.logger != nil { + z.logger.writeTraceLog([]byte("zmodem end"), "debug") + } + + if code := cmd.ProcessState.ExitCode(); code != 0 { + z.writeMessage(fmt.Sprintf("client exit with %d", code)) + } else { + z.writeMessage("\033[1;32mSuccess!!\033[0m") + } + + z.resetCleanupTimer() + + // make sure the server exit + _ = writeAll(z.serverIn, zmodemCancelSequence) +} + +func (z *zmodemTransfer) launchZmodemCmd(dir, name string, args ...string) (*exec.Cmd, error) { + var err error + cmd := exec.Command(name, args...) + cmd.Dir = dir + z.stdin, err = cmd.StdinPipe() + if err != nil { + return nil, err + } + z.stdout, err = cmd.StdoutPipe() + if err != nil { + return nil, err + } + if err := cmd.Start(); err != nil { + return nil, err + } + return cmd, nil +} + +func (z *zmodemTransfer) uploadFiles(files []string) { + workDir := "" + for i := 0; i < len(files); i++ { + fileDir := filepath.Dir(files[i]) + if i == 0 { + workDir = fileDir + } + if fileDir == workDir { + files[i] = filepath.Base(files[i]) + } + } + cmd, err := z.launchZmodemCmd(workDir, "sz", append([]string{"-e", "-b", "-B", "32768"}, files...)...) + if err != nil { + z.handleZmodemError(fmt.Sprintf("run sz client failed: %v", err)) + return + } + z.handleZmodemStream(cmd) +} + +func (z *zmodemTransfer) downloadFiles(path string) { + cmd, err := z.launchZmodemCmd(path, "rz", "-E", "-e", "-b", "-B", "32768") + if err != nil { + z.handleZmodemError(fmt.Sprintf("run rz client failed: %v", err)) + return + } + z.handleZmodemStream(cmd) +} + +func (z *zmodemTransfer) handleZmodemEvent(logger *traceLogger, serverIn io.Writer, clientOut io.Writer, + chooseUploadFiles func() ([]string, error), chooseDownloadPath func() (string, error)) { + z.logger = logger + z.serverIn = serverIn + z.clientOut = clientOut + + // the server may fail immediately + time.Sleep(100 * time.Millisecond) + if z.stopped.Load() { + return + } + + if z.upload { + files, err := chooseUploadFiles() + if err != nil { + z.handleZmodemError(err.Error()) + return + } + z.uploadFiles(files) + } else { + path, err := chooseDownloadPath() + if err != nil { + z.handleZmodemError(err.Error()) + return + } + z.downloadFiles(path) + } +} diff --git a/trzsz/zmodem_test.go b/trzsz/zmodem_test.go new file mode 100644 index 0000000..6085c62 --- /dev/null +++ b/trzsz/zmodem_test.go @@ -0,0 +1,70 @@ +/* +MIT License + +Copyright (c) 2023 Lonny Wong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package trzsz + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDetectZmodem(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + upload := true + download := false + assertDetectZmodem := func(buf string, expect *bool) { + t.Helper() + zmodem := detectZmodem([]byte(buf)) + if expect == nil { + assert.Nil(zmodem) + } else { + require.NotNil(zmodem) + assert.Equal(*expect, zmodem.upload) + } + } + + assertDetectZmodem("**\x18B0100000023be50", &upload) + assertDetectZmodem("**\x18B0100000063f694", &upload) + assertDetectZmodem("**\x18B00000000000000", &download) + assertDetectZmodem("rz\x0d**\x18B00000000000000", &download) + + assertDetectZmodem("**\x18B0100000023be50\x0d\x8a\x11", &upload) + assertDetectZmodem("**\x18B0100000063f694\x0d\x8a\x11", &upload) + assertDetectZmodem("**\x18B00000000000000\x0d\x8a\x11", &download) + assertDetectZmodem("rz\x0d**\x18B00000000000000\x0d\x8a\x11", &download) + + assertDetectZmodem("**\x19B0100000023be50\x0d\x8a\x11", nil) + assertDetectZmodem("**\x18B0100000023BE50\x0d\x8a\x11", nil) + assertDetectZmodem(" *\x18B0100000023be50\x0d\x8a\x11", nil) + assertDetectZmodem("**\x18B0100000023be5", nil) + + assertDetectZmodem("**\x18B0100000023be50"+string(zmodemCanNotOpenFile), nil) + assertDetectZmodem("**\x18B0100000063f694\x0d\x8a\x11"+string(zmodemCanNotOpenFile), nil) + assertDetectZmodem("**\x18B0100000023be50"+string(zmodemCancelSequence), nil) + assertDetectZmodem("**\x18B0100000063f694\x0d\x8a\x11"+string(zmodemCancelSequence), nil) +}