Skip to content

Commit

Permalink
Support async tests w/ same fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp committed Jul 16, 2022
1 parent 71206d6 commit c03954c
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 29 deletions.
19 changes: 1 addition & 18 deletions sdk/keyvault/azure-keyvault-keys/tests/test_keys_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def emit(self, record):
self.messages.append(record)


@pytest.mark.usefixtures("recorded_test", "variable_recorder")
class TestKeyVaultKey(KeyVaultTestCase, KeysTestCase):

def _assert_jwks_equal(self, jwk1, jwk2):
Expand Down Expand Up @@ -175,7 +176,6 @@ def _to_bytes(hex):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_key_crud_operations(self, client, is_hsm, **kwargs):
assert client is not None

Expand Down Expand Up @@ -242,7 +242,6 @@ async def test_key_crud_operations(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_hsm)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_rsa_public_exponent(self, client, **kwargs):
"""The public exponent of a Managed HSM RSA key can be specified during creation"""
assert client is not None
Expand All @@ -255,7 +254,6 @@ async def test_rsa_public_exponent(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_backup_restore(self, client, is_hsm, **kwargs):
assert client is not None

Expand Down Expand Up @@ -283,7 +281,6 @@ async def test_backup_restore(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_key_list(self, client, is_hsm, **kwargs):
assert client is not None

Expand All @@ -307,7 +304,6 @@ async def test_key_list(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_list_versions(self, client, is_hsm, **kwargs):
assert client is not None

Expand All @@ -334,7 +330,6 @@ async def test_list_versions(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_list_deleted_keys(self, client, is_hsm, **kwargs):
assert client is not None

Expand Down Expand Up @@ -366,7 +361,6 @@ async def test_list_deleted_keys(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_recover(self, client, is_hsm, **kwargs):
assert client is not None

Expand Down Expand Up @@ -397,7 +391,6 @@ async def test_recover(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_purge(self, client, is_hsm, **kwargs):
assert client is not None

Expand Down Expand Up @@ -425,7 +418,6 @@ async def test_purge(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",logging_enabled)
@AsyncKeysClientPreparer(logging_enable = True)
@recorded_by_proxy_async
async def test_logging_enabled(self, client, is_hsm, **kwargs):
mock_handler = MockHandler()

Expand Down Expand Up @@ -461,7 +453,6 @@ async def test_logging_enabled(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",logging_disabled)
@AsyncKeysClientPreparer(logging_enable = False)
@recorded_by_proxy_async
async def test_logging_disabled(self, client, is_hsm, **kwargs):
mock_handler = MockHandler()

Expand Down Expand Up @@ -496,7 +487,6 @@ async def test_logging_disabled(self, client, is_hsm, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_hsm_7_3)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_get_random_bytes(self, client, **kwargs):
assert client

Expand All @@ -513,7 +503,6 @@ async def test_get_random_bytes(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_7_3)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_key_release(self, client, **kwargs):
set_bodiless_matcher()
attestation_uri = self._get_attestation_uri()
Expand All @@ -534,7 +523,6 @@ async def test_key_release(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_hsm_7_3)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_imported_key_release(self, client, **kwargs):
set_bodiless_matcher()
attestation_uri = self._get_attestation_uri()
Expand All @@ -555,7 +543,6 @@ async def test_imported_key_release(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_7_3)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_update_release_policy(self, client, **kwargs):
set_bodiless_matcher()
attestation_uri = self._get_attestation_uri()
Expand Down Expand Up @@ -598,7 +585,6 @@ async def test_update_release_policy(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_vault_7_3)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_immutable_release_policy(self, client, **kwargs):
set_bodiless_matcher()
attestation_uri = self._get_attestation_uri()
Expand Down Expand Up @@ -633,7 +619,6 @@ async def test_immutable_release_policy(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_vault_7_3)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_key_rotation(self, client, **kwargs):
set_bodiless_matcher()
if (not is_public_cloud() and self.is_live):
Expand All @@ -651,7 +636,6 @@ async def test_key_rotation(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",only_vault_7_3)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_key_rotation_policy(self, client, **kwargs):
set_bodiless_matcher()
if (not is_public_cloud() and self.is_live):
Expand Down Expand Up @@ -724,7 +708,6 @@ async def test_key_rotation_policy(self, client, **kwargs):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_version,is_hsm",all_api_versions)
@AsyncKeysClientPreparer()
@recorded_by_proxy_async
async def test_get_cryptography_client(self, client, is_hsm, **kwargs):
key_name = self.get_resource_name("key-name")
key = await self._create_rsa_key(client, key_name, hardware_protected=is_hsm)
Expand Down
114 changes: 103 additions & 11 deletions tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from inspect import iscoroutinefunction
import logging
import requests
import six
Expand All @@ -24,7 +25,7 @@
from .proxy_startup import test_proxy

if TYPE_CHECKING:
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
from azure.core.pipeline.transport import HttpRequest

# To learn about how to migrate SDK tests to the test proxy, please refer to the migration guide at
Expand Down Expand Up @@ -126,7 +127,7 @@ def transform_request(request: "HttpRequest", recording_id: str) -> None:
request.url = updated_target


def recorded_by_proxy(test_func) -> None:
def recorded_by_proxy(test_func: "Callable") -> None:
"""Decorator that redirects network requests to target the azure-sdk-tools test proxy. Use with recorded tests.
For more details and usage examples, refer to
Expand Down Expand Up @@ -212,18 +213,79 @@ def start_proxy_session() -> "Optional[Tuple[str, str, Dict[str, str]]]":


@pytest.fixture
def recorded_test(test_proxy, request) -> "Dict[str, Any]":
"""Fixture that redirects network requests to target the azure-sdk-tools test proxy. Use with recorded tests.
async def recorded_test(test_proxy: None, request: pytest.FixtureRequest) -> "Dict[str, Any]":
"""Fixture that redirects network requests to target the azure-sdk-tools test proxy.
For more details and usage examples, refer to
Use with recorded tests. For more details and usage examples, refer to
https://github.com/Azure/azure-sdk-for-python/blob/main/doc/dev/test_proxy_migration_guide.md.
:param function test_proxy: The fixture responsible for starting up the test proxy server.
:param function request: The built-in `request` fixture.
:yields: A dictionary containing information relevant to the currently executing test.
"""

test_id, recording_id, variables = start_proxy_session()

# True if the function requesting the fixture is an async test
is_async_test = iscoroutinefunction(request._pyfuncitem.function)
if is_async_test:
original_transport_func = await redirect_async_traffic(recording_id)
yield {"variables": variables} # yield relevant test info and allow tests to run
restore_async_traffic(original_transport_func, request)
else:
original_transport_func = redirect_traffic(recording_id)
yield {"variables": variables} # yield relevant test info and allow tests to run
restore_traffic(original_transport_func, request)

stop_record_or_playback(test_id, recording_id, variables)


async def redirect_async_traffic(recording_id: str) -> "Callable":
"""Redirects asynchronous network requests to target the test proxy.
:param str recording_id: Recording ID of the currently executing test.
:returns: The original transport function used by the currently executing test.
"""
from azure.core.pipeline.transport import AioHttpTransport

original_transport_func = AioHttpTransport.send

def transform_args(*args, **kwargs):
copied_positional_args = list(args)
request = copied_positional_args[1]

transform_request(request, recording_id)

return tuple(copied_positional_args), kwargs

async def combined_call(*args, **kwargs):
adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs)
result = await original_transport_func(*adjusted_args, **adjusted_kwargs)

# make the x-recording-upstream-base-uri the URL of the request
# this makes the request look like it was made to the original endpoint instead of to the proxy
# without this, things like LROPollers can get broken by polling the wrong endpoint
parsed_result = url_parse.urlparse(result.request.url)
upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
original_target = parsed_result._replace(**upstream_uri_dict).geturl()

result.request.url = original_target
return result

AioHttpTransport.send = combined_call
return original_transport_func


def redirect_traffic(recording_id: str) -> "Callable":
"""Redirects network requests to target the test proxy.
:param str recording_id: Recording ID of the currently executing test.
:returns: The original transport function used by the currently executing test.
"""
original_transport_func = RequestsTransport.send

def transform_args(*args, **kwargs):
Expand All @@ -250,12 +312,20 @@ def combined_call(*args, **kwargs):
return result

RequestsTransport.send = combined_call
return original_transport_func

# store info pertinent to the test in a dictionary that other fixtures can access
test_info = {"variables": variables}
yield test_info # yield and allow test to run

RequestsTransport.send = original_transport_func # test finished running -- tear down
def restore_async_traffic(original_transport_func: "Callable", request: pytest.FixtureRequest) -> None:
"""Resets asynchronous network traffic to no longer target the test proxy.
:param original_transport_func: The original transport function used by the currently executing test.
:type original_transport_func: Callable
:param request: The built-in `request` pytest fixture.
:type request: ~pytest.FixtureRequest
"""
from azure.core.pipeline.transport import AioHttpTransport

AioHttpTransport.send = original_transport_func # test finished running -- tear down

if hasattr(request.node, "test_error"):
# Exceptions are logged here instead of being raised because of how pytest handles error raising from inside
Expand All @@ -270,11 +340,33 @@ def combined_call(*args, **kwargs):
logger = logging.getLogger()
logger.error(f"\n\n-----Test proxy playback error:-----\n\n{message}")

stop_record_or_playback(test_id, recording_id, variables)

def restore_traffic(original_transport_func: "Callable", request: pytest.FixtureRequest) -> None:
"""Resets network traffic to no longer target the test proxy.
:param original_transport_func: The original transport function used by the currently executing test.
:type original_transport_func: Callable
:param request: The built-in `request` pytest fixture.
:type request: ~pytest.FixtureRequest
"""
RequestsTransport.send = original_transport_func # test finished running -- tear down

if hasattr(request.node, "test_error"):
# Exceptions are logged here instead of being raised because of how pytest handles error raising from inside
# fixtures and hooks. Raising from a fixture raises an error in addition to the test failure report, and the
# test proxy error is logged before the test failure output (making it difficult to find in pytest output).
# Raising from a hook isn't allowed, and produces an internal error that disrupts test execution.
# ResourceNotFoundErrors during playback indicate a recording mismatch
error = request.node.test_error
if isinstance(error, ResourceNotFoundError):
error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
message = error_body.get("message") or error_body.get("Message")
logger = logging.getLogger()
logger.error(f"\n\n-----Test proxy playback error:-----\n\n{message}")


@pytest.fixture
def variable_recorder(recorded_test) -> "Dict[str, str]":
def variable_recorder(recorded_test: "Dict[str, Any]") -> "Dict[str, str]":
"""Fixture that invokes the `recorded_test` fixture and returns a dictionary of recorded test variables.
:param function recorded_test: The fixture responsible for redirecting network traffic to target the test proxy.
Expand Down

0 comments on commit c03954c

Please sign in to comment.