diff --git a/dockerpty/pty.py b/dockerpty/pty.py index 25cb788..969fbd8 100644 --- a/dockerpty/pty.py +++ b/dockerpty/pty.py @@ -137,7 +137,9 @@ def start(self, sockets=None, **kwargs): pumps = [] if pty_stdin and self.interactive: - pumps.append(io.Pump(io.Stream(self.stdin), pty_stdin, wait_for_output=False)) + # if stdin isn't a TTY then this is probably an SSH session + # so wait for the EOF to happen before considering the Pump closed. + pumps.append(io.Pump(io.Stream(self.stdin), pty_stdin, wait_for_output=not sys.stdin.isatty())) if pty_stdout: pumps.append(io.Pump(pty_stdout, io.Stream(self.stdout), propagate_close=False)) @@ -360,11 +362,13 @@ def resize(self, size=None): def _hijack_tty(self, pumps): with tty.Terminal(self.operation.stdin, raw=self.operation.israw()): self.resize() - while True: + keep_running = True + stdin_stream = self._get_stdin_pump(pumps) + while keep_running: read_pumps = [p for p in pumps if not p.eof] write_streams = [p.to_stream for p in pumps if p.to_stream.needs_write()] - read_ready, write_ready = io.select(read_pumps, write_streams, timeout=60) + read_ready, write_ready = io.select(read_pumps, write_streams, timeout=2) try: for write_stream in write_ready: write_stream.do_write() @@ -372,9 +376,29 @@ def _hijack_tty(self, pumps): for pump in read_ready: pump.flush() - if all([p.is_done() for p in pumps]): + if sys.stdin.isatty(): + if all([p.is_done() for p in pumps]): + keep_running = False + elif stdin_stream.is_done(): + # If stdin isn't a TTY, this is probably an SSH session. + # The most common use case for an SSH session without a + # TTY is SCP/SFTP; like, someone coping a file to a remote + # server. Those file transfer clients mark the end of + # the session by sending an empty packet, then waiting + # for the TCP session to terminate, We need to break out + # of this loop to return control to the calling application + # so it can tear down the SCP/SFTP process running inside + # the container. + keep_running = False break except SSLError as e: if 'The operation did not complete' not in e.strerror: raise e + + @staticmethod + def _get_stdin_pump(pumps): + """Find the Pump connected to stdin, and return it""" + for pump in pumps: + if pump.from_stream.fd.name == '': + return pump