From add2a5468c6a2517d6bc8aa952243c94383f3c49 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Sun, 31 Jul 2022 21:38:26 -0500 Subject: [PATCH] ort backend support output mutable data (#44724) --- .../inference/api/details/zero_copy_tensor.cc | 16 ++++++++++++++++ paddle/fluid/inference/api/paddle_tensor.h | 3 +++ 2 files changed, 19 insertions(+) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index e507b7675d637..d68c49fe6b33d 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -88,6 +88,11 @@ void Tensor::ReshapeStrings(const size_t &shape) { template T *Tensor::mutable_data(PlaceType place) { +#ifdef PADDLE_WITH_ONNXRUNTIME + if (is_ort_tensor_) { + return ORTGetMutableData(); + } +#endif EAGER_GET_TENSOR(paddle::framework::LoDTensor); PADDLE_ENFORCE_GT( tensor->numel(), @@ -670,6 +675,17 @@ void Tensor::SetOrtBinding(const std::shared_ptr binding) { binding_ = binding; } +template +T *Tensor::ORTGetMutableData() { + auto binding = binding_.lock(); + PADDLE_ENFORCE_NOT_NULL(binding, + paddle::platform::errors::PreconditionNotMet( + "output tensor [%s] no binding ptr", name_)); + std::vector outputs = binding->GetOutputValues(); + Ort::Value &value = outputs[idx_]; + return value.GetTensorMutableData(); +} + template void Tensor::ORTCopyToCpu(T *data) const { auto binding = binding_.lock(); diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index 2afe2d32e2f60..35ff80b9cbf09 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -190,6 +190,9 @@ class PD_INFER_DECL Tensor { void SetOrtBinding(const std::shared_ptr binding); + template + T* ORTGetMutableData(); + template void ORTCopyFromCpu(const T* data);