Skip to content

Commit

Permalink
add before/after step hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
stepantubanov committed May 27, 2024
1 parent 04bb3ec commit 97ab21f
Showing 1 changed file with 67 additions and 49 deletions.
116 changes: 67 additions & 49 deletions src/enabled/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
};

use fastrand::Rng;
use futures_util::{join, FutureExt};
use futures_util::{future::BoxFuture, join, FutureExt};

use crate::enabled::{
controller::{Controller, TaskState},
Expand All @@ -22,66 +22,70 @@ pub fn runner() -> Runner {
}

pub struct Runner {
config: Config,
iteration_config: IterationConfig,
on_panic: Option<PanicHandler>,
before_step: Option<BeforeStep>,
after_step: Option<AfterStep>,
}

enum Config {
Iterate(IterateConfig),
enum IterationConfig {
Replay { trace: Trace },
}

struct IterateConfig {
max_iterations: u64,
on_panic: Option<PanicHandler>,
Iterate { max_iterations: u64 },
}

pub type PanicHandler = Box<dyn FnOnce(&Trace)>;
pub type BeforeStep = Box<dyn FnMut() -> BoxFuture<'static, ()>>;
pub type AfterStep = Box<dyn FnMut() -> BoxFuture<'static, ()>>;

impl Default for Runner {
fn default() -> Self {
Self {
iteration_config: IterationConfig::Iterate {
max_iterations: u64::MAX,
},
on_panic: None,
before_step: None,
after_step: None,
}
}
}

impl Runner {
pub fn from_env() -> Self {
let mut runner = Self::default();

if let Ok(trace) = env::var("PARCHECK_REPLAY") {
let trace = trace.parse().expect("can't parse PARCHECK_REPLAY");
return Self {
config: Config::Replay { trace },
runner.iteration_config = IterationConfig::Replay { trace };
} else if let Ok(max_iterations) = env::var("PARCHECK_MAX_ITERATIONS") {
runner.iteration_config = IterationConfig::Iterate {
max_iterations: max_iterations
.parse()
.expect("failed to parse PARCHECK_MAX_ITERATIONS"),
};
}

let mut config = IterateConfig {
max_iterations: u64::MAX,
on_panic: None,
};
if let Ok(max_iterations) = env::var("PARCHECK_MAX_ITERATIONS") {
config.max_iterations = max_iterations
.parse()
.expect("failed to parse PARCHECK_MAX_ITERATIONS");
}
Self {
config: Config::Iterate(config),
}
runner
}

pub fn max_iterations(self, max_iterations: u64) -> Self {
match self.config {
Config::Iterate(config) => Self {
config: Config::Iterate(IterateConfig {
max_iterations,
..config
}),
},
Config::Replay { .. } => self,
}
pub fn max_iterations(mut self, max_iterations: u64) -> Self {
self.iteration_config = IterationConfig::Iterate { max_iterations };
self
}

pub fn on_panic(self, on_panic: PanicHandler) -> Self {
match self.config {
Config::Iterate(config) => Self {
config: Config::Iterate(IterateConfig {
on_panic: Some(on_panic),
..config
}),
},
Config::Replay { .. } => self,
}
pub fn on_panic(mut self, on_panic: PanicHandler) -> Self {
self.on_panic = Some(on_panic);
self
}

pub fn before_step(mut self, before_step: BeforeStep) -> Self {
self.before_step = Some(before_step);
self
}

pub fn after_step(mut self, after_step: AfterStep) -> Self {
self.after_step = Some(after_step);
self
}

pub async fn run<'a, I, F, Fut>(self, initial_tasks: I, mut f: F)
Expand All @@ -95,7 +99,7 @@ impl Runner {
}

pub async fn run_with_state<'a, T, I, F, Fut>(
self,
mut self,
initial_tasks: I,
mut state: T,
mut f: F,
Expand All @@ -114,27 +118,35 @@ impl Runner {
.map(|name| TaskName(name.into()))
.collect();

let config = match self.config {
Config::Replay { trace } => {
let max_iterations = match self.iteration_config {
IterationConfig::Replay { trace } => {
let mut controller = Controller::register(&initial_tasks);

let control = async {
for &task_id in &trace.task_ids {
let _tasks = controller.ready(WAIT_TIMEOUT).await;
if let Some(before_step) = &mut self.before_step {
before_step().await;
}
controller.step_forward(task_id).await;
if let Some(after_step) = &mut self.after_step {
after_step().await;
}
}
};

(state, _) = join!(f(state), control);
return state;
}
Config::Iterate(config) => config,
IterationConfig::Iterate {
max_iterations: max_iter,
} => max_iter,
};

let mut schedule_tree = ScheduleTree::new(&initial_tasks);
let mut iter = 0;

while schedule_tree.has_unfinished_paths() && iter < config.max_iterations {
while schedule_tree.has_unfinished_paths() && iter < max_iterations {
let mut controller = Controller::register(&initial_tasks);
let mut trace = Trace::new();

Expand All @@ -149,7 +161,13 @@ impl Runner {
};

trace.task_ids.push(task_id);
if let Some(before_step) = &mut self.before_step {
before_step().await;
}
controller.step_forward(task_id).await;
if let Some(after_step) = &mut self.after_step {
after_step().await;
}
}

// TODO: move this logic
Expand Down Expand Up @@ -203,7 +221,7 @@ impl Runner {
state = match result {
Ok(v) => v,
Err(error) => {
if let Some(on_panic) = config.on_panic {
if let Some(on_panic) = self.on_panic {
on_panic(&trace);
} else {
eprintln!(
Expand Down

0 comments on commit 97ab21f

Please sign in to comment.