Skip to content

Commit

Permalink
Major improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
TheRawMeatball committed Feb 15, 2021
1 parent f1073f5 commit 49adf8e
Showing 1 changed file with 171 additions and 24 deletions.
195 changes: 171 additions & 24 deletions crates/bevy_ecs/src/schedule/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use std::{
any::TypeId,
marker::PhantomData,
mem::{discriminant, Discriminant},
};

Expand All @@ -12,6 +11,13 @@ use crate::{
};
use thiserror::Error;

/// ### Stack based state machine
/// Important: `T` must be an enum!
///
/// This state machine has three operations: Next, Pop, and Push.
/// * Push pushes a new state to the state stack, pausing the previous state
/// * Pop removes the current state, and unpauses the last paused state.
/// * Next unwinds the state stack, and replaces the entire stack with a single new state
#[derive(Debug)]
pub struct State<T: Clone> {
transition: Option<StateTransition<T>>,
Expand Down Expand Up @@ -61,6 +67,9 @@ impl<T: Clone + Resource> State<T> {
Wrapper::<T, OnExit>::new(discriminant(&d))
}

/// Creates a driver set for the State.
///
/// Important note: this set must be inserted **before** all other state-dependant sets to work properly!
pub fn make_driver() -> SystemSet {
SystemSet::default().with_run_criteria(state_cleaner::<T>.system())
}
Expand Down Expand Up @@ -165,24 +174,37 @@ pub enum StateError {
}

trait Comparer<T: Clone> {
fn compare(d: Discriminant<T>, s: &State<T>) -> bool;
fn compare(&mut self, d: Discriminant<T>, s: &State<T>) -> bool;
}

#[derive(Default)]
struct OnUpdate;
impl<T: Clone> Comparer<T> for OnUpdate {
fn compare(d: Discriminant<T>, s: &State<T>) -> bool {
fn compare(&mut self, d: Discriminant<T>, s: &State<T>) -> bool {
discriminant(s.stack.last().unwrap()) == d && s.transition.is_none()
}
}
struct OnInactiveUpdate;
#[derive(Default)]
struct OnInactiveUpdate(bool /*tracks if the relevant state has been entered*/);
impl<T: Clone> Comparer<T> for OnInactiveUpdate {
fn compare(d: Discriminant<T>, s: &State<T>) -> bool {
s.stack.iter().map(discriminant).any(|s| s == d) && s.transition.is_none()
fn compare(&mut self, d: Discriminant<T>, s: &State<T>) -> bool {
match &s.transition {
Some(StateTransition::Pausing(ref relevant, _))
| Some(StateTransition::Resuming(_, ref relevant)) => {
if discriminant(relevant) == d {
self.0 = !self.0;
}
false
}
Some(_) => false,
None => self.0,
}
}
}
#[derive(Default)]
struct OnEnter;
impl<T: Clone> Comparer<T> for OnEnter {
fn compare(d: Discriminant<T>, s: &State<T>) -> bool {
fn compare(&mut self, d: Discriminant<T>, s: &State<T>) -> bool {
s.transition
.as_ref()
.map_or(false, |transition| match transition {
Expand All @@ -192,9 +214,10 @@ impl<T: Clone> Comparer<T> for OnEnter {
})
}
}
#[derive(Default)]
struct OnExit;
impl<T: Clone> Comparer<T> for OnExit {
fn compare(d: Discriminant<T>, s: &State<T>) -> bool {
fn compare(&mut self, d: Discriminant<T>, s: &State<T>) -> bool {
s.transition
.as_ref()
.map_or(false, |transition| match transition {
Expand All @@ -204,9 +227,10 @@ impl<T: Clone> Comparer<T> for OnExit {
})
}
}
#[derive(Default)]
struct OnPause;
impl<T: Clone> Comparer<T> for OnPause {
fn compare(d: Discriminant<T>, s: &State<T>) -> bool {
fn compare(&mut self, d: Discriminant<T>, s: &State<T>) -> bool {
s.transition
.as_ref()
.map_or(false, |transition| match transition {
Expand All @@ -215,9 +239,10 @@ impl<T: Clone> Comparer<T> for OnPause {
})
}
}
#[derive(Default)]
struct OnResume;
impl<T: Clone> Comparer<T> for OnResume {
fn compare(d: Discriminant<T>, s: &State<T>) -> bool {
fn compare(&mut self, d: Discriminant<T>, s: &State<T>) -> bool {
s.transition
.as_ref()
.map_or(false, |transition| match transition {
Expand All @@ -227,7 +252,7 @@ impl<T: Clone> Comparer<T> for OnResume {
}
}

impl<T: Clone + Resource, C: Comparer<T>> Wrapper<T, C> {
impl<T: Clone + Resource, C: Comparer<T> + Default> Wrapper<T, C> {
fn new(discriminant: Discriminant<T>) -> Self {
let mut resource_access = TypeAccess::default();
resource_access.add_read(std::any::TypeId::of::<State<T>>());
Expand All @@ -238,22 +263,22 @@ impl<T: Clone + Resource, C: Comparer<T>> Wrapper<T, C> {
id: SystemId::new(),
archetype_access: Default::default(),
component_access: Default::default(),
marker: Default::default(),
comparer: Default::default(),
}
}
}

struct Wrapper<T: Clone + Resource, C: Comparer<T>> {
struct Wrapper<T: Clone + Resource, C: Comparer<T> + Default> {
discriminant: Discriminant<T>,
exit_flag: bool,
resource_access: TypeAccess<TypeId>,
id: SystemId,
archetype_access: TypeAccess<ArchetypeComponent>,
component_access: TypeAccess<TypeId>,
marker: PhantomData<C>,
comparer: C,
}

impl<T: Clone + Resource, C: Comparer<T> + Resource> System for Wrapper<T, C> {
impl<T: Clone + Resource, C: Comparer<T> + Resource + Default> System for Wrapper<T, C> {
type In = ();
type Out = ShouldRun;

Expand Down Expand Up @@ -299,7 +324,7 @@ impl<T: Clone + Resource, C: Comparer<T> + Resource> System for Wrapper<T, C> {
Some(ShouldRun::No)
} else {
self.exit_flag = true;
Some(if C::compare(self.discriminant, state) {
Some(if self.comparer.compare(self.discriminant, state) {
ShouldRun::YesAndCheckAgain
} else {
ShouldRun::NoAndCheckAgain
Expand Down Expand Up @@ -364,28 +389,150 @@ fn state_cleaner<T: Clone + Resource>(mut state: ResMut<State<T>>) -> ShouldRun
mod test {
use crate::prelude::*;

#[derive(Clone, Copy)]
enum StateEnum {
#[derive(Clone, Copy, PartialEq, Eq)]
enum MyState {
S1,
S2,
S3,
S4,
S5,
S6,
Final,
}

#[test]
fn state_test() {
let mut world = World::default();
let mut resources = Resources::default();

resources.insert(Vec::<i32>::new());
resources.insert(Vec::<&'static str>::new());
resources.insert(State::new(MyState::S1));

let mut stage = SystemStage::parallel();

stage.add_system_set(State::<StateEnum>::make_driver());
stage.add_system_set(
SystemSet::default().with_run_criteria(State::<StateEnum>::on_update(StateEnum::S1)),
);
stage.add_system_set(State::<MyState>::make_driver());
stage
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>, mut s: ResMut<State<MyState>>| {
r.push("update S1");
s.overwrite_next(MyState::S2).unwrap();
})
.system(),
)
.with_run_criteria(State::<MyState>::on_update(MyState::S1)),
)
.add_system_set(
SystemSet::default()
.with_system((|mut r: ResMut<Vec<&'static str>>| r.push("enter S2")).system())
.with_run_criteria(State::<MyState>::on_enter(MyState::S2)),
)
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>, mut s: ResMut<State<MyState>>| {
r.push("update S2");
s.overwrite_next(MyState::S3).unwrap();
})
.system(),
)
.with_run_criteria(State::<MyState>::on_update(MyState::S2)),
)
.add_system_set(
SystemSet::default()
.with_system((|mut r: ResMut<Vec<&'static str>>| r.push("exit S2")).system())
.with_run_criteria(State::<MyState>::on_exit(MyState::S2)),
)
.add_system_set(
SystemSet::default()
.with_system((|mut r: ResMut<Vec<&'static str>>| r.push("enter S3")).system())
.with_run_criteria(State::<MyState>::on_enter(MyState::S3)),
)
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>, mut s: ResMut<State<MyState>>| {
r.push("update S3");
s.overwrite_push(MyState::S4).unwrap();
})
.system(),
)
.with_run_criteria(State::<MyState>::on_update(MyState::S3)),
)
.add_system_set(
SystemSet::default()
.with_system((|mut r: ResMut<Vec<&'static str>>| r.push("pause S3")).system())
.with_run_criteria(State::<MyState>::on_pause(MyState::S3)),
)
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>, mut s: ResMut<State<MyState>>| {
r.push("update S4");
s.overwrite_push(MyState::S5).unwrap();
})
.system(),
)
.with_run_criteria(State::<MyState>::on_update(MyState::S4)),
)
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>| r.push("inactive S4")).system(),
)
.with_run_criteria(State::<MyState>::on_inactive_update(MyState::S4)),
)
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>, mut s: ResMut<State<MyState>>| {
r.push("update S5");
s.overwrite_push(MyState::S6).unwrap();
})
.system(),
)
.with_run_criteria(State::<MyState>::on_update(MyState::S5)),
)
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>| r.push("inactive S5")).system(),
)
.with_run_criteria(State::<MyState>::on_inactive_update(MyState::S5)),
)
.add_system_set(
SystemSet::default()
.with_system(
(|mut r: ResMut<Vec<&'static str>>, mut s: ResMut<State<MyState>>| {
r.push("update S6");
s.overwrite_push(MyState::Final).unwrap();
})
.system(),
)
.with_run_criteria(State::<MyState>::on_update(MyState::S6)),
)
.add_system_set(
SystemSet::default()
.with_system((|mut r: ResMut<Vec<&'static str>>| r.push("resume S4")).system())
.with_run_criteria(State::<MyState>::on_resume(MyState::S4)),
)
.add_system_set(
SystemSet::default()
.with_system((|mut r: ResMut<Vec<&'static str>>| r.push("exit S4")).system())
.with_run_criteria(State::<MyState>::on_exit(MyState::S4)),
);

loop {
println!("new run!");
stage.run(&mut world, &mut resources);
for s in resources.get_mut::<Vec<&'static str>>().unwrap().drain(..) {
println!("{}", s);
}

stage.run(&mut world, &mut resources);
if resources.get::<State<MyState>>().unwrap().current() == &MyState::Final {
break;
}
}
}
}

0 comments on commit 49adf8e

Please sign in to comment.