Skip to content

Commit

Permalink
Add fold_casts optimization pass.
Browse files Browse the repository at this point in the history
When possible remove Cast operands,
modifying the output of the Cast's
parent op instead.

Bump to C++20, because of std::erase.
  • Loading branch information
kraiskil committed May 19, 2024
1 parent 2c32c1e commit f861b9a
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ project(onnx2c
VERSION 0.0.1
LANGUAGES C CXX
)
set (CMAKE_CXX_STANDARD 17)
set (CMAKE_CXX_STANDARD 20)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Werror")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wfatal-errors")

Expand All @@ -32,6 +32,7 @@ add_library(onnx2c_lib STATIC
src/node.cc
src/tensor.cc
src/util.cc
src/optimization_passes/fold_casts.cpp
src/optimization_passes/unionize_tensors.cpp
${CMAKE_CURRENT_BINARY_DIR}/onnx.pb.cc
src/nodes/cast.cc
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ See the [GCC wiki on floating point maths](https://gcc.gnu.org/wiki/FloatingPoin

Onnx2c has a few optimization passes that modify the generated output:
- Tensor unionization to wrap intermediate tensors in unions to help the compiler re-use the heap memory.
- Removing `Cast` nodes, by modifying their predecessor node's output tensor.
- Optimization for AVR processors to put constants into instruction memory.
- An [experimental quantization option](quantization.md) to convert floating point calculation to integers.

Expand Down
4 changes: 4 additions & 0 deletions src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class Graph {
* unions. This make the memory buffers time shared. */
void unionize_tensors(void);

/* Optimization step: Fold Cast-nodes to their predecessor. */
void fold_casts(void);

void addInitializedTensor(onnx::TensorProto &tensor);
Tensor* getIoTensor(onnx::ValueInfoProto &vi);

Expand Down Expand Up @@ -92,6 +95,7 @@ class Graph {
std::vector<Tensor *> tensor_unions;
uint32_t add_to_free_union(Tensor *t);
void mark_union_unoccupied(uint32_t);

};

}
Expand Down
2 changes: 2 additions & 0 deletions src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ int main(int argc, const char *argv[])

std::cout.precision(20);
toC::Graph toCgraph(onnx_model);
if( options.opt_fold_casts )
toCgraph.fold_casts();
if( options.opt_unionize )
toCgraph.unionize_tensors();
toCgraph.print_source(std::cout);
Expand Down
17 changes: 17 additions & 0 deletions src/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,20 @@ unsigned Node::get_number_of_outputs(void) const
{
return output_params.size();
}

bool Node::replace_input(Tensor *old, Tensor *replacement)
{

for( auto &p : input_params )
{
if( std::get<0>(p) == old ) {
LOG(DEBUG) << "Did replacement" << std::endl;
std::get<0>(p) = replacement;
return true;
}
}

LOG(DEBUG) << "No replacement" << std::endl;
return false;
}

5 changes: 5 additions & 0 deletions src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ class Node {
* Start counting N from 0, including the non-optional outputs. */
bool is_output_N_used(unsigned N) const;

/* Replace input tensor 'old' with 'replacement'.
* Return false if 'old' is not an input tensor.
*/
bool replace_input(Tensor *old, Tensor *replacement);

/* Not all node types have attributes. Override where needed */
virtual void parseAttributes( onnx::NodeProto &node )
{
Expand Down
83 changes: 83 additions & 0 deletions src/optimization_passes/fold_casts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* This file is part of onnx2c.
*
* Implemented here is the 'fold_casts' optimization pass
* that tires to remove Cast nodes.
*/
#include "graph.h"
#include <cassert>
#include <cstdint>

using namespace toC;

void Graph::fold_casts(void)
{
LOG(DEBUG) << "Optimisation pass: fold casts"<< std::endl;
std::vector<Node*> removed_nodes;

// Loop over all Cast nodes
for( auto n : nodes ) {
if( n->op_name != "Cast" ) {
LOG(TRACE) << n->onnx_name << " is not a Cast node, ignoring."<< std::endl;
continue;
}
LOG(DEBUG) << "considering 'Cast' Node: " << n->onnx_name << std::endl;


// If the Cast node's input has other users
// the transformation becomes too difficult.
// The input generating Predecessor node
// would now need to generate two different
// outputs, one for the folded cast, one of
// the other user(s).
// Skip folding these Cast nodes.
assert(n->get_number_of_inputs() == 1);
Tensor *input_tensor = n->get_input_tensor(0);
Tensor *output_tensor = n->get_output_tensor(0);
if( input_tensor->consumers.size() != 1 ) {
LOG(DEBUG) << " skipping. Input tensor has other users."<< std::endl;
continue;
}


// Degenerate case where the graph input is directly the output.
// This happens in unit tests at least, but other than that, sounds like an error.
if( output_tensor->isIO && input_tensor->isIO ) {
LOG(WARNING) << " Cast output is graph output??" << std::endl;
continue;
}


LOG(DEBUG) << " folding away this Cast node."<< std::endl;
// Modify the Predecessor node's output to
// match the type of the Cast node's output.
onnx::TensorProto_DataType cast_to_type;
cast_to_type = output_tensor->data_type;
input_tensor->data_type = cast_to_type;

// Replace the Cast output tensor's users input
// with the Predecessor node output. I.e. bypass
// the cast node.
for( auto cn : output_tensor->consumers ) {
bool replaced;
replaced = cn->replace_input(output_tensor, input_tensor );
if( !replaced ) {
LOG(FATAL) << output_tensor->name << " was not replaced" << std::endl;
}
else {
std::erase(tensors, output_tensor);
delete output_tensor;
}
}

// Mark the now orphaned Cast node for removal
removed_nodes.push_back(n);
}

for( auto rn : removed_nodes ) {
std::erase(nodes, rn);
delete rn;
}
LOG(TRACE) << "folding Cast nodes finished" << std::endl;
}


7 changes: 7 additions & 0 deletions src/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void print_optimization_passes(void)
{
std::cout << "Available optimization passes:" << std::endl;
std::cout << " - 'unionize' (defaut:on)" << std::endl;
std::cout << " - 'fold_casts' (defaut:on)" << std::endl;
std::cout << " - 'none' (disable all optimization passes)" << std::endl;
}

Expand All @@ -78,6 +79,7 @@ void store_optimization_passes(const std::string &opt)
// disable all optimizations (i.e. override the default settings)
// then enable those that were requested
options.opt_unionize=false;
options.opt_fold_casts=false;
if( opt == "none" )
{
LOG(TRACE) << "Disabling all optimizations: " << opt << std::endl;
Expand All @@ -98,6 +100,11 @@ void store_optimization_passes(const std::string &opt)
LOG(DEBUG) << "Enabling 'Unionize tensors' optimization pass" << std::endl;
options.opt_unionize=true;
}
else if( item == "fold_casts" )
{
LOG(DEBUG) << "Enabling 'Fold casts' optimization pass" << std::endl;
options.opt_fold_casts=true;
}
else {
LOG(WARNING) << "Optimization pass " << item << " does not exist" << std::endl;
}
Expand Down
1 change: 1 addition & 0 deletions src/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct onnx2c_opts
bool quantize=false;
bool target_avr=false;
bool opt_unionize=true;
bool opt_fold_casts=true;
/*
* logging levels are
* cmd line aixlog Use
Expand Down

0 comments on commit f861b9a

Please sign in to comment.