Skip to content

Commit

Permalink
feat: specify max inference size
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Mar 9, 2024
1 parent 6b7b48d commit d3389ca
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_ort"
description = "bevy ort (onnxruntime) plugin"
version = "0.3.0"
version = "0.4.0"
edition = "2021"
authors = ["mosure <mitchell@mosure.me>"]
license = "MIT"
Expand Down
69 changes: 38 additions & 31 deletions src/models/modnet.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
use std::cmp::{
max,
min,
};

use bevy::{prelude::*, render::render_asset::RenderAssetUsages};
use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage};
use ndarray::{Array, Array4, ArrayView4, Axis, s};
Expand Down Expand Up @@ -44,33 +39,39 @@ pub fn modnet_output_to_luma_images(
images
}


pub fn images_to_modnet_input(
images: Vec<&Image>,
max_size: Option<(u32, u32)>,
) -> Array4<f32> {
// TODO: better error handling
if images.is_empty() {
panic!("no images provided");
}

let ref_size = 512;

let &first_image = images.first().unwrap();
assert_eq!(first_image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb);

let dynamic_image = first_image.clone().try_into_dynamic().unwrap();
let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size);
let resized_image = resize_image(&dynamic_image, x_scale, y_scale);
let image = first_image.to_owned();

println!("image: {:?}", image.size());

let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size);
let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale);
let first_image_ndarray = image_to_ndarray(&resized_image);

println!("scale_factor: {:?}", (x_scale, y_scale));
println!("first_image_ndarray: {:?}", first_image_ndarray.dim());

let single_image_shape = first_image_ndarray.dim();
let n_images = images.len();
let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3);

let mut aggregate = Array4::<f32>::zeros(batch_shape);

for (i, &image) in images.iter().enumerate() {
let dynamic_image = image.clone().try_into_dynamic().unwrap();
let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size);
let resized_image = resize_image(&dynamic_image, x_scale, y_scale);
let image = image.to_owned();
let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size);
let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale);
let image_ndarray = image_to_ndarray(&resized_image);

let slice = s![i, .., .., ..];
Expand All @@ -81,27 +82,33 @@ pub fn images_to_modnet_input(
}


fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32) -> (f32, f32) {
let mut im_rh;
let mut im_rw;
fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, u32)>) -> (f32, f32) {
// Calculate the scale factor based on the maximum size constraints
let scale_factor_max = max_size.map_or(1.0, |(max_w, max_h)| {
f32::min(max_w as f32 / im_w as f32, max_h as f32 / im_h as f32)
});

if max(im_h, im_w) < ref_size || min(im_h, im_w) > ref_size {
if im_w >= im_h {
im_rh = ref_size;
im_rw = (im_w as f32 / im_h as f32 * ref_size as f32) as u32;
} else {
im_rw = ref_size;
im_rh = (im_h as f32 / im_w as f32 * ref_size as f32) as u32;
}
// Calculate the target dimensions after applying the max scale factor (clipping to max_size)
let (target_h, target_w) = ((im_h as f32 * scale_factor_max).round() as u32, (im_w as f32 * scale_factor_max).round() as u32);

// Calculate the scale factor to fit within the reference size, considering the target dimensions
let (scale_factor_ref_w, scale_factor_ref_h) = if std::cmp::max(target_h, target_w) < ref_size {
let scale_factor = ref_size as f32 / std::cmp::max(target_h, target_w) as f32;
(scale_factor, scale_factor)
} else {
im_rh = im_h;
im_rw = im_w;
}
(1.0, 1.0) // Do not upscale if target dimensions are within reference size
};

// Calculate the final scale factor as the minimum of the max scale factor and the reference scale factor
let final_scale_w = f32::min(scale_factor_max, scale_factor_ref_w);
let final_scale_h = f32::min(scale_factor_max, scale_factor_ref_h);

im_rw = im_rw - im_rw % 32;
im_rh = im_rh - im_rh % 32;
// Adjust dimensions to ensure they are multiples of 32
let final_w = ((im_w as f32 * final_scale_w).round() as u32) - ((im_w as f32 * final_scale_w).round() as u32) % 32;
let final_h = ((im_h as f32 * final_scale_h).round() as u32) - ((im_h as f32 * final_scale_h).round() as u32) % 32;

(im_rw as f32 / im_w as f32, im_rh as f32 / im_h as f32)
// Return the scale factors based on the original image dimensions
(final_w as f32 / im_w as f32, final_h as f32 / im_h as f32)
}


Expand Down
2 changes: 1 addition & 1 deletion tools/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn inference(
}

let image = images.get(&modnet.input).expect("failed to get image asset");
let input = images_to_modnet_input(vec![&image]);
let input = images_to_modnet_input(vec![&image], Some((256, 144)));

let output: Result<ort::SessionOutputs<'_>, String> = (|| {
let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?;
Expand Down

0 comments on commit d3389ca

Please sign in to comment.