diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index e507b7675d637..d68c49fe6b33d 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -88,6 +88,11 @@ void Tensor::ReshapeStrings(const size_t &shape) { template T *Tensor::mutable_data(PlaceType place) { +#ifdef PADDLE_WITH_ONNXRUNTIME + if (is_ort_tensor_) { + return ORTGetMutableData(); + } +#endif EAGER_GET_TENSOR(paddle::framework::LoDTensor); PADDLE_ENFORCE_GT( tensor->numel(), @@ -670,6 +675,17 @@ void Tensor::SetOrtBinding(const std::shared_ptr binding) { binding_ = binding; } +template +T *Tensor::ORTGetMutableData() { + auto binding = binding_.lock(); + PADDLE_ENFORCE_NOT_NULL(binding, + paddle::platform::errors::PreconditionNotMet( + "output tensor [%s] no binding ptr", name_)); + std::vector outputs = binding->GetOutputValues(); + Ort::Value &value = outputs[idx_]; + return value.GetTensorMutableData(); +} + template void Tensor::ORTCopyToCpu(T *data) const { auto binding = binding_.lock(); diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index 2afe2d32e2f60..35ff80b9cbf09 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -190,6 +190,9 @@ class PD_INFER_DECL Tensor { void SetOrtBinding(const std::shared_ptr binding); + template + T* ORTGetMutableData(); + template void ORTCopyFromCpu(const T* data);