Skip to content

Commit

Permalink
add before/after iter hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
stepantubanov committed May 31, 2024
1 parent ac10083 commit e07e781
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/enabled/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub struct Runner {
on_panic: Option<PanicHandler>,
before_step: Option<BeforeStep>,
after_step: Option<AfterStep>,
before_iter: Option<BeforeIter>,
after_iter: Option<AfterIter>,
}

enum IterationConfig {
Expand All @@ -36,6 +38,8 @@ enum IterationConfig {
pub type PanicHandler = Box<dyn FnOnce(&Trace)>;
pub type BeforeStep = Box<dyn FnMut() -> BoxFuture<'static, ()>>;
pub type AfterStep = Box<dyn FnMut() -> BoxFuture<'static, ()>>;
pub type BeforeIter = Box<dyn FnMut() -> BoxFuture<'static, ()>>;
pub type AfterIter = Box<dyn FnMut() -> BoxFuture<'static, ()>>;

impl Default for Runner {
fn default() -> Self {
Expand All @@ -46,6 +50,8 @@ impl Default for Runner {
on_panic: None,
before_step: None,
after_step: None,
before_iter: None,
after_iter: None,
}
}
}
Expand Down Expand Up @@ -93,6 +99,16 @@ impl Runner {
self
}

pub fn before_iter(mut self, before_iter: BeforeStep) -> Self {
self.before_iter = Some(before_iter);
self
}

pub fn after_iter(mut self, after_iter: AfterStep) -> Self {
self.after_iter = Some(after_iter);
self
}

pub async fn run<'a, I, F, Fut>(self, initial_tasks: I, mut f: F)
where
I: IntoIterator,
Expand Down Expand Up @@ -131,6 +147,10 @@ impl Runner {
let mut task_ids_from_trace = trace.task_ids.into_iter();
let mut rng = Rng::new();

if let Some(before_iter) = &mut self.before_iter {
before_iter().await;
}

loop {
let _tasks = controller.ready(WAIT_TIMEOUT).await;
if let Some(before_step) = &mut self.before_step {
Expand Down Expand Up @@ -165,6 +185,10 @@ impl Runner {

controller.assert_finished();
drop(controller);

if let Some(after_iter) = &mut self.after_iter {
after_iter().await;
}
};

(state, _) = join!(f(state), control);
Expand All @@ -186,6 +210,10 @@ impl Runner {
let mut rng = Rng::new();
let mut cursor = schedule_tree.pick_unfinished_path(&mut rng).unwrap();

if let Some(before_iter) = &mut self.before_iter {
before_iter().await;
}

loop {
let tasks = controller.ready(WAIT_TIMEOUT).await;
let Some(task_id) = cursor.visit_and_pick(tasks, &mut rng) else {
Expand All @@ -204,6 +232,10 @@ impl Runner {

controller.assert_finished();
drop(controller);

if let Some(after_iter) = &mut self.after_iter {
after_iter().await;
}
};

let result = AssertUnwindSafe(async {
Expand Down

0 comments on commit e07e781

Please sign in to comment.