From f861b9a4a66c7418878ce3fd6fdb01f7ba17ba62 Mon Sep 17 00:00:00 2001 From: Kalle Raiskila Date: Mon, 13 Feb 2023 09:16:01 +0100 Subject: [PATCH] Add fold_casts optimization pass. When possible remove Cast operands, modifying the output of the Cast's parent op instead. Bump to C++20, because of std::erase. --- CMakeLists.txt | 3 +- README.md | 1 + src/graph.h | 4 ++ src/main.cc | 2 + src/node.cc | 17 ++++++ src/node.h | 5 ++ src/optimization_passes/fold_casts.cpp | 83 ++++++++++++++++++++++++++ src/options.cc | 7 +++ src/options.h | 1 + 9 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 src/optimization_passes/fold_casts.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ed9703..2fd2e83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ project(onnx2c VERSION 0.0.1 LANGUAGES C CXX ) -set (CMAKE_CXX_STANDARD 17) +set (CMAKE_CXX_STANDARD 20) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Werror") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wfatal-errors") @@ -32,6 +32,7 @@ add_library(onnx2c_lib STATIC src/node.cc src/tensor.cc src/util.cc + src/optimization_passes/fold_casts.cpp src/optimization_passes/unionize_tensors.cpp ${CMAKE_CURRENT_BINARY_DIR}/onnx.pb.cc src/nodes/cast.cc diff --git a/README.md b/README.md index 99e16b6..6dbb1bc 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ See the [GCC wiki on floating point maths](https://gcc.gnu.org/wiki/FloatingPoin Onnx2c has a few optimization passes that modify the generated output: - Tensor unionization to wrap intermediate tensors in unions to help the compiler re-use the heap memory. + - Removing `Cast` nodes, by modifying their predecessor node's output tensor. - Optimization for AVR processors to put constants into instruction memory. - An [experimental quantization option](quantization.md) to convert floating point calculation to integers. diff --git a/src/graph.h b/src/graph.h index 3fe284f..d441208 100644 --- a/src/graph.h +++ b/src/graph.h @@ -40,6 +40,9 @@ class Graph { * unions. This make the memory buffers time shared. */ void unionize_tensors(void); + /* Optimization step: Fold Cast-nodes to their predecessor. */ + void fold_casts(void); + void addInitializedTensor(onnx::TensorProto &tensor); Tensor* getIoTensor(onnx::ValueInfoProto &vi); @@ -92,6 +95,7 @@ class Graph { std::vector tensor_unions; uint32_t add_to_free_union(Tensor *t); void mark_union_unoccupied(uint32_t); + }; } diff --git a/src/main.cc b/src/main.cc index 0e5bbdd..8fff570 100644 --- a/src/main.cc +++ b/src/main.cc @@ -24,6 +24,8 @@ int main(int argc, const char *argv[]) std::cout.precision(20); toC::Graph toCgraph(onnx_model); + if( options.opt_fold_casts ) + toCgraph.fold_casts(); if( options.opt_unionize ) toCgraph.unionize_tensors(); toCgraph.print_source(std::cout); diff --git a/src/node.cc b/src/node.cc index d441afe..9a2ca8f 100644 --- a/src/node.cc +++ b/src/node.cc @@ -203,3 +203,20 @@ unsigned Node::get_number_of_outputs(void) const { return output_params.size(); } + +bool Node::replace_input(Tensor *old, Tensor *replacement) +{ + + for( auto &p : input_params ) + { + if( std::get<0>(p) == old ) { + LOG(DEBUG) << "Did replacement" << std::endl; + std::get<0>(p) = replacement; + return true; + } + } + + LOG(DEBUG) << "No replacement" << std::endl; + return false; +} + diff --git a/src/node.h b/src/node.h index 648d4b1..dc06adb 100644 --- a/src/node.h +++ b/src/node.h @@ -88,6 +88,11 @@ class Node { * Start counting N from 0, including the non-optional outputs. */ bool is_output_N_used(unsigned N) const; + /* Replace input tensor 'old' with 'replacement'. + * Return false if 'old' is not an input tensor. + */ + bool replace_input(Tensor *old, Tensor *replacement); + /* Not all node types have attributes. Override where needed */ virtual void parseAttributes( onnx::NodeProto &node ) { diff --git a/src/optimization_passes/fold_casts.cpp b/src/optimization_passes/fold_casts.cpp new file mode 100644 index 0000000..c3afcb8 --- /dev/null +++ b/src/optimization_passes/fold_casts.cpp @@ -0,0 +1,83 @@ +/* This file is part of onnx2c. + * + * Implemented here is the 'fold_casts' optimization pass + * that tires to remove Cast nodes. + */ +#include "graph.h" +#include +#include + +using namespace toC; + +void Graph::fold_casts(void) +{ + LOG(DEBUG) << "Optimisation pass: fold casts"<< std::endl; + std::vector removed_nodes; + + // Loop over all Cast nodes + for( auto n : nodes ) { + if( n->op_name != "Cast" ) { + LOG(TRACE) << n->onnx_name << " is not a Cast node, ignoring."<< std::endl; + continue; + } + LOG(DEBUG) << "considering 'Cast' Node: " << n->onnx_name << std::endl; + + + // If the Cast node's input has other users + // the transformation becomes too difficult. + // The input generating Predecessor node + // would now need to generate two different + // outputs, one for the folded cast, one of + // the other user(s). + // Skip folding these Cast nodes. + assert(n->get_number_of_inputs() == 1); + Tensor *input_tensor = n->get_input_tensor(0); + Tensor *output_tensor = n->get_output_tensor(0); + if( input_tensor->consumers.size() != 1 ) { + LOG(DEBUG) << " skipping. Input tensor has other users."<< std::endl; + continue; + } + + + // Degenerate case where the graph input is directly the output. + // This happens in unit tests at least, but other than that, sounds like an error. + if( output_tensor->isIO && input_tensor->isIO ) { + LOG(WARNING) << " Cast output is graph output??" << std::endl; + continue; + } + + + LOG(DEBUG) << " folding away this Cast node."<< std::endl; + // Modify the Predecessor node's output to + // match the type of the Cast node's output. + onnx::TensorProto_DataType cast_to_type; + cast_to_type = output_tensor->data_type; + input_tensor->data_type = cast_to_type; + + // Replace the Cast output tensor's users input + // with the Predecessor node output. I.e. bypass + // the cast node. + for( auto cn : output_tensor->consumers ) { + bool replaced; + replaced = cn->replace_input(output_tensor, input_tensor ); + if( !replaced ) { + LOG(FATAL) << output_tensor->name << " was not replaced" << std::endl; + } + else { + std::erase(tensors, output_tensor); + delete output_tensor; + } + } + + // Mark the now orphaned Cast node for removal + removed_nodes.push_back(n); + } + + for( auto rn : removed_nodes ) { + std::erase(nodes, rn); + delete rn; + } + LOG(TRACE) << "folding Cast nodes finished" << std::endl; +} + + diff --git a/src/options.cc b/src/options.cc index b0cc197..bca918e 100644 --- a/src/options.cc +++ b/src/options.cc @@ -68,6 +68,7 @@ void print_optimization_passes(void) { std::cout << "Available optimization passes:" << std::endl; std::cout << " - 'unionize' (defaut:on)" << std::endl; + std::cout << " - 'fold_casts' (defaut:on)" << std::endl; std::cout << " - 'none' (disable all optimization passes)" << std::endl; } @@ -78,6 +79,7 @@ void store_optimization_passes(const std::string &opt) // disable all optimizations (i.e. override the default settings) // then enable those that were requested options.opt_unionize=false; + options.opt_fold_casts=false; if( opt == "none" ) { LOG(TRACE) << "Disabling all optimizations: " << opt << std::endl; @@ -98,6 +100,11 @@ void store_optimization_passes(const std::string &opt) LOG(DEBUG) << "Enabling 'Unionize tensors' optimization pass" << std::endl; options.opt_unionize=true; } + else if( item == "fold_casts" ) + { + LOG(DEBUG) << "Enabling 'Fold casts' optimization pass" << std::endl; + options.opt_fold_casts=true; + } else { LOG(WARNING) << "Optimization pass " << item << " does not exist" << std::endl; } diff --git a/src/options.h b/src/options.h index bb580b2..6385dda 100644 --- a/src/options.h +++ b/src/options.h @@ -13,6 +13,7 @@ struct onnx2c_opts bool quantize=false; bool target_avr=false; bool opt_unionize=true; + bool opt_fold_casts=true; /* * logging levels are * cmd line aixlog Use