Skip to content

Commit

Permalink
Fix timeout for recvonce and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ptr-yudai committed Jan 6, 2024
1 parent f4a5d3d commit da58f18
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 8 deletions.
15 changes: 11 additions & 4 deletions ptrlib/connection/tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def _recv(self, size: int, timeout: Union[int, float]) -> Optional[bytes]:
pass

def unget(self, data: Union[str, bytes]):
"""Revert data to socket
Return data to socket.
Args:
data (bytes): Data to return
"""
if isinstance(data, str):
data = str2bytes(data)
self.buf = data + self.buf
Expand Down Expand Up @@ -100,10 +107,10 @@ def recvonce(self, size: int, timeout: Optional[Union[int, float]]=None) -> byte
timer_start = time.time()

while len(data) < size:
if timeout is not None and time.time() - timer_start > timeout:
raise TimeoutError("`recvonce` timeout", data)

data += self.recv(size - len(data), timeout=-1)
try:
data += self.recv(size - len(data), timeout=-1)
except TimeoutError as err:
raise TimeoutError("`recvonce` timeout", data + err.args[1])
time.sleep(0.01)

if len(data) > size:
Expand Down
50 changes: 46 additions & 4 deletions tests/connection/test_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_basic(self):

p = Process("./tests/test.bin/test_echo.x64")

# send / recv
# sendline / recvline
p.sendline(b"Message : " + msg)
self.assertEqual(p.recvlineafter(" : "), msg)

Expand All @@ -41,6 +41,13 @@ def test_basic(self):
self.assertEqual(int(r[0], 16), a)
self.assertEqual(int(r[1], 16), b)

# sendlineafter
a, b = os.urandom(16).hex(), os.urandom(16).hex()
p.sendline(a)
v = p.sendlineafter(a + "\n", b)
self.assertEqual(v.strip(), a.encode())
self.assertEqual(p.recvline().strip(), b.encode())

# shutdown
p.send(msg[::-1])
p.shutdown('write')
Expand All @@ -54,16 +61,51 @@ def test_basic(self):
def test_timeout(self):
p = Process("./tests/test.bin/test_echo.x64")
data = os.urandom(16).hex()

# recv
try:
p.recv(timeout=0.5)
result = False
except TimeoutError as err:
self.assertEqual(err.args[1], b"")
result = True
except:
result = False
self.assertEqual(result, True)

# recvonce
p.sendline(data)
try:
p.recvuntil("*** never expected ***", timeout=1)
p.recvonce(len(data) + 1 + 1, timeout=0.5)
result = False
except TimeoutError as err:
self.assertEqual(err.args[1].decode().strip(), data)
result = True
except:
result = False
finally:
p.close()
self.assertEqual(result, True)

# recvuntil
p.sendline(data)
try:
p.recvuntil("*** never expected ***", timeout=0.5)
result = False
except TimeoutError as err:
self.assertEqual(err.args[1].decode().strip(), data)
result = True
except:
result = False
self.assertEqual(result, True)

# sendlineafter
a, b = os.urandom(16).hex(), os.urandom(16).hex()
p.sendline(a)
try:
p.sendlineafter(b"neko", b, timeout=0.5)
except TimeoutError as err:
self.assertEqual(err.args[1].decode().strip(), a)
result = True
except:
result = False
self.assertEqual(result, True)

0 comments on commit da58f18

Please sign in to comment.