Skip to content

Commit

Permalink
add more capi to support stride (#62716)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongkaio authored Mar 14, 2024
1 parent fde63d1 commit 683a141
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 0 deletions.
12 changes: 12 additions & 0 deletions paddle/phi/capi/include/c_meta_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,25 @@ int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor,
size_t index,
PD_Status *status);

int64_t PD_MetaTensorGetNumStrides(const PD_MetaTensor *tensor,
PD_Status *status);

int64_t PD_MetaTensorGetStride(const PD_MetaTensor *tensor,
size_t index,
PD_Status *status);

bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status);

void PD_MetaTensorSetDims(PD_MetaTensor *tensor,
int64_t ndims,
const int64_t *dims,
PD_Status *status);

void PD_MetaTensorSetStrides(PD_MetaTensor *tensor,
int64_t nstrides,
const int64_t *strides,
PD_Status *status);

void PD_MetaTensorSetDataType(PD_MetaTensor *tensor,
PD_DataType dtype,
PD_Status *status);
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/capi/include/c_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ int64_t PD_TensorGetDim(const PD_Tensor *tensor,
size_t index,
PD_Status *status);

int64_t PD_TensorGetNumStrides(const PD_Tensor *tensor, PD_Status *status);

int64_t PD_TensorGetStride(const PD_Tensor *tensor,
size_t index,
PD_Status *status);

void PD_TensorGetLoD(const PD_Tensor *tensor,
PD_List *data,
PD_List *offset,
Expand All @@ -52,11 +58,22 @@ bool PD_TensorIsValid(const PD_Tensor *tensor, PD_Status *status);

void *PD_TensorGetHolder(const PD_Tensor *tensor, PD_Status *status);

size_t PD_TensorGetOffset(const PD_Tensor *tensor, PD_Status *status);

void PD_TensorSetDims(PD_Tensor *tensor,
int64_t ndims,
const int64_t *dims,
PD_Status *status);

void PD_TensorSetOffset(PD_Tensor *tensor,
const int64_t offset,
PD_Status *status);

void PD_TensorSetStrides(PD_Tensor *tensor,
int64_t nstrides,
const int64_t *strides,
PD_Status *status);

void PD_TensorSetDataType(PD_Tensor *tensor,
PD_DataType dtype,
PD_Status *status);
Expand Down
66 changes: 66 additions & 0 deletions paddle/phi/capi/include/wrapper_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ inline std::vector<int64_t> PD_TensorGetDims(PD_Tensor* tensor,
return std::vector<int64_t>();
}

inline std::vector<int64_t> PD_TensorGetStrides(PD_Tensor* tensor,
PD_Status* status) {
int64_t nstrides = PD_TensorGetNumStrides(tensor, status);
if (nstrides > 0) {
std::vector<int64_t> shape(nstrides);
for (int64_t i = 0; i < nstrides; ++i) {
shape[i] = PD_TensorGetStride(tensor, i, status);
}
return shape;
}
return std::vector<int64_t>();
}

inline std::vector<int64_t> PD_MetaTensorGetDims(PD_MetaTensor* tensor,
PD_Status* status) {
int64_t ndims = PD_MetaTensorGetNumDims(tensor, status);
Expand All @@ -85,6 +98,19 @@ inline std::vector<int64_t> PD_MetaTensorGetDims(PD_MetaTensor* tensor,
return std::vector<int64_t>();
}

inline std::vector<int64_t> PD_MetaTensorGetStrides(PD_MetaTensor* tensor,
PD_Status* status) {
int64_t nstrides = PD_MetaTensorGetNumStrides(tensor, status);
if (nstrides > 0) {
std::vector<int64_t> shape(nstrides);
for (int64_t i = 0; i < nstrides; ++i) {
shape[i] = PD_MetaTensorGetStride(tensor, i, status);
}
return shape;
}
return std::vector<int64_t>();
}

template <typename T>
class WrapperBase {
public:
Expand Down Expand Up @@ -134,13 +160,27 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
return holder;
}

size_t offset() const {
C_Status status;
auto offset = PD_TensorGetOffset(raw_data(), &status);
PD_CHECK_STATUS(status);
return offset;
}

std::vector<int64_t> dims() const {
C_Status status;
auto dimension = PD_TensorGetDims(raw_data(), &status);
PD_CHECK_STATUS(status);
return dimension;
}

std::vector<int64_t> strides() const {
C_Status status;
auto strides = PD_TensorGetStrides(raw_data(), &status);
PD_CHECK_STATUS(status);
return strides;
}

PD_DataType dtype() const {
C_Status status;
auto data_type = PD_TensorGetPDDataType(raw_data(), &status);
Expand Down Expand Up @@ -207,6 +247,18 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
PD_CHECK_STATUS(status);
}

void set_offset(const int64_t& offset) {
C_Status status;
PD_TensorSetOffset(raw_data(), offset, &status);
PD_CHECK_STATUS(status);
}

void set_strides(const std::vector<int64_t>& strides) {
C_Status status;
PD_TensorSetStrides(raw_data(), strides.size(), strides.data(), &status);
PD_CHECK_STATUS(status);
}

void set_dtype(PD_DataType data_type) {
C_Status status;
PD_TensorSetDataType(raw_data(), data_type, &status);
Expand Down Expand Up @@ -513,6 +565,13 @@ class MetaTensor : WrapperBase<PD_MetaTensor> {
return dimension;
}

std::vector<int64_t> strides() const {
C_Status status;
auto strides = PD_MetaTensorGetStrides(raw_data(), &status);
PD_CHECK_STATUS(status);
return strides;
}

PD_DataType dtype() const {
C_Status status;
auto data_type = PD_MetaTensorGetPDDataType(raw_data(), &status);
Expand Down Expand Up @@ -540,6 +599,13 @@ class MetaTensor : WrapperBase<PD_MetaTensor> {
PD_CHECK_STATUS(status);
}

void set_strides(const std::vector<int64_t>& strides) {
C_Status status;
PD_MetaTensorSetStrides(
raw_data(), strides.size(), strides.data(), &status);
PD_CHECK_STATUS(status);
}

void set_dtype(PD_DataType data_type) {
C_Status status;
PD_MetaTensorSetDataType(raw_data(), data_type, &status);
Expand Down
46 changes: 46 additions & 0 deletions paddle/phi/capi/lib/c_meta_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,36 @@ int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor,
return cc_tensor->dims()[index];
}

int64_t PD_MetaTensorGetNumStrides(const PD_MetaTensor *tensor,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}

auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
return cc_tensor->strides().size();
}

int64_t PD_MetaTensorGetStride(const PD_MetaTensor *tensor,
size_t index,
PD_Status *status) {
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);

if (status) {
if (!tensor || index >= static_cast<size_t>(cc_tensor->strides().size())) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}

return cc_tensor->strides()[index];
}

bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status) {
if (status) {
if (!tensor) {
Expand Down Expand Up @@ -117,6 +147,22 @@ void PD_MetaTensorSetDims(PD_MetaTensor *tensor,
cc_tensor->set_dims(common::make_ddim(shape));
}

void PD_MetaTensorSetStrides(PD_MetaTensor *tensor,
int64_t nstrides,
const int64_t *strides,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<phi::MetaTensor *>(tensor);
std::vector<int> shape(strides, strides + nstrides);
cc_tensor->set_strides(common::make_ddim(shape));
}

void PD_MetaTensorSetDataType(PD_MetaTensor *tensor,
PD_DataType dtype,
PD_Status *status) {
Expand Down
72 changes: 72 additions & 0 deletions paddle/phi/capi/lib/c_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ int64_t PD_TensorGetDim(const PD_Tensor* tensor,
return cc_tensor->dims()[index];
}

int64_t PD_TensorGetNumStrides(const PD_Tensor* tensor, PD_Status* status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}

auto cc_tensor = reinterpret_cast<const phi::DenseTensor*>(tensor);
return cc_tensor->strides().size();
}

int64_t PD_TensorGetStride(const PD_Tensor* tensor,
size_t index,
PD_Status* status) {
auto cc_tensor = reinterpret_cast<const phi::DenseTensor*>(tensor);

if (status) {
if (!tensor || index >= static_cast<size_t>(cc_tensor->strides().size())) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}

return cc_tensor->strides()[index];
}

void PD_TensorGetLoD(const PD_Tensor* tensor,
PD_List* data,
PD_List* offset,
Expand Down Expand Up @@ -185,6 +214,19 @@ void* PD_TensorGetHolder(const PD_Tensor* tensor, PD_Status* status) {
return cc_tensor->Holder().get();
}

size_t PD_TensorGetOffset(const PD_Tensor* tensor, PD_Status* status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}

auto cc_tensor = reinterpret_cast<const phi::DenseTensor*>(tensor);
return cc_tensor->offset();
}

void PD_TensorSetDims(PD_Tensor* tensor,
int64_t ndims,
const int64_t* dims,
Expand All @@ -201,6 +243,36 @@ void PD_TensorSetDims(PD_Tensor* tensor,
cc_tensor->Resize(common::make_ddim(shape));
}

void PD_TensorSetOffset(PD_Tensor* tensor,
const int64_t offset,
PD_Status* status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<phi::DenseTensor*>(tensor);
cc_tensor->set_offset(offset);
}

void PD_TensorSetStrides(PD_Tensor* tensor,
int64_t nstrides,
const int64_t* strides,
PD_Status* status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<phi::DenseTensor*>(tensor);
std::vector<int> shape(strides, strides + nstrides);
cc_tensor->set_strides(common::make_ddim(shape));
}

void PD_TensorSetDataType(PD_Tensor* tensor,
PD_DataType dtype,
PD_Status* status) {
Expand Down

0 comments on commit 683a141

Please sign in to comment.