Skip to content

Commit

Permalink
Generalize CategoryMapper test. (llvm#1172)
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>

Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
Ettore Tiotto and AlexandreEichenberger authored Feb 14, 2022
1 parent 90c2b8b commit 54941f9
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 103 deletions.
16 changes: 8 additions & 8 deletions include/onnx-mlir/Runtime/OMTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void omTensorDestroy(OMTensor *tensor);
* @return pointer to the data buffer of the OMTensor,
* NULL if the data buffer is not set.
*/
void *omTensorGetDataPtr(OMTensor *tensor);
void *omTensorGetDataPtr(const OMTensor *tensor);

/**
* \brief OMTensor data shape getter.
Expand All @@ -155,7 +155,7 @@ void *omTensorGetDataPtr(OMTensor *tensor);
* @param tensor pointer to the OMTensor
* @return pointer to the data shape array.
*/
int64_t *omTensorGetShape(OMTensor *tensor);
int64_t *omTensorGetShape(const OMTensor *tensor);

/**
* \brief OMTensor data shape setter.
Expand Down Expand Up @@ -185,7 +185,7 @@ void omTensorSetShape(OMTensor *tensor, int64_t *shape);
* @param tensor pointer to the OMTensor
* @return pointer to the data strides array.
*/
int64_t *omTensorGetStrides(OMTensor *tensor);
int64_t *omTensorGetStrides(const OMTensor *tensor);

/**
* \brief OMTensor data strides setter
Expand Down Expand Up @@ -230,7 +230,7 @@ void omTensorSetStridesWithPyArrayStrides(
* @param tensor pointer to the OMTensor
* @return ONNX data type of the data buffer elements.
*/
OM_DATA_TYPE omTensorGetDataType(OMTensor *tensor);
OM_DATA_TYPE omTensorGetDataType(const OMTensor *tensor);

/**
* \brief OMTensor data type setter
Expand All @@ -253,30 +253,30 @@ static inline int64_t getDataTypeSize(OM_DATA_TYPE dataType) {
* @param tensor pointer to the OMTensor
* @return the total size of the data buffer in bytes.
*/
int64_t omTensorGetBufferSize(OMTensor *tensor);
int64_t omTensorGetBufferSize(const OMTensor *tensor);

/**
* \brief OMTensor rank getter
*
* @param tensor, pointer to the OMTensor
* @return rank of data shape and strides of the OMTensor.
*/
int64_t omTensorGetRank(OMTensor *tensor);
int64_t omTensorGetRank(const OMTensor *tensor);

/**
* \brief OMTensor number of elements getter
*
* @param tensor, pointer to the OMTensor
* @return the number of elements in the data buffer.
*/
int64_t omTensorGetNumElems(OMTensor *tensor);
int64_t omTensorGetNumElems(const OMTensor *tensor);

/**
* \brief OMTensor owning flag getter
*
* @return owning flag of the OMTensor.
*/
int64_t omTensorGetOwning(OMTensor *tensor);
int64_t omTensorGetOwning(const OMTensor *tensor);

/**
* \brief OMTensor owning flag setter
Expand Down
55 changes: 31 additions & 24 deletions src/Runtime/OMTensor.inc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===--------- OMTensor.inc - C/C++ Neutral OMTensor Implementation--------===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -225,7 +225,7 @@ void omTensorDestroy(OMTensor *tensor) {
}

/* OMTensor data getter */
void *omTensorGetDataPtr(OMTensor *tensor) { return tensor->_alignedPtr; }
void *omTensorGetDataPtr(const OMTensor *tensor) { return tensor->_alignedPtr; }

/**
* OMTensor allocated and aligned pointer setter.
Expand Down Expand Up @@ -255,7 +255,7 @@ void omTensorSetDataPtr(
}

/* OMTensor data shape getter */
int64_t *omTensorGetShape(OMTensor *tensor) { return tensor->_shape; }
int64_t *omTensorGetShape(const OMTensor *tensor) { return tensor->_shape; }

/* OMTensor data shape setter */
void omTensorSetShape(OMTensor *tensor, int64_t *shape) {
Expand All @@ -264,7 +264,7 @@ void omTensorSetShape(OMTensor *tensor, int64_t *shape) {
}

/* OMTensor data strides getter */
int64_t *omTensorGetStrides(OMTensor *tensor) { return tensor->_strides; }
int64_t *omTensorGetStrides(const OMTensor *tensor) { return tensor->_strides; }

/* OMTensor data strides setter */
void omTensorSetStrides(OMTensor *tensor, int64_t *strides) {
Expand All @@ -281,24 +281,26 @@ void omTensorSetStridesWithPyArrayStrides(
}

/* OMTensor data type getter */
OM_DATA_TYPE omTensorGetDataType(OMTensor *tensor) { return tensor->_dataType; }
OM_DATA_TYPE omTensorGetDataType(const OMTensor *tensor) {
return tensor->_dataType;
}

/* OMTensor data type setter */
void omTensorSetDataType(OMTensor *tensor, OM_DATA_TYPE dataType) {
tensor->_dataType = dataType;
}

/* OMTensor data buffer size getter */
int64_t omTensorGetBufferSize(OMTensor *tensor) {
int64_t omTensorGetBufferSize(const OMTensor *tensor) {
return getNumElems(tensor->_shape, tensor->_rank) *
getDataTypeSize(tensor->_dataType);
}

/* OMTensor rank getter */
int64_t omTensorGetRank(OMTensor *tensor) { return tensor->_rank; }
int64_t omTensorGetRank(const OMTensor *tensor) { return tensor->_rank; }

/* OMTensor number of elements getter */
int64_t omTensorGetNumElems(OMTensor *tensor) {
int64_t omTensorGetNumElems(const OMTensor *tensor) {
// Using signed indices helps detect when index falls below 0.
// Verify that strides are dense, meaning that there're
// no skipping elements.
Expand All @@ -317,7 +319,7 @@ int64_t omTensorGetNumElems(OMTensor *tensor) {
*
* @return owning flag of the OMTensor.
*/
int64_t omTensorGetOwning(OMTensor *tensor) { return tensor->_owning; }
int64_t omTensorGetOwning(const OMTensor *tensor) { return tensor->_owning; }

/**
* OMTensor owning flag setter.
Expand All @@ -336,7 +338,7 @@ void omTensorSetOwning(OMTensor *tensor, int64_t owning) {
* @return pointer to the allocated data buffer of the OMTensor,
* NULL if the allocated data buffer is not set.
*/
void *omTensorGetAllocatedPtr(OMTensor *tensor) {
void *omTensorGetAllocatedPtr(const OMTensor *tensor) {
return tensor->_allocatedPtr;
}

Expand Down Expand Up @@ -403,25 +405,25 @@ OMTensor *omTensorCreateWithRandomData(

/* Access an element (by reference) at offset computed by index array */
template <typename T>
T &omTensorGetElem(OMTensor *omt, std::vector<int64_t> indexes) {
T &omTensorGetElem(const OMTensor *omt, std::vector<int64_t> indexes) {
int64_t elemOffset = omTensorComputeElemOffset(omt, indexes);
return ((T *)omt->_alignedPtr)[elemOffset];
}

/* Access an element (by reference) at linear offset */
template <typename T>
T &omTensorGetElemByOffset(OMTensor *omt, int64_t index) {
T &omTensorGetElemByOffset(const OMTensor *omt, int64_t index) {
return ((T *)omt->_alignedPtr)[index];
}

/* Compute strides vector from shape vector */
std::vector<int64_t> omTensorComputeStridesFromShape(OMTensor *omt) {
std::vector<int64_t> omTensorComputeStridesFromShape(const OMTensor *omt) {
return computeStridesFromShape(omt->_shape, omt->_rank);
}

/* Compute linear element offset from multi-dimensional index array */
int64_t omTensorComputeElemOffset(
OMTensor *omt, std::vector<int64_t> &indexes) {
const OMTensor *omt, std::vector<int64_t> &indexes) {
return computeElemOffset(omt->_strides, omt->_rank, indexes);
}

Expand Down Expand Up @@ -529,20 +531,25 @@ template OMTensor *omTensorCreateWithRandomData<float>(
template OMTensor *omTensorCreateWithRandomData<double>(
std::vector<int64_t> shape, double lbound, double ubound);

template bool &omTensorGetElem<bool>(OMTensor *, std::vector<int64_t> indexes);
template bool &omTensorGetElem<bool>(
const OMTensor *, std::vector<int64_t> indexes);
template int32_t &omTensorGetElem<int32_t>(
OMTensor *, std::vector<int64_t> indexes);
const OMTensor *, std::vector<int64_t> indexes);
template int64_t &omTensorGetElem<int64_t>(
OMTensor *, std::vector<int64_t> indexes);
const OMTensor *, std::vector<int64_t> indexes);
template float &omTensorGetElem<float>(
OMTensor *, std::vector<int64_t> indexes);
const OMTensor *, std::vector<int64_t> indexes);
template double &omTensorGetElem<double>(
OMTensor *, std::vector<int64_t> indexes);

template int32_t &omTensorGetElemByOffset<int32_t>(OMTensor *, int64_t index);
template int64_t &omTensorGetElemByOffset<int64_t>(OMTensor *, int64_t index);
template float &omTensorGetElemByOffset<float>(OMTensor *, int64_t indexs);
template double &omTensorGetElemByOffset<double>(OMTensor *, int64_t index);
const OMTensor *, std::vector<int64_t> indexes);

template int32_t &omTensorGetElemByOffset<int32_t>(
const OMTensor *, int64_t index);
template int64_t &omTensorGetElemByOffset<int64_t>(
const OMTensor *, int64_t index);
template float &omTensorGetElemByOffset<float>(
const OMTensor *, int64_t indexs);
template double &omTensorGetElemByOffset<double>(
const OMTensor *, int64_t index);

template bool omTensorAreTwoOmtsClose<int32_t>(
OMTensor *a, OMTensor *b, float rtol, float atol);
Expand Down
11 changes: 6 additions & 5 deletions src/Runtime/OMTensorHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ OMTensor *omTensorCreateWithRandomData(
* @return typed element by reference at the offset computed by the index array.
*/
template <typename T>
T &omTensorGetElem(OMTensor *omt, std::vector<int64_t> indexes);
T &omTensorGetElem(const OMTensor *omt, std::vector<int64_t> indexes);

/**
* OMTensor data element getter by index
Expand All @@ -118,15 +118,15 @@ T &omTensorGetElem(OMTensor *omt, std::vector<int64_t> indexes);
* @return typed element by reference at the linear offset.
*/
template <typename T>
T &omTensorGetElemByOffset(OMTensor *omt, int64_t index);
T &omTensorGetElemByOffset(const OMTensor *omt, int64_t index);

/**
* OMTensor strides computation
*
* @param omt, pointer to the OMTensor
* @return data strides of the OMTensor computed from the data sizes.
*/
std::vector<int64_t> omTensorComputeStridesFromShape(OMTensor *omt);
std::vector<int64_t> omTensorComputeStridesFromShape(const OMTensor *omt);

/**
* OMTensor linear offset computation
Expand All @@ -135,7 +135,8 @@ std::vector<int64_t> omTensorComputeStridesFromShape(OMTensor *omt);
* @param indexes, multi-dimensional index array
* @return linear offset.
*/
int64_t omTensorComputeElemOffset(OMTensor *omt, std::vector<int64_t> &indexes);
int64_t omTensorComputeElemOffset(
const OMTensor *omt, std::vector<int64_t> &indexes);

/**
* OMTensor index set computation
Expand All @@ -145,7 +146,7 @@ int64_t omTensorComputeElemOffset(OMTensor *omt, std::vector<int64_t> &indexes);
* that can be used to access this OMTensor's constituent elements)
* for the whole OMTensor.
*/
std::vector<std::vector<int64_t>> omTensorComputeIndexSet(OMTensor *omt);
std::vector<std::vector<int64_t>> omTensorComputeIndexSet(const OMTensor *omt);

/**
* OMTensor "distance" computation
Expand Down
Loading

0 comments on commit 54941f9

Please sign in to comment.