Skip to content

Commit

Permalink
Implement process class
Browse files Browse the repository at this point in the history
  • Loading branch information
ptr-yudai committed Apr 23, 2024
1 parent a929ae8 commit 70b130a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 42 deletions.
47 changes: 32 additions & 15 deletions ptrlib/connection/proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,8 @@ def __init__(self,
args = list(map(bytes2str, args))

# Prepare stdio
if raw:
# TODO
pass
else:
master = self._slave = None
if not raw:
master, self._slave = pty.openpty()
tty.setraw(master)
tty.setraw(self._slave)
Expand Down Expand Up @@ -118,7 +116,7 @@ def __init__(self,
self._current_timeout = self._default_timeout

# Duplicate master
if not raw and master is not None:
if master is not None:
self._proc.stdout = os.fdopen(os.dup(master), 'r+b', 0)
os.close(master)

Expand Down Expand Up @@ -153,8 +151,13 @@ def _recv_impl(self, size: int) -> bytes:
Returns:
bytes: The received data
"""
if self._current_timeout == 0:
timeout = None
else:
timeout = self._current_timeout

ready, [], [] = select.select(
[self._proc.stdout.fileno()], [], [], self._current_timeout
[self._proc.stdout.fileno()], [], [], timeout
)
if len(ready) == 0:
raise TimeoutError("Timeout (_recv_impl)", b'') from None
Expand Down Expand Up @@ -196,14 +199,16 @@ def _shutdown_send_impl(self):
def _close_impl(self):
"""Close process
"""
self._proc.stdin.close()
self._proc.stdout.close()
if self._is_alive_impl():
self._proc.kill()
self._proc.wait()
logger.info(f"{str(self)} killed")
else:
logger.info(f"{str(self)} has already exited")
logger.info(f"{str(self)} killed by `close`")

if self._slave is not None: # PTY mode
os.close(self._slave)

self._proc.stdin.close()
self._proc.stdout.close()

def _is_alive_impl(self) -> bool:
"""Check if the process is alive"""
Expand All @@ -216,7 +221,6 @@ def __str__(self) -> str:
#
# Custom method
#
@tube_is_open
def poll(self) -> Optional[int]:
"""Check if the process has exited
"""
Expand All @@ -228,12 +232,25 @@ def poll(self) -> Optional[int]:
self._returncode = self._proc.returncode
name = signal_name(-self._returncode, detail=True)
if name:
name = '--> ' + name
logger.error(f"{str(self)} stopped with exit code " \
f"{self._returncode} {name}")
name = ' --> ' + name

logger_func = logger.info if self._returncode == 0 else logger.error
logger_func(f"{str(self)} stopped with exit code " \
f"{self._returncode}{name}")

return self._returncode

@tube_is_open
def wait(self, timeout: Optional[Union[int, float]]=None) -> int:
"""Wait until the process dies
Wait until the process exits and get the status code.
Returns:
code (int): Status code of the process
"""
return self._proc.wait(timeout)


Process = WinProcess if _is_windows else UnixProcess
process = Process # alias for the Process
11 changes: 8 additions & 3 deletions ptrlib/connection/sock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,19 @@ def _recv_impl(self, size: int) -> bytes:
# NOTE: We cannot rely on the blocking behavior of `recv`
# because the socket might be non-blocking mode
# due to `_is_alive_impl` on multi-thread environment.
ready, [], [] = select.select(
[self._sock], [], [], self._current_timeout
)
if self._current_timeout == 0:
timeout = None
else:
timeout = self._current_timeout

ready, [], [] = select.select([self._sock], [], [], timeout)
if len(ready) == 0:
raise TimeoutError("Timeout (_recv_impl)", b'') from None

try:
data = self._sock.recv(size)
if len(data) == 0:
raise ConnectionResetError("Empty reply") from None

except BlockingIOError:
# NOTE: This exception can occur if this method is called
Expand Down
42 changes: 31 additions & 11 deletions ptrlib/connection/tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def recvuntil(self,
for i, d in enumerate(delim):
assert isinstance(d, (str, bytes)), \
f"`delim[{i}]` must be either str or bytes"
delim[i] = str2bytes(delim)
delim[i] = str2bytes(delim[i])
else:
delim = [str2bytes(delim)]

Expand All @@ -282,6 +282,9 @@ def recvuntil(self,
data += self.recv(size, timeout)
except TimeoutError as err:
raise TimeoutError("Timeout (recvuntil)", data + err.args[1])
except Exception as err:
err.args = (err.args[0], data)
raise err from None

for d in delim:
if d in data[max(0, prev_len-len(d)):]:
Expand Down Expand Up @@ -549,16 +552,32 @@ def sendafter(self,
tube.sendafter("command: ", 1) # b"1" is sent
```
"""
assert isinstance(data, (int, float, str, bytes)), \
"`data` must be int, float, str, or bytes"
recv_data = self.recvuntil(delim, size, timeout, drop, lookahead)
self.send(data)

if isinstance(data, (int, float)):
data = str(data).encode()
else:
data = str2bytes(data)
return recv_data

def sendlineafter(self,
delim: Union[str, bytes],
data: Union[str, bytes, int],
size: int=4096,
timeout: Optional[Union[int, float]]=None,
drop: bool=False,
lookahead: bool=False) -> bytes:
"""Send raw data after a delimiter
Send raw data with newline after `delim` is received.
Args:
delim (bytes): The delimiter
data (bytes) : Data to send
timeout (int): Timeout (in second)
Returns:
bytes: Received bytes before `delim` comes.
"""
recv_data = self.recvuntil(delim, size, timeout, drop, lookahead)
self.send(data)
self.sendline(data, timeout=timeout)

return recv_data

Expand Down Expand Up @@ -679,9 +698,6 @@ def thread_send(flag: threading.Event):
except (ConnectionResetError, ConnectionAbortedError, OSError):
flag.set()

# Disable timeout
self.settimeout(0)

flag = threading.Event()
th_recv = threading.Thread(target=thread_recv, args=(flag,))
th_send = threading.Thread(target=thread_send, args=(flag,))
Expand Down Expand Up @@ -776,6 +792,10 @@ def __exit__(self, _e_type, _e_value, _traceback):
def __str__(self) -> str:
return "<unknown tube>"

def __del__(self):
if not self._is_closed:
self.close()

#
# Abstract methods
#
Expand Down
6 changes: 3 additions & 3 deletions tests/connection/test_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_basic(self):
with self.assertLogs(module_name) as cm:
p = Process("./tests/test.bin/test_echo.x64")
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$')
self.assertEqual(cm.output[0], f'INFO:{module_name}:Successfully created new process {str(p)}')

# sendline / recvline
p.sendline(b"Message : " + msg)
Expand Down Expand Up @@ -70,15 +70,15 @@ def test_basic(self):
with self.assertLogs(module_name) as cm:
p.close()
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.output[0], fr'^INFO:{module_name}:.+ \(PID=\d+\) has already exited$')
self.assertEqual(cm.output[0], fr'INFO:{module_name}:{str(p)} stopped with exit code {p.poll()}')

def test_timeout(self):
module_name = inspect.getmodule(Process).__name__

with self.assertLogs(module_name) as cm:
p = Process("./tests/test.bin/test_echo.x64")
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.output[0], fr'^INFO:{module_name}:Successfully created new process \(PID=\d+\)$')
self.assertEqual(cm.output[0], fr'INFO:{module_name}:Successfully created new process {str(p)}')
data = os.urandom(16).hex()

# recv
Expand Down
22 changes: 12 additions & 10 deletions tests/connection/test_sock.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,20 @@ def test_timeout(self):
sock = Socket("www.example.com", 80)
sock.sendline(b'GET / HTTP/1.1\r')
sock.send(b'Host: www.example.com\r\n\r\n')
try:

with self.assertRaises(TimeoutError) as cm:
sock.recvuntil("*** never expected ***", timeout=2)
result = False
except TimeoutError as err:
self.assertEqual(b"200 OK" in err.args[1], True)
result = True
except:
result = False
finally:
sock.close()
self.assertEqual(b"200 OK" in cm.exception.args[1], True)

self.assertEqual(result, True)
def test_reset(self):
sock = Socket("www.example.com", 80)
sock.sendline(b'GET / HTTP/1.1\r')
sock.send(b'Host: www.example.com\r\n')
sock.send(b'Connection: close\r\n\r\n')

with self.assertRaises(ConnectionResetError) as cm:
sock.recvuntil("*** never expected ***", timeout=2)
self.assertEqual(b"200 OK" in cm.exception.args[1], True)

def test_tls(self):
host = "www.example.com"
Expand Down

0 comments on commit 70b130a

Please sign in to comment.