Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore sndrcv behaviour from before 53afe84 #4538

Merged
merged 11 commits into from
Sep 24, 2024
Merged
41 changes: 21 additions & 20 deletions scapy/sendrecv.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class debug:
Automatically enabled when a generator is passed as the packet
:param _flood:
:param threaded: if True, packets are sent in a thread and received in another.
defaults to False.
Defaults to True.
:param session: a flow decoder used to handle stream of packets
:param chainEX: if True, exceptions during send will be forwarded
:param stop_filter: Python function applied to each packet to determine if
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(self,
rcv_pks=None, # type: Optional[SuperSocket]
prebuild=False, # type: bool
_flood=None, # type: Optional[_FloodGenerator]
threaded=False, # type: bool
threaded=True, # type: bool
session=None, # type: Optional[_GlobSessionType]
chainEX=False, # type: bool
stop_filter=None # type: Optional[Callable[[Packet], bool]]
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(self,
self.noans = 0
self._flood = _flood
self.threaded = threaded
self.breakout = False
self.breakout = Event()
# Instantiate packet holders
if prebuild and not self._flood:
self.tobesent = list(pkt) # type: _PacketIterable
Expand All @@ -174,6 +174,7 @@ def __init__(self,
self.timeout = None

while retry >= 0:
self.breakout.clear()
self.hsent = {} # type: Dict[bytes, List[Packet]]

if threaded or self._flood:
Expand All @@ -190,7 +191,7 @@ def __init__(self,
except KeyboardInterrupt as ex:
interrupted = ex

self.breakout = True
self.breakout.set()

# Ended. Let's close gracefully
if self._flood:
Expand Down Expand Up @@ -251,28 +252,33 @@ def results(self):
# type: () -> Tuple[SndRcvList, PacketList]
return self.ans_result, self.unans_result

def _stop_sniffer_if_done(self) -> None:
"""Close the sniffer if all expected answers have been received"""
if self._send_done and self.noans >= self.notans and not self.multi:
if self.sniffer and self.sniffer.running:
self.sniffer.stop(join=False)

def _sndrcv_snd(self):
# type: () -> None
"""Function used in the sending thread of sndrcv()"""
i = 0
p = None
try:
if self.verbose:
print("Begin emission:")
os.write(1, b"Begin emission\n")
for p in self.tobesent:
# Populate the dictionary of _sndrcv_rcv
# _sndrcv_rcv won't miss the answer of a packet that
# has not been sent
self.hsent.setdefault(p.hashret(), []).append(p)
# Send packet
self.pks.send(p)
if self.inter:
time.sleep(self.inter)
if self.breakout:
time.sleep(self.inter)
if self.breakout.is_set():
break
i += 1
if self.verbose:
print("Finished sending %i packets." % i)
os.write(1, b"\nFinished sending %i packets\n" % i)
except SystemExit:
pass
except Exception:
Expand All @@ -291,13 +297,10 @@ def _sndrcv_snd(self):
elif not self._send_done:
self.notans = i
self._send_done = True
# In threaded mode, timeout.
if self.threaded and self.timeout is not None and not self.breakout:
t = time.monotonic() + self.timeout
while time.monotonic() < t:
polybassa marked this conversation as resolved.
Show resolved Hide resolved
if self.breakout:
break
time.sleep(0.1)
self._stop_sniffer_if_done()
# In threaded mode, timeout
if self.threaded and self.timeout is not None and not self.breakout.is_set():
self.breakout.wait(timeout=self.timeout)
if self.sniffer and self.sniffer.running:
self.sniffer.stop()

Expand All @@ -324,9 +327,7 @@ def _process_packet(self, r):
self.noans += 1
sentpkt._answered = 1
break
if self._send_done and self.noans >= self.notans and not self.multi:
polybassa marked this conversation as resolved.
Show resolved Hide resolved
if self.sniffer and self.sniffer.running:
self.sniffer.stop(join=False)
self._stop_sniffer_if_done()
if not ok:
if self.verbose > 1:
os.write(1, b".")
Expand All @@ -342,7 +343,7 @@ def _sndrcv_rcv(self, callback):
self.sniffer = AsyncSniffer()
self.sniffer._run(
prn=self._process_packet,
timeout=None if self.threaded else self.timeout,
timeout=None if self.threaded and not self._flood else self.timeout,
store=False,
opened_socket=self.rcv_pks,
session=self.session,
Expand Down
36 changes: 26 additions & 10 deletions test/contrib/automotive/doip.uts
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ import tempfile
= Test DoIPSocket

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -426,6 +427,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -437,7 +439,7 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

Expand All @@ -446,6 +448,7 @@ assert len(pkts) == 2
~ linux

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -456,6 +459,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
for i in range(len(buffer)):
connection.send(buffer[i:i+1])
time.sleep(0.01)
Expand All @@ -469,13 +473,14 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test DoIPSocket 3

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -486,6 +491,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
while buffer:
randlen = random.randint(0, len(buffer))
connection.send(buffer[:randlen])
Expand All @@ -501,14 +507,15 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2


= Test DoIPSocket6

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
Expand All @@ -519,6 +526,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -530,7 +538,7 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(ip="::1", activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

Expand Down Expand Up @@ -604,6 +612,7 @@ def _load_certificate_chain(context) -> None:


server_up = threading.Event()
sniff_up = threading.Event()
def server():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -619,6 +628,7 @@ def server():
ssock.listen(1)
server_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -633,14 +643,15 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
sock = DoIPSocket(activate_routing=False, force_tls=True, context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test DoIPSslSocket6
~ broken_windows

server_up = threading.Event()
sniff_up = threading.Event()
def server():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -656,6 +667,7 @@ def server():
ssock.listen(1)
server_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -670,14 +682,15 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
sock = DoIPSocket(ip="::1", activate_routing=False, force_tls=True, context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test UDS_DoIPSslSocket6
~ broken_windows

server_up = threading.Event()
sniff_up = threading.Event()
def server():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -693,6 +706,7 @@ def server():
ssock.listen(1)
server_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -707,15 +721,16 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
sock = UDS_DoIPSocket(ip="::1", activate_routing=False, force_tls=True, context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test UDS_DualDoIPSslSocket6
~ broken_windows
~ broken_windows not_pypy

server_tcp_up = threading.Event()
server_tls_up = threading.Event()
sniff_up = threading.Event()
def server_tls():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -732,6 +747,7 @@ def server_tls():
ssock.listen(1)
server_tls_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -748,7 +764,7 @@ def server_tcp():
server_tcp_up.set()
connection, address = sock.accept()
connection.send(buffer)
connection.shutdown()
connection.shutdown(socket.SHUT_RDWR)
connection.close()
finally:
sock.close()
Expand All @@ -767,7 +783,7 @@ context.verify_mode = ssl.CERT_NONE

sock = UDS_DoIPSocket(ip="::1", context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_tcp_thread.join(timeout=1)
server_tls_thread.join(timeout=1)
assert len(pkts) == 2
12 changes: 12 additions & 0 deletions test/contrib/automotive/scanner/enumerator.uts
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,19 @@ class MockISOTPSocket(SuperSocket):
return len(sx)
@staticmethod
def select(sockets, remain=None):
time.sleep(0)
return sockets
def sr(self, *args, **kargs):
from scapy import sendrecv
return sendrecv.sndrcv(self, *args, threaded=False, **kargs)
def sr1(self, *args, **kargs):
from scapy import sendrecv
ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0] # type: SndRcvList
if len(ans) > 0:
pkt = ans[0][1] # type: Packet
return pkt
else:
return None

sock = MockISOTPSocket()
sock.rcvd_queue.put(b"\x41")
Expand Down
6 changes: 3 additions & 3 deletions test/regression.uts
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ sck = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ssck = StreamSocket(sck)

try:
r = ssck.sr1(ICMP(type='echo-request'), timeout=0.1, chainEX=True)
r = ssck.sr1(ICMP(type='echo-request'), timeout=0.1, chainEX=True, threaded=False)
assert False
except Exception:
assert True
Expand Down Expand Up @@ -2132,7 +2132,7 @@ retry_test(_test)
~ netaccess needs_root IP ICMP
def _test():
packet = IP(dst="8.8.8.8")/ICMP()
r = srflood(packet, timeout=2)
r = srflood(packet, timeout=0.5)
assert packet.sent_time is not None

retry_test(_test)
Expand All @@ -2142,7 +2142,7 @@ retry_test(_test)
def _test():
packet1 = IP(dst="8.8.8.8")/ICMP()
packet2 = IP(dst="8.8.4.4")/ICMP()
r = srflood([packet1, packet2], timeout=2)
r = srflood([packet1, packet2], timeout=0.5)
assert packet1.sent_time is not None
assert packet2.sent_time is not None

Expand Down
Loading
Loading