Skip to content

Commit

Permalink
Node input tensors not 'const'
Browse files Browse the repository at this point in the history
Refactoring. No functional changes.
  • Loading branch information
kraiskil committed Apr 22, 2023
1 parent bdfe0f7 commit 34f8af0
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ Tensor* Graph::getIoTensor(onnx::ValueInfoProto &vi)



bool Graph::getNodeInputTensors(const onnx::NodeProto &node, std::vector<const Tensor*> &inputs)
bool Graph::getNodeInputTensors(const onnx::NodeProto &node, std::vector<Tensor*> &inputs)
{
// TODO: ugly. Move where?
static const Tensor unused;
static Tensor unused;

// if all inputs can be found in the tensors-vector, then yes, inputs are resolved
for( auto i : node.input() )
Expand Down Expand Up @@ -237,7 +237,7 @@ bool Graph::getNodeInputTensors(const onnx::NodeProto &node, std::vector<const T
*/
bool Graph::tryResolveNode(onnx::NodeProto &node)
{
std::vector<const Tensor*> inputs;
std::vector<Tensor*> inputs;
LOG(DEBUG) << "Resolving ONNX node " << node.name() <<std::endl;

for( auto o : nodes )
Expand Down Expand Up @@ -575,7 +575,7 @@ Tensor *Graph::findTensor(const std::string &name) const
return NULL;
}

void Graph::replaceWithQuantized(std::vector<const Tensor*> &inputs)
void Graph::replaceWithQuantized(std::vector<Tensor*> &inputs)
{
for( unsigned i=0; i<inputs.size(); i++ ) {
if(inputs[i]->quantizedCopy)
Expand Down
4 changes: 2 additions & 2 deletions src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class Graph {
void addInitializedTensor(onnx::TensorProto &tensor);
Tensor* getIoTensor(onnx::ValueInfoProto &vi);

void replaceWithQuantized(std::vector<const Tensor*> &inputs);
bool getNodeInputTensors(const onnx::NodeProto &node, std::vector<const Tensor*> &inputs);
void replaceWithQuantized(std::vector<Tensor*> &inputs);
bool getNodeInputTensors(const onnx::NodeProto &node, std::vector<Tensor*> &inputs);

bool tryResolveNode(onnx::NodeProto &node);
bool hasUnresolvedNodes(void);
Expand Down
2 changes: 1 addition & 1 deletion src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Node {
std::string onnx_name; //ONNX name of the individual node
std::string op_name; //ONNX name of node type
static int64_t onnx_ir_version;
std::vector< const Tensor*> inputs; // List of input tensors in the .onnx file
std::vector<Tensor*> inputs; // List of input tensors in the .onnx file

// NB: this is deprecated. Whenever a node is updated,
// any reference to this variable should be removed.
Expand Down
2 changes: 1 addition & 1 deletion src/nodes/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace toC {
}

// inputs
std::vector<const Tensor *> node_inputs; // 'inputs' in the spec
std::vector<Tensor *> node_inputs; // 'inputs' in the spec

// output
const Tensor *concat_result ;
Expand Down
2 changes: 1 addition & 1 deletion src/nodes/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class MatMul : public Node {
register_output(rv, "Y");
}

void result_dim( const std::vector< const Tensor*> &inputs, int32_t &rows, int32_t &cols) const
void result_dim( const std::vector< Tensor*> &inputs, int32_t &rows, int32_t &cols) const
{
// TODO: this is the check for vectors. Check equivalent for N-dimensons: N>2
if( inputs[0]->data_dim[1] != 0 && inputs[1]->data_dim[1] != 0 )
Expand Down
2 changes: 1 addition & 1 deletion src/nodes/matmulinteger.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class MatMulInteger : public Node {
outputs.push_back(rv);
}

void result_dim( const std::vector< const Tensor*> &inputs, int32_t &rows, int32_t &cols) const
void result_dim( const std::vector< Tensor*> &inputs, int32_t &rows, int32_t &cols) const
{
// TODO: this is the check for vectors. Check equivalent for N-dimensons: N>2
if( inputs[0]->data_dim[1] != 0 && inputs[1]->data_dim[1] != 0 )
Expand Down
2 changes: 1 addition & 1 deletion src/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Tensor {
// IO tensors still get initialized e.g. in the test suite
bool isRecursive;// tensor that one node uses both output and input.
// may additionally be used as input for other nodes
const Tensor *quantizedCopy; // non-NULL if there is a quantized version of this
Tensor *quantizedCopy; // non-NULL if there is a quantized version of this
bool isQuantized; // is this a quantized copy
std::vector<int> data_dim;
onnx::TensorProto_DataType data_type;
Expand Down

0 comments on commit 34f8af0

Please sign in to comment.