Skip to content

Commit

Permalink
feat: async person segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Mar 9, 2024
1 parent 0e3fd44 commit ba8a5c0
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 17 deletions.
32 changes: 28 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,35 @@ jobs:
with:
toolchain: ${{ matrix.rust-toolchain }}

- name: Install ONNX Runtime on Windows
if: matrix.os == 'windows-latest'
run: |
Invoke-WebRequest -Uri "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-win-x64-1.17.1.zip" -OutFile "onnxruntime.zip"
Expand-Archive -Path "onnxruntime.zip" -DestinationPath "$env:RUNNER_TEMP"
echo "ONNXRUNTIME_DIR=$env:RUNNER_TEMP\onnxruntime-win-x64-1.17.1" | Out-File -Append -Encoding ascii $env:GITHUB_ENV
- name: Install ONNX Runtime on macOS
if: matrix.os == 'macos-latest'
run: |
curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz"
mkdir -p $HOME/onnxruntime
tar -xzf onnxruntime.tgz -C $HOME/onnxruntime
echo "ONNXRUNTIME_DIR=$HOME/onnxruntime/onnxruntime-osx-x86_64-1.17.1" >> $GITHUB_ENV
- name: Set ONNX Runtime library path for macOS
if: matrix.os == 'macos-latest'
run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/libonnxruntime.dylib" >> $GITHUB_ENV

- name: Set ONNX Runtime library path for Windows
if: matrix.os == 'windows-latest'
run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/onnxruntime.dll" >> $GITHUB_ENV


- name: lint
run: cargo clippy -- -Dwarnings

- name: build
run: cargo build

# - name: build (web)
# run: cargo build --example=minimal --target wasm32-unknown-unknown --release
run: cargo build --features "ort/load-dynamic"
env:
ORT_DYLIB_PATH: ${{ env.ORT_DYLIB_PATH }}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ www/assets/
*.mp4

mediamtx/
onnxruntime/
23 changes: 22 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@ exclude = [

default-run = "viewer"

[features]
default = [
"person_matting",
]

person_matting = ["bevy_ort", "ort", "ndarray"]


[dependencies]
anyhow = "1.0"
async-compat = "0.2"
bytes = "1.5.0"
bevy_ort = { version = "0.5", optional = true }
bytes = "1.5"
futures = "0.3"
ndarray = { version = "0.15", optional = true }
openh264 = "0.5"
retina = "0.4"
tokio = { version = "1.36", features = ["full"] }
Expand All @@ -50,6 +59,18 @@ features = [
]


[dependencies.ort]
version = "2.0.0-alpha.4"
optional = true
default-features = false
features = [
"cuda",
"load-dynamic",
"ndarray",
"openvino",
]


[profile.dev.package."*"]
opt-level = 3

Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ rust bevy light field camera array tooling

- [X] grid view of light field camera array
- [X] stream to files with recording controls
- [X] person segmentation post-process (batch across streams)
- [X] async segmentation model inference
- [ ] foreground extraction post-process and visualization mode
- [ ] playback nersemble recordings with annotations
- [ ] person segmentation post-process (batch across streams)
- [ ] camera array calibration
- [ ] 3d reconstruction dataset preparation
- [ ] real-time 3d reconstruction viewer
Expand All @@ -27,6 +29,11 @@ rust bevy light field camera array tooling

the viewer opens a window and displays the light field camera array, with post-process options

> see execution provider [bevy_ort documentation](https://github.com/mosure/bevy_ort?tab=readme-ov-file#run-the-example-person-segmentation-model-modnet) for better performance
- windows: `cargo run --release --features "ort/cuda"`


### controls

- `r` to start recording
Expand Down Expand Up @@ -156,5 +163,6 @@ it is useful to test the light field viewer with emulated camera streams
## credits
- [bevy_video](https://github.com/PortalCloudInc/bevy_video)
- [gaussian_avatars](https://github.com/ShenhanQian/GaussianAvatars)
- [modnet](https://github.com/ZHKKKe/MODNet)
- [nersemble](https://github.com/tobias-kirschstein/nersemble)
- [paddle_seg_matting](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.9/Matting/docs/quick_start_en.md)
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
#[cfg(feature = "person_matting")]
pub mod matting;

pub mod mp4;
pub mod stream;
147 changes: 147 additions & 0 deletions src/matting.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
use bevy::{
prelude::*,
ecs::system::CommandQueue,
tasks::{block_on, futures_lite::future, AsyncComputeTaskPool, Task},
};
use bevy_ort::{
BevyOrtPlugin,
inputs,
models::modnet::{
images_to_modnet_input,
modnet_output_to_luma_images,
},
Onnx,
};

use crate::stream::StreamId;


#[derive(Component, Clone, Debug, Reflect)]
pub struct MattedStream {
pub stream_id: StreamId,
pub input: Handle<Image>,
pub output: Handle<Image>,
}


pub struct MattingPlugin;
impl Plugin for MattingPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(BevyOrtPlugin);
app.register_type::<MattedStream>();
app.init_resource::<Modnet>();
app.add_systems(Startup, load_modnet);
app.add_systems(Update, matting_inference);
}
}


#[derive(Resource, Default)]
pub struct Modnet {
pub onnx: Handle<Onnx>,
}


fn load_modnet(
asset_server: Res<AssetServer>,
mut modnet: ResMut<Modnet>,
) {
let modnet_handle: Handle<Onnx> = asset_server.load("modnet_photographic_portrait_matting.onnx");
modnet.onnx = modnet_handle;
}


#[derive(Default)]
struct ModnetComputePipeline(Option<Task<CommandQueue>>);


fn matting_inference(
mut commands: Commands,
images: Res<Assets<Image>>,
modnet: Res<Modnet>,
matted_streams: Query<
(
Entity,
&MattedStream,
)
>,
onnx_assets: Res<Assets<Onnx>>,
mut pipeline_local: Local<ModnetComputePipeline>,
) {
if let Some(pipeline) = pipeline_local.0.as_mut() {
if let Some(mut commands_queue) = block_on(future::poll_once(pipeline)) {
commands.append(&mut commands_queue);
pipeline_local.0 = None;
}

return;
}

let thread_pool = AsyncComputeTaskPool::get();

let inputs = matted_streams.iter()
.map(|(_, matted_stream)| {
images.get(matted_stream.input.clone()).unwrap()
})
.collect::<Vec<_>>();

let uninitialized = inputs.iter().any(|image| image.size() == (32, 32).into());
if uninitialized {
return;
}

let max_inference_size = (256, 256).into();
let input = images_to_modnet_input(
inputs,
max_inference_size,
);

if onnx_assets.get(&modnet.onnx).is_none() {
return;
}

let onnx = onnx_assets.get(&modnet.onnx).unwrap();
let session_arc = onnx.session.clone();

let outputs = matted_streams.iter()
.map(|(_, matted_stream)| matted_stream.output.clone())
.collect::<Vec<_>>();

let task = thread_pool.spawn(async move {
let mask_images: Result<Vec<Image>, String> = (|| {
let session_lock = session_arc.lock().map_err(|e| e.to_string())?;
let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;

let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string())?;
let outputs = session.run(input_values).map_err(|e| e.to_string());

let binding = outputs.ok().unwrap();
let output_value: &ort::Value = binding.get("output").unwrap();

Ok(modnet_output_to_luma_images(output_value))
})();

match mask_images {
Ok(mut mask_images) => {
let mut command_queue = CommandQueue::default();

command_queue.push(move |world: &mut World| {
let mut images = world.get_resource_mut::<Assets<Image>>().unwrap();

outputs.iter()
.for_each(|output| {
images.insert(output, mask_images.pop().unwrap());
});
});

command_queue
},
Err(error) => {
eprintln!("inference failed: {}", error);
CommandQueue::default()
}
}
});

*pipeline_local = ModnetComputePipeline(Some(task));
}
Loading

0 comments on commit ba8a5c0

Please sign in to comment.