Skip to content

Commit

Permalink
Enable Ort objects to be stored in a resizable std::vector (#22608)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Allow some classes to be default constructed.
The effect is the same as constructing it with nullptr.
Make default ctor visible from the base classes.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Multiple customers complained that when storing Ort::Value
in an instance of std::vector, vector can not be resized.

We enable that with allowing it default constructed.
  • Loading branch information
yuslepukhin authored Oct 29, 2024
1 parent 951d9aa commit e106131
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
37 changes: 33 additions & 4 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
* constructors to construct an instance of a Status object from exceptions.
*/
struct Status : detail::Base<OrtStatus> {
using Base = detail::Base<OrtStatus>;
using Base::Base;

explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
Expand Down Expand Up @@ -728,6 +731,9 @@ struct Env : detail::Base<OrtEnv> {
*
*/
struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
using Base = detail::Base<OrtCustomOpDomain>;
using Base::Base;

explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used

/// \brief Wraps OrtApi::CreateCustomOpDomain
Expand Down Expand Up @@ -963,8 +969,10 @@ struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
*
*/
struct ModelMetadata : detail::Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
using Base = detail::Base<OrtModelMetadata>;
using Base::Base;

explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used

/** \brief Returns a copy of the producer name.
*
Expand Down Expand Up @@ -1237,6 +1245,9 @@ using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::U
*
*/
struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
using Base = detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo>;
using Base::Base;

explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
Expand All @@ -1258,6 +1269,9 @@ using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const
*
*/
struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
using Base = detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo>;
using Base::Base;

explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
Expand Down Expand Up @@ -1293,6 +1307,9 @@ using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTyp
*
*/
struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
using Base = detail::MapTypeInfoImpl<OrtMapTypeInfo>;
using Base::Base;

explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
Expand Down Expand Up @@ -1324,6 +1341,9 @@ using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
/// the information about contained sequence or map depending on the ONNXType.
/// </summary>
struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
using Base = detail::TypeInfoImpl<OrtTypeInfo>;
using Base::Base;

explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop

Expand Down Expand Up @@ -1661,11 +1681,11 @@ using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
*/
struct Value : detail::ValueImpl<OrtValue> {
using Base = detail::ValueImpl<OrtValue>;
using Base::Base;
using OrtSparseValuesParam = detail::OrtSparseValuesParam;
using Shape = detail::Shape;

explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
Value(Value&&) = default;
Value& operator=(Value&&) = default;

Expand Down Expand Up @@ -1941,6 +1961,10 @@ struct ArenaCfg : detail::Base<OrtArenaCfg> {
/// This struct provides life time management for custom op attribute
/// </summary>
struct OpAttr : detail::Base<OrtOpAttr> {
using Base = detail::Base<OrtOpAttr>;
using Base::Base;

explicit OpAttr(std::nullptr_t) {}
OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
};

Expand Down Expand Up @@ -2183,6 +2207,8 @@ using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelIn
/// so it does not destroy the pointer the kernel does not own.
/// </summary>
struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
using Base = detail::KernelInfoImpl<OrtKernelInfo>;
using Base::Base;
explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
Expand All @@ -2192,6 +2218,9 @@ struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
/// Create and own custom defined operation.
/// </summary>
struct Op : detail::Base<OrtOp> {
using Base = detail::Base<OrtOp>;
using Base::Base;

explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used

explicit Op(OrtOp*); ///< Take ownership of the OrtOp
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ inline void ThrowOnError(const Status& st) {
}
}

inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
inline Status::Status(OrtStatus* status) noexcept : detail::Base<OrtStatus>{status} {
}

inline Status::Status(const std::exception& e) noexcept {
Expand Down Expand Up @@ -1908,7 +1908,7 @@ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::

inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}

inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
inline Op::Op(OrtOp* p) : detail::Base<OrtOp>(p) {}

inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
const char** type_constraint_names,
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/test/shared_lib/test_nontensor_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,32 @@ TEST(CApiTest, SparseTensorFillSparseTensorFormatAPI) {
}
}

TEST(CApi, TestResize) {
std::vector<Ort::Value> values;
values.resize(10);

std::vector<Ort::Status> sts;
sts.resize(5);

std::vector<Ort::CustomOpDomain> domains;
domains.resize(5);

std::vector<Ort::TensorTypeAndShapeInfo> type_and_shape;
type_and_shape.resize(5);

std::vector<Ort::SequenceTypeInfo> seq_type_info;
seq_type_info.resize(5);

std::vector<Ort::MapTypeInfo> map_type_info;
map_type_info.resize(5);

std::vector<Ort::TypeInfo> type_info;
type_info.resize(5);

std::vector<Ort::OpAttr> op_attr;
op_attr.resize(5);
}

TEST(CApiTest, SparseTensorFillSparseFormatStringsAPI) {
auto allocator = Ort::AllocatorWithDefaultOptions();
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
Expand Down

0 comments on commit e106131

Please sign in to comment.