Skip to content

Commit

Permalink
feat: stream foreground extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Mar 10, 2024
1 parent ba8a5c0 commit e575a21
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 61 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ person_matting = ["bevy_ort", "ort", "ndarray"]
[dependencies]
anyhow = "1.0"
async-compat = "0.2"
bevy_ort = { version = "0.5", optional = true }
bevy_args = "1.3"
bevy_ort = { version = "0.6", optional = true }
bytes = "1.5"
clap = { version = "4.4", features = ["derive"] }
futures = "0.3"
ndarray = { version = "0.15", optional = true }
openh264 = "0.5"
serde = "1.0"
serde_qs = "0.12"
retina = "0.4"
tokio = { version = "1.36", features = ["full"] }
url = "2.5"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rust bevy light field camera array tooling
- [X] stream to files with recording controls
- [X] person segmentation post-process (batch across streams)
- [X] async segmentation model inference
- [ ] foreground extraction post-process and visualization mode
- [X] foreground extraction post-process and visualization mode
- [ ] playback nersemble recordings with annotations
- [ ] camera array calibration
- [ ] 3d reconstruction dataset preparation
Expand Down
Binary file added assets/fonts/Caveat-Bold.ttf
Binary file not shown.
Binary file added assets/fonts/Caveat-Medium.ttf
Binary file not shown.
Binary file added assets/fonts/Caveat-Regular.ttf
Binary file not shown.
Binary file added assets/fonts/Caveat-SemiBold.ttf
Binary file not shown.
12 changes: 12 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
use bevy::prelude::*;

#[cfg(feature = "person_matting")]
pub mod matting;

pub mod materials;
pub mod mp4;
pub mod stream;


pub struct LightFieldPlugin;
impl Plugin for LightFieldPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(materials::StreamMaterialsPlugin);
app.add_plugins(stream::RtspStreamPlugin);
}
}
40 changes: 40 additions & 0 deletions src/materials/foreground.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use bevy::{
prelude::*,
asset::load_internal_asset,
render::render_resource::*,
};


const FOREGROUND_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(5231534123);

pub struct ForegroundPlugin;
impl Plugin for ForegroundPlugin {
fn build(&self, app: &mut App) {
load_internal_asset!(
app,
FOREGROUND_SHADER_HANDLE,
"foreground.wgsl",
Shader::from_wgsl
);

app.add_plugins(UiMaterialPlugin::<ForegroundMaterial>::default());
}
}


#[derive(AsBindGroup, Asset, TypePath, Debug, Clone)]
pub struct ForegroundMaterial {
#[texture(0)]
#[sampler(1)]
pub input: Handle<Image>,

#[texture(2)]
#[sampler(3)]
pub mask: Handle<Image>,
}

impl UiMaterial for ForegroundMaterial {
fn fragment_shader() -> ShaderRef {
ShaderRef::Handle(FOREGROUND_SHADER_HANDLE)
}
}
22 changes: 22 additions & 0 deletions src/materials/foreground.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#import bevy_ui::ui_vertex_output::UiVertexOutput


@group(1) @binding(0) var foreground_texture: texture_2d<f32>;
@group(1) @binding(1) var foreground_sampler: sampler;

@group(1) @binding(2) var mask_texture: texture_2d<f32>;
@group(1) @binding(3) var mask_sampler: sampler;


@fragment
fn fragment(in: UiVertexOutput) -> @location(0) vec4<f32> {
return textureSample(
foreground_texture,
foreground_sampler,
in.uv,
) * textureSample(
mask_texture,
mask_sampler,
in.uv,
).x;
}
11 changes: 11 additions & 0 deletions src/materials/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use bevy::prelude::*;

pub mod foreground;


pub struct StreamMaterialsPlugin;
impl Plugin for StreamMaterialsPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(foreground::ForegroundPlugin);
}
}
45 changes: 35 additions & 10 deletions src/matting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,42 @@ use bevy_ort::{
Onnx,
};

use crate::stream::StreamId;
use crate::{
materials::foreground::ForegroundMaterial,
stream::StreamId,
};


#[derive(Component, Clone, Debug, Reflect)]
pub struct MattedStream {
pub stream_id: StreamId,
pub input: Handle<Image>,
pub output: Handle<Image>,
pub material: Handle<ForegroundMaterial>,
}


pub struct MattingPlugin;
#[derive(Resource, Default, Clone)]
pub struct InferenceSize(pub (u32, u32));

pub struct MattingPlugin {
pub max_inference_size: InferenceSize,
}

impl MattingPlugin {
pub fn new(max_inference_size: (u32, u32)) -> Self {
MattingPlugin {
max_inference_size: InferenceSize(max_inference_size),
}
}
}

impl Plugin for MattingPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(BevyOrtPlugin);
app.register_type::<MattedStream>();
app.init_resource::<Modnet>();
app.insert_resource(self.max_inference_size.clone());
app.add_systems(Startup, load_modnet);
app.add_systems(Update, matting_inference);
}
Expand Down Expand Up @@ -67,6 +86,7 @@ fn matting_inference(
>,
onnx_assets: Res<Assets<Onnx>>,
mut pipeline_local: Local<ModnetComputePipeline>,
inference_size: Res<InferenceSize>,
) {
if let Some(pipeline) = pipeline_local.0.as_mut() {
if let Some(mut commands_queue) = block_on(future::poll_once(pipeline)) {
Expand All @@ -77,6 +97,10 @@ fn matting_inference(
return;
}

if matted_streams.is_empty() {
return;
}

let thread_pool = AsyncComputeTaskPool::get();

let inputs = matted_streams.iter()
Expand All @@ -90,10 +114,9 @@ fn matting_inference(
return;
}

let max_inference_size = (256, 256).into();
let input = images_to_modnet_input(
inputs,
max_inference_size,
inputs.as_slice(),
inference_size.0.into(),
);

if onnx_assets.get(&modnet.onnx).is_none() {
Expand All @@ -104,7 +127,7 @@ fn matting_inference(
let session_arc = onnx.session.clone();

let outputs = matted_streams.iter()
.map(|(_, matted_stream)| matted_stream.output.clone())
.map(|(_, matted_stream)| (matted_stream.output.clone(), matted_stream.material.clone()))
.collect::<Vec<_>>();

let task = thread_pool.spawn(async move {
Expand All @@ -126,11 +149,13 @@ fn matting_inference(
let mut command_queue = CommandQueue::default();

command_queue.push(move |world: &mut World| {
let mut images = world.get_resource_mut::<Assets<Image>>().unwrap();

outputs.iter()
.for_each(|output| {
images.insert(output, mask_images.pop().unwrap());
.for_each(|(mask, material)| {
let mut images = world.get_resource_mut::<Assets<Image>>().unwrap();
images.insert(mask, mask_images.pop().unwrap());

let mut foreground_materials = world.get_resource_mut::<Assets<ForegroundMaterial>>().unwrap();
foreground_materials.get_mut(material).unwrap();
});
});

Expand Down
Loading

0 comments on commit e575a21

Please sign in to comment.