diff --git a/README.md b/README.md index ce07e9d..0ad74f5 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,10 @@ Call that from your main program to run inference. Function parameters are named Using the compiler `-ffast-math` (or equivalent) when compiling onnx2c-generated code increases computation speed. See the [GCC wiki on floating point maths](https://gcc.gnu.org/wiki/FloatingPointMath) for details. -Onnx2c has an [experimental quantization option](quantization.md) to convert floating point calculation to integers. +Onnx2c has a few optimization passes that modify the generated output: + - Tensor unionization to wrap intermediate tensors in unions to help the compiler re-use the heap memory. + - Optimization for AVR processors to put constants into instruction memory. + - An [experimental quantization option](quantization.md) to convert floating point calculation to integers. `./onnx2c -h` prints out all available command line options. diff --git a/src/main.cc b/src/main.cc index 88e1d49..0e5bbdd 100644 --- a/src/main.cc +++ b/src/main.cc @@ -24,7 +24,8 @@ int main(int argc, const char *argv[]) std::cout.precision(20); toC::Graph toCgraph(onnx_model); - toCgraph.unionize_tensors(); + if( options.opt_unionize ) + toCgraph.unionize_tensors(); toCgraph.print_source(std::cout); } diff --git a/src/options.cc b/src/options.cc index 60119d6..b0cc197 100644 --- a/src/options.cc +++ b/src/options.cc @@ -64,12 +64,54 @@ void store_define_option(const std::string &opt) options.dim_defines[name] = val_num; } +void print_optimization_passes(void) +{ + std::cout << "Available optimization passes:" << std::endl; + std::cout << " - 'unionize' (defaut:on)" << std::endl; + std::cout << " - 'none' (disable all optimization passes)" << std::endl; +} + +void store_optimization_passes(const std::string &opt) +{ + LOG(TRACE) << "Parsing optimizations: " << opt << std::endl; + + // disable all optimizations (i.e. override the default settings) + // then enable those that were requested + options.opt_unionize=false; + if( opt == "none" ) + { + LOG(TRACE) << "Disabling all optimizations: " << opt << std::endl; + return; + } + + if( opt == "help" ) + { + print_optimization_passes(); + exit(0); + } + std::vector result; + std::stringstream ss (opt); + std::string item; + while (getline (ss, item, ',')) { + if( item == "unionize" ) + { + LOG(DEBUG) << "Enabling 'Unionize tensors' optimization pass" << std::endl; + options.opt_unionize=true; + } + else { + LOG(WARNING) << "Optimization pass " << item << " does not exist" << std::endl; + } + } + LOG(TRACE) << "That was all optimizations" << std::endl; +} + void parse_cmdline_options(int argc, const char *argv[]) { args::ArgumentParser parser("Generate C code from an ONNX graph file."); args::Flag avr(parser, "avr", "Target AVR-GCC", {'a', "avr"}); args::ValueFlagList define(parser, "dim:size", "Define graph input dimension. Can be given multiple times", {'d', "define"}); args::ValueFlag loglevel(parser, "level", "Logging verbosity. 0(none)-4(all)", {'l',"log"}); + args::ValueFlag optimizations(parser, "opt[,opt]...", "Specify optimization passes to run. ('help' to list available)", {'p', "optimizations"}); args::Flag help(parser, "help", "Print this help text.", {'h',"help"}); args::Flag quantize(parser, "quantize", "Quantize network (EXPERIMENTAL!)", {'q', "quantize"}); args::Flag version(parser, "version", "Print onnx2c version", {'v', "version"}); @@ -109,6 +151,7 @@ void parse_cmdline_options(int argc, const char *argv[]) store_define_option(d); } } + if (optimizations) { store_optimization_passes( args::get(optimizations) ); } if (input) { options.input_file = args::get(input); } if (options.input_file == "" ) { std::cerr << "No input file given"; hint_at_help_and_exit(); } } diff --git a/src/options.h b/src/options.h index a1e70f7..59e1789 100644 --- a/src/options.h +++ b/src/options.h @@ -11,6 +11,7 @@ struct onnx2c_opts { bool quantize=false; bool target_avr=false; + bool opt_unionize=true; /* * logging levels are * cmd line aixlog Use diff --git a/src/util.cc b/src/util.cc index 5047000..f309c93 100644 --- a/src/util.cc +++ b/src/util.cc @@ -6,9 +6,6 @@ #include "tensor.h" #include "util.h" -// command line option -extern bool target_avr; - std::string cify_name(const std::string &in) { // Replace all non-allowed characters with underscore diff --git a/test/onnx_backend_tests_generator.cc b/test/onnx_backend_tests_generator.cc index 9bd71de..7a9b70f 100644 --- a/test/onnx_backend_tests_generator.cc +++ b/test/onnx_backend_tests_generator.cc @@ -72,6 +72,9 @@ int main(int argc, char *argv[]) exit(1); } + options.logging_level = 1; + AixLog::Log::init(AixLog::Severity::error); + onnx::ModelProto onnx_model; std::string dir(argv[1]);