Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/motion matching with kdtree & kmeans #47

Merged
merged 12 commits into from
Nov 17, 2024
53 changes: 53 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ leafwing-input-manager = "0.14"

# debug editor (run with debug feature to enable it)
bevy-inspector-egui = { version = "0.25", optional = true }
kdtree = "0.7.0"
peak_alloc = "0.2.1"
clustering = "0.2.1"

[dev-dependencies]
bevy-inspector-egui = { version = "0.25" }
Expand Down
10 changes: 9 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Plugin for MotionMatchingAppPlugin {
draw_axes::DrawAxesPlugin,
));

app.init_state::<GameMode>();
app.init_state::<GameMode>().init_state::<Method>();
}
}

Expand All @@ -74,3 +74,11 @@ pub enum MainSet {
MotionMatching,
Animation,
}

#[derive(States, Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Method {
#[default]
BruteForceKNN,
KdTree,
KMeans,
}
84 changes: 72 additions & 12 deletions src/motion_matching.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
use bevy::prelude::*;
use std::time::Instant;

use kdtree_match::KdTreeMatchPlugin;
use kmeans_match::KMeansMatchPlugin;

pub mod kdtree_match;
pub mod kmeans_match;

use crate::bvh_manager::bvh_player::JointMap;
use crate::motion::chunk::ChunkIterator;
Expand All @@ -9,28 +16,51 @@ use crate::motion::motion_player::{
use crate::motion::{MotionData, MotionHandle};
use crate::trajectory::{Trajectory, TrajectoryConfig, TrajectoryDistance, TrajectoryPoint};
use crate::ui::play_mode::MotionMatchingResult;
use crate::{GameMode, MainSet, BVH_SCALE_RATIO};
use crate::{GameMode, MainSet, Method, BVH_SCALE_RATIO};

use peak_alloc::PeakAlloc;
#[global_allocator]
static PEAK_ALLOC: PeakAlloc = PeakAlloc;

pub struct MotionMatchingPlugin;

impl Plugin for MotionMatchingPlugin {
fn build(&self, app: &mut App) {
app.insert_resource(MatchConfig {
max_match_count: 5,
match_threshold: 0.2,
pred_match_threshold: 0.15,
})
.add_event::<TrajectoryMatch>()
.add_event::<PredictionMatch>()
.add_event::<NearestTrajectories>()
.add_systems(PreStartup, load_motion_data)
.add_systems(
app.configure_sets(
Update,
(flow, prediction_match, trajectory_match, pose_match)
(
MotionMatchingSet::Flow,
MotionMatchingSet::PredictionMatch,
MotionMatchingSet::GlobalMatch,
MotionMatchingSet::PoseMatch,
)
.chain()
.in_set(MainSet::MotionMatching)
.run_if(in_state(GameMode::Play)),
);

app.add_plugins(KdTreeMatchPlugin)
.add_plugins(KMeansMatchPlugin)
.insert_resource(MatchConfig {
max_match_count: 5,
match_threshold: 0.2,
pred_match_threshold: 0.15,
})
.add_event::<TrajectoryMatch>()
.add_event::<PredictionMatch>()
.add_event::<NearestTrajectories>()
.add_systems(PreStartup, load_motion_data)
.add_systems(
Update,
(
flow.in_set(MotionMatchingSet::Flow),
prediction_match.in_set(MotionMatchingSet::PredictionMatch),
trajectory_match
.in_set(MotionMatchingSet::GlobalMatch)
.run_if(in_state(Method::BruteForceKNN)),
pose_match,
),
);
}
}

Expand Down Expand Up @@ -171,7 +201,10 @@ fn trajectory_match(
trajectory_config: Res<TrajectoryConfig>,
match_config: Res<MatchConfig>,
mut nearest_trajectories_evw: EventWriter<NearestTrajectories>,
mut motion_matching_result: ResMut<MotionMatchingResult>,
) {
println!("Brute Force KNN Method");
PEAK_ALLOC.reset_peak_usage();
let Some(motion_data) = motion_data.get() else {
return;
};
Expand All @@ -196,8 +229,10 @@ fn trajectory_match(
})
.collect::<Vec<_>>();

// println!("current traj: {:?}", traj);
let mut nearest_trajs = Vec::with_capacity(match_config.max_match_count);

let start_time = Instant::now();
for (chunk_index, chunk) in motion_data.trajectory_data.iter_chunk().enumerate() {
// Number of trajectory in this chunk.
let num_trajectories = chunk.len() - num_segments;
Expand Down Expand Up @@ -250,6 +285,23 @@ fn trajectory_match(
}
}

let knn_search_peak_memory = PEAK_ALLOC.peak_usage_as_mb();
let traj_duration = start_time.elapsed().as_secs_f64() * 1000.0;

let runs = motion_matching_result.matching_result.runs + 1;

motion_matching_result.matching_result.avg_time =
(motion_matching_result.matching_result.avg_time
* motion_matching_result.matching_result.runs as f64
+ traj_duration)
/ runs as f64;
motion_matching_result.matching_result.avg_memory =
(motion_matching_result.matching_result.avg_memory
* motion_matching_result.matching_result.runs as f64
+ knn_search_peak_memory as f64)
/ runs as f64;
motion_matching_result.matching_result.runs = runs;

nearest_trajectories_evw.send(NearestTrajectories {
trajectories: nearest_trajs,
entity,
Expand Down Expand Up @@ -382,3 +434,11 @@ pub struct MatchConfig {
pub match_threshold: f32,
pub pred_match_threshold: f32,
}

#[derive(SystemSet, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MotionMatchingSet {
Flow,
PredictionMatch,
GlobalMatch,
PoseMatch,
}
Loading