Skip to content

Commit

Permalink
Clean up code.
Browse files Browse the repository at this point in the history
When printing the function parameters.
No functional change.
  • Loading branch information
kraiskil committed Aug 1, 2023
1 parent b382b82 commit baa5e64
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/graph_print.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void Graph::print_functions(std::ostream &dst)
for( auto n : nodes ) {
dst << "static inline void ";
dst << n->c_name() << "( ";
n->print_function_parameters_shapes(dst);
n->print_function_parameters_definition(dst);
dst << " )";
dst << std::endl << "{" << std::endl;

Expand Down
55 changes: 22 additions & 33 deletions src/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,56 +110,45 @@ void Node::multidirectional_broadcast_size(



// NB: old node implementations that dont use input_params and output_params
// must and have overridden this function.
// New mode node implementations use the Node::register_input() and
// Node::register_output() functions
void Node::print_parameters(std::ostream &dst, bool not_callsite ) const
{
// First create the parameter names as strings (with or without dimensions)
std::vector<std::string> params;

if( not_callsite )
{
for( auto i : input_params ) {
const Tensor *t = std::get<0>(i);
std::string name = std::get<1>(i);
for( auto i : input_params ) {
const Tensor *t = std::get<0>(i);
std::string name = std::get<1>(i);
if( not_callsite )
params.push_back( t->print_tensor_as_const(name) );
}
for( auto o : output_params ) {
const Tensor *t = std::get<0>(o);
// A node does not know at its resolve time if an optional
// output is used, so it registers all. Once all nodes
// are resolved, the tensor knows if some one uses it.
if( t->is_used() == false )
continue;
std::string name = std::get<1>(o);
params.push_back( t->print_tensor(name) );
}
}
else
{
for( auto i : input_params ) {
const Tensor *t = std::get<0>(i);
else
params.push_back( t->print_tensor_callsite() );
}
for( auto o : output_params ) {
const Tensor *t = std::get<0>(o);
if( t->is_used() == false )
continue;
}
for( auto o : output_params ) {
const Tensor *t = std::get<0>(o);
// A node does not know at its resolve time if an optional
// output is used, so it registers all. Once all nodes
// are resolved, the tensor knows if some one uses it.
if( t->is_used() == false )
continue;
std::string name = std::get<1>(o);
if( not_callsite )
params.push_back( t->print_tensor(name) );
else
params.push_back( t->print_tensor_callsite() );
}
}

// Then print the parmeters as comma-separated string
auto i = params.begin();
dst << *i ;
for( i++; i != params.end(); i++)
dst << ", " << *i;
}

void Node::print_function_parameters_shapes(std::ostream &destination) const
// parameters at function definition/declaration
void Node::print_function_parameters_definition(std::ostream &destination) const
{
print_parameters(destination, true);
}
// parameters when calling a function
void Node::print_function_parameters_callsite(std::ostream &destination) const
{
print_parameters(destination, false);
Expand Down
11 changes: 3 additions & 8 deletions src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,12 @@ class Node {
* or decorated
* "float tensor_X[1][2][3], float tensor_Y[2][3][4]"
*
* "old" vs "new":
* Currently print_parameters() is splatted all over the node implementing
* subclasses. It used to be a pure virtual function for bad design reasons.
* The new way, which is being implemented piecemeal is to have the
* node::resolve() function create mappings for all of its function parameters
* node::resolve() function creates mappings for all of its function parameters
* so that each tensor has a "local name" corresponding to the tensor name in
* the ONNX Operands specificaion.
* In short, don't override print_parameters() for new node subclasses!
*/
virtual void print_parameters(std::ostream &destination, bool decorate ) const;
void print_function_parameters_shapes(std::ostream &destination) const;
void print_parameters(std::ostream &destination, bool decorate ) const;
void print_function_parameters_definition(std::ostream &destination) const;
void print_function_parameters_callsite(std::ostream &destination) const;

/* Figure out in what format the output is in.
Expand Down

0 comments on commit baa5e64

Please sign in to comment.