From 0764716be7d4164665f47138a8a7bc8490bb6563 Mon Sep 17 00:00:00 2001 From: Andrew Moffat Date: Sat, 25 Apr 2020 08:09:11 -0700 Subject: [PATCH] support empty _in param. closes #427 --- sh.py | 15 ++++++++------- test.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sh.py b/sh.py index 2f298597..dd1203f9 100644 --- a/sh.py +++ b/sh.py @@ -2119,7 +2119,7 @@ def __init__(self, command, parent_log, cmd, stdin, stdout, stderr, # to prevent race conditions self.exit_code = None - self.stdin = stdin or Queue() + self.stdin = stdin # _pipe_queue is used internally to hand off stdout from one process # to another. by default, all stdout from a process gets dumped @@ -2152,15 +2152,10 @@ def __init__(self, command, parent_log, cmd, stdin, stdout, stderr, attr[3] &= ~termios.ECHO termios.tcsetattr(self._stdin_parent_fd, termios.TCSANOW, attr) - # we're only going to create a stdin thread iff we have potential - # for stdin to come in. this would be through a stdout callback or - # through an object we've passed in for stdin - potentially_has_input = callable(stdout) or stdin - # this represents the connection from a Queue object (or whatever # we're using to feed STDIN) to the process's STDIN fd self._stdin_stream = None - if self._stdin_parent_fd and potentially_has_input: + if self._stdin_parent_fd: log = self.log.get_child("streamwriter", "stdin") self._stdin_stream = StreamWriter(log, self._stdin_parent_fd, self.stdin, ca["in_bufsize"], ca["encoding"], @@ -2651,6 +2646,12 @@ def determine_how_to_read_input(input_obj): log_msg = "generator" get_chunk = get_iter_chunk_reader(iter(input_obj)) + elif input_obj is None: + log_msg = "None" + def raise_(): + raise DoneReadingForever + get_chunk = raise_ + else: try: it = iter(input_obj) diff --git a/test.py b/test.py index 06146d55..b5588b09 100644 --- a/test.py +++ b/test.py @@ -287,6 +287,17 @@ def test_number_arg(self): out = python(py.name, 3).strip() self.assertEqual(out, "3") + def test_empty_stdin_no_hang(self): + py = create_tmp_test(""" +import sys +data = sys.stdin.read() +sys.stdout.write("no hang") +""") + out = python(py.name, _in="", _timeout=2) + self.assertEqual(out, "no hang") + + out = python(py.name, _in=None, _timeout=2) + self.assertEqual(out, "no hang") def test_exit_code(self): from sh import ErrorReturnCode