Skip to content

Commit

Permalink
feat: modnet inference example
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Mar 9, 2024
1 parent 0c25186 commit a6e01fc
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 36 deletions.
16 changes: 15 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_ort"
description = "bevy ort (onnxruntime) plugin"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
authors = ["mosure <mitchell@mosure.me>"]
license = "MIT"
Expand All @@ -26,9 +26,17 @@ exclude = [

default-run = "modnet"

[features]
default = [
"modnet",
]

modnet = []


[dependencies]
bevy_args = "1.3"
image = "0.24"
ndarray = "0.15"
ort = "2.0.0-alpha.4"
thiserror = "1.0"
Expand All @@ -40,6 +48,12 @@ version = "0.13"
default-features = false
features = [
"bevy_asset",
"bevy_core_pipeline",
"bevy_render",
"bevy_ui",
"bevy_winit",
"multi-threaded",
"png",
]


Expand Down
110 changes: 94 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,32 @@
a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library


![person](assets/person.png)
![mask](assets/mask.png)
*> modnet inference example*


## capabilities

- [X] load ONNX models as ORT session assets
- [X] initialize ORT with default execution providers

- [X] modnet bevy image <-> ort tensor IO (with feature `modnet`)
- [ ] compute task pool inference scheduling


## library usage

```rust
use bevy::prelude::*;

use bevy_ort::{
BevyOrtPlugin,
inputs,
models::modnet::{
image_to_modnet_input,
modnet_output_to_luma_image,
},
Onnx,
};


Expand All @@ -31,44 +45,108 @@ fn main() {
DefaultPlugins,
BevyOrtPlugin,
))
.add_systems(Startup, load_model)
.add_system(Update, inference)
.init_resource::<Modnet>()
.add_systems(Startup, load_modnet)
.add_systems(Update, inference)
.run();
}

fn load_model(
#[derive(Resource, Default)]
pub struct Modnet {
pub onnx: Handle<Onnx>,
pub input: Handle<Image>,
}

fn load_modnet(
asset_server: Res<AssetServer>,
mut modnet: ResMut<Modnet>,
) {
let model_handle: Handle<Onnx> = asset_server.load("path/to/model.onnx");
let modnet_handle: Handle<Onnx> = asset_server.load("modnet_photographic_portrait_matting.onnx");
modnet.onnx = modnet_handle;

let input_handle: Handle<Image> = asset_server.load("person.png");
modnet.input = input_handle;
}


fn inference(
asset_server: Res<AssetServer>,
mut models: ResMut<Assets<Onnx>>,
mut commands: Commands,
modnet: Res<Modnet>,
onnx_assets: Res<Assets<Onnx>>,
mut images: ResMut<Assets<Image>>,
mut complete: Local<bool>,
) {
let model_handle: Handle<Onnx> = todo!();

if Some(LoadState::Loaded) == asset_server.get_load_state(model_handle) {
let model: &Onnx = models.get(model_handle).unwrap();
if *complete {
return;
}

if let Some(session) = &model.session {
let input_values = todo!();
let outputs = session.run(input_values).unwrap();
let image = images.get(&modnet.input).expect("failed to get image asset");
let input = image_to_modnet_input(image);

let output: Result<ort::SessionOutputs<'_>, String> = (|| {
let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?;
let session = onnx.session.as_ref().ok_or("failed to get session from ONNX asset")?;

let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string())?;
session.run(input_values).map_err(|e| e.to_string())
})();

match output {
Ok(output) => {
let output_value: &ort::Value = output.get("output").unwrap();

let mask_image = modnet_output_to_luma_image(output_value);
let mask_image = images.add(mask_image);

commands.spawn(NodeBundle {
style: Style {
display: Display::Grid,
width: Val::Percent(100.0),
height: Val::Percent(100.0),
grid_template_columns: RepeatedGridTrack::flex(1, 1.0),
grid_template_rows: RepeatedGridTrack::flex(1, 1.0),
..default()
},
background_color: BackgroundColor(Color::DARK_GRAY),
..default()
})
.with_children(|builder| {
builder.spawn(ImageBundle {
style: Style {
..default()
},
image: UiImage::new(mask_image.clone()),
..default()
});
});

commands.spawn(Camera2dBundle::default());

*complete = true;
},
Err(error) => {
println!("inference failed: {}", error);
}
}
}

```


## run the example person segmentation model
## run the example person segmentation model (modnet)

```sh
cargo run --bin modnet -- --input assets/person.jpg
cargo run
```

> note: if you use `pip install onnxruntime`, you may need to run `ORT_STRATEGY=system cargo run`, see: https://docs.rs/ort/latest/ort/#how-to-get-binaries

## compatible bevy versions

| `bevy_ort` | `bevy` |
| :-- | :-- |
| `0.1.0` | `0.13` |

## credits
- [modnet](https://github.com/ZHKKKe/MODNet)
Binary file added assets/mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/person.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub use ort::{
Session,
};

pub mod models;


pub struct BevyOrtPlugin;
impl Plugin for BevyOrtPlugin {
Expand Down
2 changes: 2 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#[cfg(feature = "modnet")]
pub mod modnet;
118 changes: 118 additions & 0 deletions src/models/modnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::cmp::{
max,
min,
};

use bevy::{prelude::*, render::render_asset::RenderAssetUsages};
use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage};
use ndarray::{Array, Array4, ArrayView4, Axis};


pub fn modnet_output_to_luma_image(
output_value: &ort::Value,
) -> Image {
let tensor: ort::Tensor<f32> = output_value.extract_tensor::<f32>().unwrap();

let data = tensor.view();

let shape = data.shape();
let width = shape[3];
let height = shape[2];

let tensor_data = ArrayView4::from_shape((1, 1, height, width), data.as_slice().unwrap())
.expect("Failed to create ArrayView4 from shape and data");

let mut imgbuf = ImageBuffer::<Luma<u8>, Vec<u8>>::new(width as u32, height as u32);

for y in 0..height {
for x in 0..width {
let pixel_value = tensor_data[(0, 0, y as usize, x as usize)];
let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8;
imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value]));
}
}

let dyn_img = DynamicImage::ImageLuma8(imgbuf);

Image::from_dynamic(dyn_img, false, RenderAssetUsages::all())
}


pub fn image_to_modnet_input(
image: &Image,
) -> Array4<f32> {
assert_eq!(image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb);

let ref_size = 512;
let (
x_scale,
y_scale,
) = get_scale_factor(
image.height(),
image.width(),
ref_size,
);

let resized_image = resize_image(
&image.clone().try_into_dynamic().unwrap(),
x_scale,
y_scale,
);

image_to_ndarray(&resized_image)
}


fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32) -> (f32, f32) {
let mut im_rh;
let mut im_rw;

if max(im_h, im_w) < ref_size || min(im_h, im_w) > ref_size {
if im_w >= im_h {
im_rh = ref_size;
im_rw = (im_w as f32 / im_h as f32 * ref_size as f32) as u32;
} else {
im_rw = ref_size;
im_rh = (im_h as f32 / im_w as f32 * ref_size as f32) as u32;
}
} else {
im_rh = im_h;
im_rw = im_w;
}

im_rw = im_rw - im_rw % 32;
im_rh = im_rh - im_rh % 32;

(im_rw as f32 / im_w as f32, im_rh as f32 / im_h as f32)
}


fn image_to_ndarray(img: &RgbImage) -> Array4<f32> {
let (width, height) = img.dimensions();

// convert RgbImage to a Vec of f32 values normalized to [-1, 1]
let raw: Vec<f32> = img.pixels()
.flat_map(|p| {
p.0.iter().map(|&e| {
(e as f32 - 127.5) / 127.5
})
})
.collect();

// create a 3D array from the raw pixel data
let arr = Array::from_shape_vec((height as usize, width as usize, 3), raw)
.expect("failed to create ndarray from image raw data");

// rearrange the dimensions from [height, width, channels] to [1, channels, height, width]
let arr = arr.permuted_axes([2, 0, 1]).insert_axis(Axis(0));

arr
}

fn resize_image(image: &DynamicImage, x_scale: f32, y_scale: f32) -> RgbImage {
let (width, height) = image.dimensions();
let new_width = (width as f32 * x_scale) as u32;
let new_height = (height as f32 * y_scale) as u32;

image.resize_exact(new_width, new_height, FilterType::Triangle).to_rgb8()
}
Loading

0 comments on commit a6e01fc

Please sign in to comment.