Skip to content

Commit

Permalink
set_output_names in python
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 20, 2023
1 parent db4f3a9 commit d43c448
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 32 deletions.
13 changes: 13 additions & 0 deletions ffi/py/tests/mobilenet_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def test_inference_model():
assert str(model.output_fact(0)) == "B,1000,F32"
typed = model.into_typed()

def test_set_output_names_on_inference_model():
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")
model.set_input_fact(0, "B,3,224,224,f32")
model.set_output_fact(0, None)
model.analyse()
model.set_output_names(["mobilenetv20_output_pred_fwd"])
assert str(model.output_fact(0)) == "B,1000,1,1,F32"

def test_typed_model():
model = tract.nnef().model_for_path("mobilenet_v2_1.0.onnx.nnef.tgz")
assert model.input_count() == 1
Expand All @@ -72,6 +80,11 @@ def test_typed_model():
assert str(model.output_fact(0)) == "1,1000,F32"
model.declutter()

def test_set_output_names():
model = tract.nnef().model_for_path("mobilenet_v2_1.0.onnx.nnef.tgz")
model.set_output_names(["conv_53"])
assert str(model.output_fact(0)) == "1,1000,1,1,F32"

def test_concretize():
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")
model.set_input_fact(0, "B,3,224,224,f32")
Expand Down
11 changes: 11 additions & 0 deletions ffi/py/tract/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def set_input_fact(self, input_id: int, fact: Union[InferenceFact, str, None]) -
else:
check(lib.tract_inference_model_set_input_fact(self.ptr, input_id, fact.ptr))

def set_output_names(self, names: List[str]):
"""Change the output nodes of the model"""
self._valid()
nb = len(names)
names_str = []
names_ptr = (c_char_p * nb)()
for ix, n in enumerate(names):
names_str.append(str(n).encode("utf-8"))
names_ptr[ix] = names_str[ix]
check(lib.tract_inference_model_set_output_names(self.ptr, nb, names_ptr))

def output_name(self, output_id: int) -> str:
"""Return the name of the `output_id`th output."""
self._valid()
Expand Down
71 changes: 45 additions & 26 deletions ffi/py/tract/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class Model:
"""
Main model object
# Main model object
## Central focus point of the model transformation pipeline
Expand All @@ -23,27 +23,29 @@ class Model:
## Model cooking
But some model transformations can be peformed on the Model class:
* declutter (getting rid of training artefacts)
* "pulsification" (transforming a batch-oriented model into a streaming model)
* symbol substitution (make N or Batch a fixed number, unlocking potential optimisation later on)
* static cost evalation and dynamic profiling
* ...
But some model transformations can be peformed on the `Model` class:
* declutter (getting rid of training artefacts)
* "pulsification" (transforming a batch-oriented model into a streaming model)
* symbol substitution (make N or Batch a fixed number, unlocking potential optimisation later on)
* static cost evalation and dynamic profiling
* ...
In some situation, these operation are done "on-the-fly" when a ONNX or NNEF model is loaded,
at start-up time. In other situation, when start-up time becomes an issue, it may be beneficial
to "pre-cook" the model: apply the transformations one time, serialize the model as NNEF (with
tract-opl extension if needed). At start-up this model can be significantly less expensive to
"cook" for inference.
# Model and TypedModel
## Model and TypedModel
This class is actually a wrapper around the "TypedModel" in Rust codebase. The "Typed"
bit means than all shapes and element types in all input, output and temporary values must
known. There is support in tract for symbols in dimensions, with some limited computation
capabilities on symbolic expression. For instance, it is relatively frequent to work with
a Model where all tensors shapes start with the `N` or `Batch`.
"""

def __init__(self, ptr):
self.ptr = ptr

Expand All @@ -55,60 +57,71 @@ def _valid(self):
if self.ptr == None:
raise TractError("invalid model (maybe already consumed ?)")

"""Return the number of inputs of the model"""
def input_count(self) -> int:
"""Return the number of inputs of the model"""
self._valid()
i = c_size_t()
check(lib.tract_model_nbio(self.ptr, byref(i), None))
return i.value

"""Return the number of outputs of the model"""
def output_count(self) -> int:
"""Return the number of outputs of the model"""
self._valid()
i = c_size_t()
check(lib.tract_model_nbio(self.ptr, None, byref(i)))
return i.value

"""Return the name of the input_id-th input"""
def input_name(self, input_id: int) -> str:
"""Return the name of the input_id-th input"""
self._valid()
cstring = c_char_p()
check(lib.tract_model_input_name(self.ptr, input_id, byref(cstring)))
result = str(cstring.value, "utf-8")
lib.tract_free_cstring(cstring)
return result

"""Return the fact of the input_id-th input"""
def input_fact(self, input_id: int) -> Fact:
"""Return the fact of the input_id-th input"""
self._valid()
fact = c_void_p()
check(lib.tract_model_input_fact(self.ptr, input_id, byref(fact)))
return Fact(fact)

"""Return the name of the output_id-th output"""
def set_output_names(self, names: List[str]):
"""Change the output nodes of the model"""
self._valid()
nb = len(names)
names_str = []
names_ptr = (c_char_p * nb)()
for ix, n in enumerate(names):
names_str.append(str(n).encode("utf-8"))
names_ptr[ix] = names_str[ix]
check(lib.tract_model_set_output_names(self.ptr, nb, names_ptr))

def output_name(self, output_id: int) -> str:
"""Return the name of the output_id-th output"""
self._valid()
cstring = c_char_p()
check(lib.tract_model_output_name(self.ptr, output_id, byref(cstring)))
result = str(cstring.value, "utf-8")
lib.tract_free_cstring(cstring)
return result

"""Return the fact of the output_id-th output"""
def output_fact(self, input_id: int) -> Fact:
"""Return the fact of the output_id-th output"""
self._valid()
fact = c_void_p()
check(lib.tract_model_output_fact(self.ptr, input_id, byref(fact)))
return Fact(fact)

"""Substitute symbols by a value
def concretize_symbols(self, values: Dict[str, int]) -> None:
"""Substitute symbols by a value
Replace all occurencies of the symbols in the dictionary, in all the Model facts shapes.
Replace all occurencies of the symbols in the dictionary, in all the Model facts shapes.
While this is not strictly necesary, the optimizing steps may make better choices if the model
is informed of some specific symbol values.
"""
def concretize_symbols(self, values: Dict[str, int]) -> None:
While this is not strictly necesary, the optimizing steps may make better choices if the model
is informed of some specific symbol values.
"""
self._valid()
nb = len(values)
names_str = []
Expand All @@ -120,39 +133,43 @@ def concretize_symbols(self, values: Dict[str, int]) -> None:
values_list[ix] = v
check(lib.tract_model_concretize_symbols(self.ptr, c_size_t(nb), names, values_list))

"""Pulsify a model.
`pulse` is typically a one-length dictionary mapping the time dimension symbol to a pulse len.
"""
def pulse(self, symbol: str, pulse: Union[str, int]) -> None:
"""Pulsify a model.
`pulse` is typically a one-length dictionary mapping the time dimension symbol to a pulse len.
"""
self._valid()
check(lib.tract_model_pulse_simple(byref(self.ptr), symbol.encode("utf-8"), str(pulse).encode("utf-8")))

def declutter(self) -> None:
"""Declutter the model in place"""
self._valid()
check(lib.tract_model_declutter(self.ptr))

def optimize(self) -> None:
"""Optimize the model in place"""
self._valid()
check(lib.tract_model_optimize(self.ptr))

"""Convenience method performing `declutter()` and returning the model"""
def into_decluttered(self) -> "Model":
"""Convenience method performing `declutter()` and returning the model"""
self.declutter();
return self

"""Convenience method performing `optimize()` and returning the model"""
def into_optimized(self) -> "Model":
"""Convenience method performing `optimize()` and returning the model"""
self.optimize()
return self

def into_runnable(self) -> Runnable:
"""Transform the model into a Runnable model ready to be used"""
self._valid()
runnable = c_void_p()
check(lib.tract_model_into_runnable(byref(self.ptr), byref(runnable)))
return Runnable(runnable)

def property_keys(self) -> List[str]:
"""Extract the list of properties from a model"""
self._valid()
count = c_size_t()
check(lib.tract_model_property_count(self.ptr, byref(count)))
Expand All @@ -166,12 +183,14 @@ def property_keys(self) -> List[str]:
return names

def property(self, name: str) -> Value:
"""Query a property by name"""
self._valid()
value = c_void_p()
check(lib.tract_model_property(self.ptr, str(name).encode("utf-8"), byref(value)))
return Value(value)

def profile_json(self, inputs: Union[None, List[Union[Value, numpy.ndarray]]]) -> str:
"""Profile the model on the provided input"""
self._valid()
cstring = c_char_p()
input_values = []
Expand Down
75 changes: 69 additions & 6 deletions ffi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#![allow(clippy::missing_safety_doc)]

use anyhow::Context;
use tract_libcli::annotations::Annotations;
use tract_libcli::profile::BenchLimits;
use std::cell::RefCell;
use std::convert::TryFrom;
use std::ffi::{c_char, c_void, CStr, CString};
use std::sync::Arc;
use tract_data::internal::parse_tdim;
use tract_libcli::annotations::Annotations;
use tract_libcli::profile::BenchLimits;
use tract_pulse::model::{PulsedModel, PulsedModelExt};

use tract_nnef::internal as native;
Expand Down Expand Up @@ -233,7 +233,10 @@ pub unsafe extern "C" fn tract_nnef_enable_pulse(nnef: *mut TractNnef) -> TRACT_
}

#[no_mangle]
pub unsafe extern "C" fn tract_nnef_allow_extended_identifier_syntax(nnef: *mut TractNnef, enable: bool) -> TRACT_RESULT {
pub unsafe extern "C" fn tract_nnef_allow_extended_identifier_syntax(
nnef: *mut TractNnef,
enable: bool,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
(*nnef).0.allow_extended_identifier_syntax(enable);
Expand Down Expand Up @@ -469,6 +472,25 @@ pub unsafe extern "C" fn tract_inference_model_set_input_fact(
})
}

/// Change the model outputs nodes (by name).
///
/// `names` is an array containing `len` pointers to null terminated strings.
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_set_output_names(
model: *mut TractInferenceModel,
len: usize,
names: *const *const c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, names, *names);
let node_names = (0..len)
.map(|i| Ok(CStr::from_ptr(*names.add(i)).to_str()?.to_owned()))
.collect::<TractResult<Vec<_>>>()?;
(*model).0.set_output_names(&node_names)?;
Ok(())
})
}

/// Query an output fact for an InferenceModel.
///
/// The return model must be freed using `tract_inference_fact_destroy`.
Expand Down Expand Up @@ -648,6 +670,7 @@ pub unsafe extern "C" fn tract_model_output_name(
})
}

/// Query the fact of a model output.
#[no_mangle]
pub unsafe extern "C" fn tract_model_output_fact(
model: *const TractModel,
Expand All @@ -663,6 +686,31 @@ pub unsafe extern "C" fn tract_model_output_fact(
})
}

/// Change the model outputs nodes (by name).
///
/// `names` is an array containing `len` pointers to null terminated strings.
#[no_mangle]
pub unsafe extern "C" fn tract_model_set_output_names(
model: *mut TractModel,
len: usize,
names: *const *const c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, names, *names);
let node_names = (0..len)
.map(|i| Ok(CStr::from_ptr(*names.add(i)).to_str()?.to_owned()))
.collect::<TractResult<Vec<_>>>()?;
dbg!(&node_names);
(*model).0.set_output_names(&node_names)?;
Ok(())
})
}

/// Give value one or more symbols used in the model.
///
/// * symbols is an array of `nb_symbols` pointers to null-terminated UTF-8 string for the symbols
/// names to substitue
/// * values is an array of `nb_symbols` integer values
#[no_mangle]
pub unsafe extern "C" fn tract_model_concretize_symbols(
model: *mut TractModel,
Expand All @@ -686,6 +734,11 @@ pub unsafe extern "C" fn tract_model_concretize_symbols(
})
}

/// Perform pulsification of the model
///
/// `stream_symbol` is a pointer to a null-terminated UTF-8 string reprensenting the symbol to be
/// used as the time
/// `pulse_expr` is the pulse value to be used, as a null-terminated UTF-8 string.
#[no_mangle]
pub unsafe extern "C" fn tract_model_pulse_simple(
model: *mut *mut TractModel,
Expand Down Expand Up @@ -730,6 +783,7 @@ pub unsafe extern "C" fn tract_model_optimize(model: *mut TractModel) -> TRACT_R
})
}

/// Perform a profile of the model using the provided inputs.
#[no_mangle]
pub unsafe extern "C" fn tract_model_profile_json(
model: *mut TractModel,
Expand All @@ -743,9 +797,18 @@ pub unsafe extern "C" fn tract_model_profile_json(
tract_libcli::profile::extract_costs(&mut annotations, model)?;
if !inputs.is_null() {
let input_len = model.inputs.len();
let values:TVec<TValue> =
std::slice::from_raw_parts(inputs, input_len).iter().map(|tv| (**tv).0.clone()).collect();
tract_libcli::profile::profile(model, &BenchLimits::default(), &mut annotations, &values, None, true)?;
let values: TVec<TValue> = std::slice::from_raw_parts(inputs, input_len)
.iter()
.map(|tv| (**tv).0.clone())
.collect();
tract_libcli::profile::profile(
model,
&BenchLimits::default(),
&mut annotations,
&values,
None,
true,
)?;
}
let export = tract_libcli::export::GraphPerfInfo::from(model, &annotations);
*json = CString::new(serde_json::to_string(&export)?)?.into_raw();
Expand Down

0 comments on commit d43c448

Please sign in to comment.