Skip to content

Commit

Permalink
Nest authored test functions for safety
Browse files Browse the repository at this point in the history
Like `tokio::test` et. al., nest the authored test functions and call them, which helps maintain code safety by separating contexts as recommended for all macros.
  • Loading branch information
heaths committed Dec 9, 2024
1 parent 3e0ab7a commit e72135f
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 23 deletions.
3 changes: 3 additions & 0 deletions sdk/core/azure_core_test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ edition.workspace = true
rust-version.workspace = true

[dependencies]
async-trait.workspace = true
azure_core = { workspace = true, features = ["test"] }
azure_core_test_macros.workspace = true
serde.workspace = true
tracing.workspace = true

[dev-dependencies]
tokio.workspace = true
57 changes: 55 additions & 2 deletions sdk/core/azure_core_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
pub mod recorded {
pub use azure_core_test_macros::test;
}
mod sanitizers;
mod transport;

pub use azure_core::test::TestMode;
use azure_core::{ClientOptions, TransportOptions};
pub use sanitizers::*;
use std::sync::Arc;

/// Context information required by recorded client library tests.
///
Expand All @@ -17,24 +22,64 @@ pub use azure_core::test::TestMode;
#[derive(Clone, Debug)]
pub struct TestContext {
test_mode: TestMode,
crate_dir: &'static str,
test_name: &'static str,
}

impl TestContext {
/// Not intended for use outside the `azure_core` crates.
#[doc(hidden)]
pub fn new(test_mode: TestMode, test_name: &'static str) -> Self {
pub fn new(test_mode: TestMode, crate_dir: &'static str, test_name: &'static str) -> Self {
Self {
test_mode,
crate_dir,
test_name,
}
}

/// Instruments the [`ClientOptions`] to support recording and playing back of session records.
///
/// # Examples
///
/// ```no_run
/// use azure_core_test::{recorded, TestContext};
///
/// # struct MyClient;
/// # #[derive(Default)]
/// # struct MyClientOptions { client_options: azure_core::ClientOptions };
/// # impl MyClient {
/// # fn new(endpoint: impl AsRef<str>, options: Option<MyClientOptions>) -> Self { todo!() }
/// # async fn invoke(&self) -> azure_core::Result<()> { todo!() }
/// # }
/// #[recorded::test]
/// async fn test_invoke(ctx: TestContext) -> azure_core::Result<()> {
/// let mut options = MyClientOptions::default();
/// ctx.instrument(&mut options.client_options);
///
/// let client = MyClient::new("https://azure.net", Some(options));
/// client.invoke().await
/// }
/// ```
pub fn instrument(&self, options: &mut ClientOptions) {
let transport = options.transport.clone().unwrap_or_default();
options.transport = Some(TransportOptions::new_custom_policy(Arc::new(
transport::ProxyTransportPolicy {
inner: transport,
mode: self.test_mode,
},
)));
}

/// Gets the current [`TestMode`].
pub fn test_mode(&self) -> TestMode {
self.test_mode
}

/// Gets the root directory of the crate under test.
pub fn crate_dir(&self) -> &'static str {
self.crate_dir
}

/// Gets the current test function name.
pub fn test_name(&self) -> &'static str {
self.test_name
Expand All @@ -47,8 +92,16 @@ mod tests {

#[test]
fn test_content_new() {
let ctx = TestContext::new(TestMode::default(), "test_content_new");
let ctx = TestContext::new(
TestMode::default(),
env!("CARGO_MANIFEST_DIR"),
"test_content_new",
);
assert_eq!(ctx.test_mode(), TestMode::Playback);
assert!(ctx
.crate_dir()
.replace("\\", "/")
.ends_with("sdk/core/azure_core_test"));
assert_eq!(ctx.test_name(), "test_content_new");
}
}
118 changes: 118 additions & 0 deletions sdk/core/azure_core_test/src/sanitizers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use azure_core::headers::{AsHeaders, HeaderName, HeaderValue};
use serde::Serialize;
use std::{
convert::Infallible,
fmt,
iter::{once, Once},
};

/// Default sanitization replacement value, "Sanitized";
pub const SANITIZED_VALUE: &str = "Sanitized";
const ABSTRACTION_IDENTIFIER: HeaderName = HeaderName::from_static("x-abstraction-identifier");

/// Represents a sanitizer.
pub trait Sanitizer: AsHeaders + fmt::Debug + Serialize {}

macro_rules! impl_sanitizer {
($name:ident) => {
impl Sanitizer for $name {}

impl AsHeaders for $name {
type Error = Infallible;
type Iter = Once<(HeaderName, HeaderValue)>;
fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
Ok(once((
ABSTRACTION_IDENTIFIER,
HeaderValue::from_static(stringify!($name)),
)))
}
}
};

($($name:ident),+) => {
$(impl_sanitizer!($name))*

};
}

/// This sanitizer offers regular expression replacements within a returned JSON body for a specific JSONPath.
///
/// This sanitizer only applies to JSON bodies.
#[derive(Clone, Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct BodyKeySanitizer {
/// The JSONPath that will be checked for replacements.
pub json_path: String,

/// The substitution value. The default is [`SANITIZED_VALUE`].
#[serde(skip_serializing_if = "Option::is_none")]
pub value: Option<String>,

/// The regular expression to search for.
///
/// Can be defined as a simple regular expression replacement or, if [`BodyKeySanitizer::group_for_replace`] is set, a substitution operation.
/// Defaults to replacing the entire string.
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,

/// The regular expression capture group to substitute.
///
/// Do not set if you're invoking a simple replacement operation.
#[serde(skip_serializing_if = "Option::is_none")]
pub group_for_replace: Option<String>,
}
impl_sanitizer!(BodyKeySanitizer);

#[test]
fn test_body_key_sanitizer_as_headers() {
let sut = BodyKeySanitizer {
json_path: String::from("$.values"),
value: None,
regex: None,
group_for_replace: None,
};
let headers = sut.as_headers().expect("expect headers");
headers.for_each(|(h, v)| {
assert_eq!(h.as_str(), "x-abstraction-identifier");
assert_eq!(v.as_str(), "BodyKeySanitizer");
});
}

/// This sanitizer offers regular expression replacements within raw request and response bodies.
///
/// Specifically, this means the regular expression applies to the raw JSON.
/// If you are attempting to simply replace a specific JSON key, the [`BodyKeySanitizer`] is probably what you want to use.
#[derive(Clone, Debug, Default, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct BodyRegexSanitizer {
/// The substitution value. The default is [`SANITIZED_VALUE`].
#[serde(skip_serializing_if = "Option::is_none")]
pub value: Option<String>,

/// The regular expression to search for or the entire body if `None`.
///
/// Can be defined as a simple regular expression replacement or, if [`BodyRegexSanitizer::group_for_replace`] is set, a substitution operation.
/// Defaults to replacing the entire string.
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,

/// The regular expression capture group to substitute.
///
/// Do not set if you're invoking a simple replacement operation.
#[serde(skip_serializing_if = "Option::is_none")]
pub group_for_replace: Option<String>,
}
impl_sanitizer!(BodyRegexSanitizer);

#[test]
fn test_body_regex_sanitizer_as_headers() {
let sut = BodyRegexSanitizer::default();
let headers = sut.as_headers().expect("expect headers");
headers.for_each(|(h, v)| {
assert_eq!(h.as_str(), "x-abstraction-identifier");
assert_eq!(v.as_str(), "BodyRegexSanitizer");
});
}
33 changes: 33 additions & 0 deletions sdk/core/azure_core_test/src/transport.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use async_trait::async_trait;
use azure_core::{test::TestMode, Context, Policy, PolicyResult, Request, TransportOptions};
use std::sync::Arc;
use tracing::{debug_span, Instrument};

/// Wraps the original [`TransportOptions`] and records or plays back session records for testing.
#[derive(Debug)]
pub struct ProxyTransportPolicy {
pub(crate) inner: TransportOptions,
pub(crate) mode: TestMode,
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl Policy for ProxyTransportPolicy {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
// There must be no other policies since we're encapsulating the original TransportPolicy.
assert_eq!(0, next.len());

let span = debug_span!("test-proxy", mode = ?self.mode);
async move { { self.inner.send(ctx, request) }.await }
.instrument(span)
.await
}
}
50 changes: 30 additions & 20 deletions sdk/core/azure_core_test_macros/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,27 @@ use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Meta, PatType, Result,

const INVALID_RECORDED_ATTRIBUTE_MESSAGE: &str =
"expected `#[recorded::test]` or `#[recorded::test(live)]`";
const INVALID_RECORDED_FUNCTION_MESSAGE: &str = "expected `fn(TestContext)` function signature";
const INVALID_RECORDED_FUNCTION_MESSAGE: &str =
"expected `async fn(TestContext)` function signature with optional `Result<T, E>` return";

// cspell:ignore asyncness
pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
let recorded_attrs: Attributes = syn::parse2(attr)?;
let ItemFn {
attrs,
vis,
mut sig,
sig: original_sig,
block,
} = syn::parse2(item)?;

// Use #[tokio::test] for async functions; otherwise, #[test].
let mut test_attr: TokenStream = if sig.asyncness.is_some() {
quote! { #[::tokio::test] }
} else {
quote! { #[::core::prelude::v1::test] }
let mut test_attr: TokenStream = match original_sig.asyncness {
Some(_) => quote! { #[::tokio::test] },
None => {
return Err(syn::Error::new(
original_sig.span(),
INVALID_RECORDED_FUNCTION_MESSAGE,
))
}
};

// Ignore live-only tests if not running live tests.
Expand All @@ -36,21 +40,23 @@ pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
});
}

let mut inputs = sig.inputs.iter();
let preamble = match inputs.next() {
None if recorded_attrs.live => TokenStream::new(),
Some(FnArg::Typed(PatType { pat, ty, .. })) if is_test_context(ty.as_ref()) => {
let fn_name = &original_sig.ident;
let mut inputs = original_sig.inputs.iter();
let setup = match inputs.next() {
None if recorded_attrs.live => quote! {
#fn_name().await
},
Some(FnArg::Typed(PatType { ty, .. })) if is_test_context(ty.as_ref()) => {
let test_mode = test_mode_to_tokens(test_mode);
let fn_name = &sig.ident;

quote! {
#[allow(dead_code)]
let #pat = #ty::new(#test_mode, stringify!(#fn_name));
let ctx = #ty::new(#test_mode, env!("CARGO_MANIFEST_DIR"), stringify!(#fn_name));
#fn_name(ctx).await
}
}
_ => {
return Err(syn::Error::new(
sig.ident.span(),
original_sig.ident.span(),
INVALID_RECORDED_FUNCTION_MESSAGE,
))
}
Expand All @@ -63,14 +69,18 @@ pub fn parse_test(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
));
}

// Empty the parameters and return our rewritten test function.
sig.inputs.clear();
// Clear the actual test method parameters.
let mut outer_sig = original_sig.clone();
outer_sig.inputs.clear();

Ok(quote! {
#test_attr
#(#attrs)*
#vis #sig {
#preamble
#block
#vis #outer_sig {
#original_sig {
#block
}
#setup
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhubs/azure_messaging_eventhubs/tests/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async fn test_get_partition_properties() {
}

#[recorded::test(live)]
fn test_create_eventdata() {
async fn test_create_eventdata() {
common::setup();
let data = b"hello world";
let ed1 = azure_messaging_eventhubs::models::EventData::builder()
Expand Down

0 comments on commit e72135f

Please sign in to comment.