diff --git a/test-harness-macro/src/lib.rs b/test-harness-macro/src/lib.rs index 3c8e6d0b..d2b0edfb 100644 --- a/test-harness-macro/src/lib.rs +++ b/test-harness-macro/src/lib.rs @@ -68,41 +68,17 @@ fn make_test_fn(mut function: ItemFn, ghc_version: &str) -> ItemFn { function.sig.ident = Ident::new(&test_name, function.sig.ident.span()); // Wrap the test code in startup/cleanup code. - // - // Before the user test code, we set the thread-local storage to test-local data so that when - // we construct a `test_harness::GhcidNg` it can use the correct GHC version. - // - // Then we run the user test code. If it errors, we save the logs to `CARGO_TARGET_TMPDIR`. - // - // Finally, we clean up the temporary directory `GhcidNg` created. let new_body = parse::( quote! { { - ::test_harness::internal::GHC_VERSION.with(|tmpdir| { - *tmpdir.borrow_mut() = #ghc_version.to_owned(); - }); - - match ::tokio::task::spawn(async { - #(#stmts);* - }).await { - Err(err) => { - // Copy out temp files - ::test_harness::internal::save_test_logs( - format!("{}::{}", module_path!(), #test_name), - ::std::path::PathBuf::from(env!("CARGO_TARGET_TMPDIR")) - ); - ::test_harness::internal::cleanup().await; - - if err.is_panic() { - ::std::panic::resume_unwind(err.into_panic()); - } else { - panic!("Test cancelled? {err:?}"); - } - } - Ok(()) => { - ::test_harness::internal::cleanup().await; - } - }; + ::test_harness::internal::wrap_test( + async { + #(#stmts);* + }, + #ghc_version, + #test_name, + env!("CARGO_TARGET_TMPDIR"), + ).await; } } .into(), diff --git a/test-harness/src/internal.rs b/test-harness/src/internal.rs index cd03df9a..46367e72 100644 --- a/test-harness/src/internal.rs +++ b/test-harness/src/internal.rs @@ -1,6 +1,7 @@ //! Internal functions, exposed for the `#[test]` attribute macro. use std::cell::RefCell; +use std::future::Future; use std::path::PathBuf; use std::time::Duration; @@ -20,7 +21,7 @@ thread_local! { /// The GHC version to use for this test. This should be a string like `ghc962`. /// This is used to open a corresponding (e.g.) `nix develop .#ghc962` shell to run `ghcid-ng` /// in. - pub static GHC_VERSION: RefCell = RefCell::new(String::new()); + pub(crate) static GHC_VERSION: RefCell = RefCell::new(String::new()); /// The GHC process for this test. /// @@ -28,11 +29,51 @@ thread_local! { pub(crate) static GHC_PROCESS: RefCell> = RefCell::new(None); } +/// Wraps an asynchronous test with startup/cleanup code. +/// +/// Before the user test code, we set the [`GHC_VERSION`] thread-local variable so that when +/// we construct a [`super::GhcidNg`] it can use the correct GHC version. +/// +/// Then we run the user test code. If it errors, we save the logs to `CARGO_TARGET_TMPDIR`. +/// +/// Finally, we wait for the [`GHC_PROCESS`] to exit and clean up the temporary directory `GhcidNg` +/// created. +pub async fn wrap_test( + test: impl Future + Send + 'static, + ghc_version: &'static str, + test_name: &'static str, + cargo_target_tmpdir: &'static str, +) { + GHC_VERSION.with(|tmpdir| { + *tmpdir.borrow_mut() = ghc_version.to_owned(); + }); + + match tokio::task::spawn(test).await { + Err(err) => { + // Copy out temp files + save_test_logs( + format!("{}::{}", module_path!(), test_name), + PathBuf::from(cargo_target_tmpdir), + ); + cleanup().await; + + if err.is_panic() { + std::panic::resume_unwind(err.into_panic()); + } else { + panic!("Test cancelled? {err:?}"); + } + } + Ok(()) => { + cleanup().await; + } + }; +} + /// Save the test logs in `TEMPDIR` to `cargo_target_tmpdir`. /// /// This is called when a `#[test]`-annotated function panics, to persist the logs for further /// analysis. -pub fn save_test_logs(test_name: String, cargo_target_tmpdir: PathBuf) { +fn save_test_logs(test_name: String, cargo_target_tmpdir: PathBuf) { let log_path: PathBuf = TEMPDIR.with(|tempdir| { tempdir .borrow() @@ -62,7 +103,7 @@ pub fn save_test_logs(test_name: String, cargo_target_tmpdir: PathBuf) { /// /// 1. Kill the [`GHC_PROCESS`]. /// 2. Remove the [`TEMPDIR`] from the filesystem. -pub async fn cleanup() { +async fn cleanup() { let child = GHC_PROCESS.with(|child| child.take()); match child { None => {