diff --git a/src/nodes/cast.cc b/src/nodes/cast.cc index cdefd32..5bbeecc 100644 --- a/src/nodes/cast.cc +++ b/src/nodes/cast.cc @@ -25,11 +25,12 @@ void Cast::resolve(void) switch(to) { + case onnx::TensorProto_DataType_INT32: + output_type = "int32_t"; break; case onnx::TensorProto_DataType_FLOAT: - output_type = "float"; + output_type = "float"; break; case onnx::TensorProto_DataType_DOUBLE: - output_type = "double"; - break; + output_type = "double"; break; default: ERROR("Unimplemented casting to requested type"); }