Skip to content

Commit

Permalink
Add typing information to mock_device (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
cottsay authored Dec 14, 2023
1 parent af94e4e commit 0e2c984
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions test/mock_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,31 @@
# Licensed under the Apache License, Version 2.0

import asyncio
from collections.abc import AsyncIterator
from collections.abc import Awaitable
from collections.abc import MutableMapping
from contextlib import asynccontextmanager
import sys
from typing import Optional
from typing import SupportsIndex
from typing import TypeVar
from typing import Union

import serial_asyncio


class _NoPopList(list):
_T = TypeVar('_T')

def pop(self, index):

class _NoPopList(list[_T]):

def pop(self, index: SupportsIndex = -1) -> _T:
return self[index]


DEFAULT_RESPONSES = {
_ResponseMapping = MutableMapping[bytes, Optional[Union[list[bytes], bytes]]]

DEFAULT_RESPONSES: _ResponseMapping = {
b'<Command>'
b'<Name>get_current_price</Name>'
b'<MeterMacId>0xFEDCBA9876543210</MeterMacId>'
Expand Down Expand Up @@ -116,7 +128,11 @@ def pop(self, index):
}


async def _device_loop(reader, writer, responses):
async def _device_loop(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
responses: _ResponseMapping
) -> None:
buffer = b''
while True:
try:
Expand Down Expand Up @@ -148,7 +164,10 @@ async def _device_loop(reader, writer, responses):


@asynccontextmanager
async def mock_device(responses=None, initial_buffer=None):
async def mock_device(
responses: Optional[_ResponseMapping] = None,
initial_buffer: Optional[bytes] = None
) -> AsyncIterator[None]:
"""
Create a mock device at a TCP endpoint.
Expand All @@ -165,7 +184,10 @@ async def mock_device(responses=None, initial_buffer=None):

connections = []

def client_connected(reader, writer):
def client_connected(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter
) -> Awaitable[None]:
if initial_buffer is not None:
writer.write(initial_buffer)
task = asyncio.create_task(_device_loop(reader, writer, responses))
Expand All @@ -179,14 +201,14 @@ def client_connected(reader, writer):
if not connections:
return

_, connections = await asyncio.wait(connections, timeout=0.1)
for task in connections:
_, pending = await asyncio.wait(connections, timeout=0.1)
for task in pending:
task.cancel()
if connections:
await asyncio.wait(connections)
if pending:
await asyncio.wait(pending)


async def main(argv=sys.argv):
async def main(argv: list[str] = sys.argv) -> None:
assert len(argv) == 2
reader, writer = await serial_asyncio.open_serial_connection(url=argv[1])
await _device_loop(reader, writer, DEFAULT_RESPONSES)
Expand Down

0 comments on commit 0e2c984

Please sign in to comment.