Skip to content

Commit

Permalink
Mock out all connection stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
bartv committed Jun 24, 2024
1 parent 04bd445 commit 75ec544
Showing 1 changed file with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import mock
import asyncio

from asyncpg import Connection, cursor
from asyncpg import Connection, cursor, Record
from wrapt import ObjectProxy
import pytest

Expand Down Expand Up @@ -60,30 +60,43 @@ def assert_wrapped(assert_fnc):
def test_cursor_span_creation(self):
""" Test the cursor wrapper if it creates spans correctly.
"""
async def mock_fn(*args, **kwargs):
pass
# Mock out all interaction with postgres
async def bind_mock(*args, **kwargs):
return []

async def mock_fn_stop(*args, **kwargs):
raise StopAsyncIteration()
async def exec_mock(*args, **kwargs):
return [], None, True

cursor_mock = mock.Mock()
cursor_mock._query = "SELECT * FROM test"
conn = mock.Mock()
conn.is_closed = lambda: False

conn._protocol = mock.Mock()
conn._protocol.bind = bind_mock
conn._protocol.execute = exec_mock
conn._protocol.bind_execute = exec_mock
conn._protocol.close_portal = bind_mock

state = mock.Mock()
state.closed = False

apg = AsyncPGInstrumentor()
apg.instrument(tracer_provider=self.tracer_provider)

# Call the wrapper function directly. They only way to be able to do this on the real classes is to mock all of the
# methods. In that case we are only testing the instrumentation of mocked functions. This makes it explicit.
asyncio.run(apg._do_cursor_execute(mock_fn, cursor_mock, [], {}))
# init the cursor and fetch a single record
crs = cursor.Cursor(conn, "SELECT * FROM test", state, [], Record)
asyncio.run(crs._init(1))
asyncio.run(crs.fetch(1))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "CURSOR: SELECT")
self.assertTrue(spans[0].status.is_ok)

# Now test that the StopAsyncIteration of the cursor does not get recorded as an ERROR
crs_iter = cursor.CursorIterator(conn, "SELECT * FROM test", state, [], Record, 1, 1)

with pytest.raises(StopAsyncIteration):
asyncio.run(apg._do_cursor_execute(mock_fn_stop, cursor_mock, [], {}))
asyncio.run(anext(crs_iter))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 2)
Expand Down

0 comments on commit 75ec544

Please sign in to comment.