Skip to content

Commit

Permalink
Prefer AsyncMock to as_future (#1366)
Browse files Browse the repository at this point in the history
In Python, calling an async function `f()` can be done in two steps:

 1. call the function using regular function call `f()`
 2. `await` the result, possibly later

The driver tests took advantage of this detail to make `MagicMock` work with async functions by using `as_future`. With that trick, mocked functions return an awaitable result, which are then awaited in the tested code.

However, coroutines are best thought as an implementation detail, and it's better to always call an async function `f` using `await f()`, never separating the two steps mentioned above. Thankfully, Python 3.8 introduced `AsyncMock` that allows removing `as_future` by just specifying a return value, which avoids thinking about coroutines, which is what we want.

There's just one wrinkle: while `mock.patch()` replaces the target with an `AsyncMock`, it does not work recursively. So while we would like to rewrite

```
es.cluster.health.return_value = as_future({"status": "green", "relocating_shards": 0})
```

to

```
es.cluster.health.return_value = {"status": "green", "relocating_shards": 0}
```

we need to use

```
es.cluster.health = mock.AsyncMock(return_value={"status": "green", "relocating_shards": 0})
```

which is still an improvement as it avoids the `as_future` code smell.
  • Loading branch information
pquentin authored Oct 25, 2021
1 parent e90597e commit 24dee9e
Show file tree
Hide file tree
Showing 3 changed files with 446 additions and 465 deletions.
19 changes: 0 additions & 19 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,3 @@ def async_wrapper(*args, **kwargs):
asyncio.run(t(*args, **kwargs), debug=True)

return async_wrapper


def as_future(result=None, exception=None):
"""
Helper to create a future that completes immediately either with a result or exceptionally.
:param result: Regular result.
:param exception: Exceptional result.
:return: The corresponding future.
"""
f = asyncio.get_running_loop().create_future()
if exception and result:
raise AssertionError("Specify a result or an exception but not both")
if exception:
f.set_exception(exception)
else:
f.set_result(result)
return f
33 changes: 13 additions & 20 deletions tests/driver/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from esrally import config, exceptions, metrics, track
from esrally.driver import driver, runner, scheduler
from esrally.track import params
from tests import as_future, run_async
from tests import run_async


class DriverTestParamSource:
Expand Down Expand Up @@ -1369,7 +1369,7 @@ async def test_execute_schedule_in_throughput_mode(self, es):
task_start = time.perf_counter()
es.new_request_context.return_value = AsyncExecutorTests.StaticRequestTiming(task_start=task_start)

es.bulk.return_value = as_future(io.StringIO('{"errors": false, "took": 8}'))
es.bulk = mock.AsyncMock(return_value=io.StringIO('{"errors": false, "took": 8}'))

params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource)
test_track = track.Track(name="unittest", description="unittest track", indices=None, challenges=None)
Expand Down Expand Up @@ -1566,8 +1566,8 @@ async def test_execute_schedule_runner_overrides_times(self, es):
@mock.patch("elasticsearch.Elasticsearch")
@run_async
async def test_execute_schedule_throughput_throttled(self, es):
def perform_request(*args, **kwargs):
return as_future()
async def perform_request(*args, **kwargs):
return None

es.init_request_context.return_value = {"request_start": 0, "request_end": 10}
# as this method is called several times we need to return a fresh instance every time as the previous
Expand Down Expand Up @@ -1635,7 +1635,7 @@ def perform_request(*args, **kwargs):
@run_async
async def test_cancel_execute_schedule(self, es):
es.init_request_context.return_value = {"request_start": 0, "request_end": 10}
es.bulk.return_value = as_future(io.StringIO('{"errors": false, "took": 8}'))
es.bulk = mock.AsyncMock(return_value=io.StringIO('{"errors": false, "took": 8}'))

params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource)
test_track = track.Track(name="unittest", description="unittest track", indices=None, challenges=None)
Expand Down Expand Up @@ -1742,8 +1742,7 @@ async def __call__(self):
async def test_execute_single_no_return_value(self):
es = None
params = None
runner = mock.Mock()
runner.return_value = as_future()
runner = mock.AsyncMock()

ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue")

Expand All @@ -1755,8 +1754,7 @@ async def test_execute_single_no_return_value(self):
async def test_execute_single_tuple(self):
es = None
params = None
runner = mock.Mock()
runner.return_value = as_future(result=(500, "MB"))
runner = mock.AsyncMock(return_value=(500, "MB"))

ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue")

Expand All @@ -1768,9 +1766,8 @@ async def test_execute_single_tuple(self):
async def test_execute_single_dict(self):
es = None
params = None
runner = mock.Mock()
runner.return_value = as_future(
{
runner = mock.AsyncMock(
return_value={
"weight": 50,
"unit": "docs",
"some-custom-meta-data": "valid",
Expand Down Expand Up @@ -1798,7 +1795,7 @@ async def test_execute_single_with_connection_error_always_aborts(self):
es = None
params = None
# ES client uses pseudo-status "N/A" in this case...
runner = mock.Mock(side_effect=as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host", None)))
runner = mock.AsyncMock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host", None))

with self.assertRaises(exceptions.RallyAssertionError) as ctx:
await driver.execute_single(self.context_managed(runner), es, params, on_error=on_error)
Expand All @@ -1808,9 +1805,7 @@ async def test_execute_single_with_connection_error_always_aborts(self):
async def test_execute_single_with_http_400_aborts_when_specified(self):
es = None
params = None
runner = mock.Mock(
side_effect=as_future(exception=elasticsearch.NotFoundError(404, "not found", "the requested document could not be found"))
)
runner = mock.AsyncMock(side_effect=elasticsearch.NotFoundError(404, "not found", "the requested document could not be found"))

with self.assertRaises(exceptions.RallyAssertionError) as ctx:
await driver.execute_single(self.context_managed(runner), es, params, on_error="abort")
Expand All @@ -1823,9 +1818,7 @@ async def test_execute_single_with_http_400_aborts_when_specified(self):
async def test_execute_single_with_http_400(self):
es = None
params = None
runner = mock.Mock(
side_effect=as_future(exception=elasticsearch.NotFoundError(404, "not found", "the requested document could not be found"))
)
runner = mock.AsyncMock(side_effect=elasticsearch.NotFoundError(404, "not found", "the requested document could not be found"))

ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue")

Expand All @@ -1845,7 +1838,7 @@ async def test_execute_single_with_http_400(self):
async def test_execute_single_with_http_413(self):
es = None
params = None
runner = mock.Mock(side_effect=as_future(exception=elasticsearch.NotFoundError(413, b"", b"")))
runner = mock.AsyncMock(side_effect=elasticsearch.NotFoundError(413, b"", b""))

ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue")

Expand Down
Loading

0 comments on commit 24dee9e

Please sign in to comment.