-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C-API output tensors reset #2683
Changes from 11 commits
9fd82e5
a06d705
21746b4
c2d5031
774c18b
1ab3f0b
a6c6aaf
e3ae208
5669ca7
9ea8028
9b6e37b
7108be2
c048cc1
23d1645
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1578,7 +1578,7 @@ cc_library( | |
srcs = ["test/openvino_remote_tensors_tests.cpp"], | ||
data = [ | ||
"test/c_api/config_standard_dummy.json", | ||
"test/c_api/config_gpu_dummy.json", | ||
"test/configs/config_gpu_dummy.json", | ||
"test/configs/config_gpu_face_detection_adas.json", | ||
"test/dummy/1/dummy.xml", | ||
"test/dummy/1/dummy.bin", | ||
|
@@ -1627,7 +1627,7 @@ cc_library( | |
srcs = ["test/openvino_tests.cpp"], | ||
data = [ | ||
"test/c_api/config_standard_dummy.json", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about standard dummy config? |
||
"test/c_api/config_gpu_dummy.json", | ||
"test/configs/config_gpu_dummy.json", | ||
"test/dummy/1/dummy.xml", | ||
"test/dummy/1/dummy.bin", | ||
], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -693,6 +693,21 @@ DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestRemoveInput(OVMS_InferenceRequest* | |
return nullptr; | ||
} | ||
|
||
DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestRemoveOutput(OVMS_InferenceRequest* req, const char* outputName) { | ||
if (req == nullptr) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference request")); | ||
} | ||
if (outputName == nullptr) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "output name")); | ||
} | ||
InferenceRequest* request = reinterpret_cast<InferenceRequest*>(req); | ||
auto status = request->removeOutput(outputName); | ||
if (!status.ok()) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(std::move(status))); | ||
} | ||
return nullptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
} | ||
|
||
DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestInputRemoveData(OVMS_InferenceRequest* req, const char* inputName) { | ||
if (req == nullptr) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference request")); | ||
|
@@ -708,6 +723,21 @@ DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestInputRemoveData(OVMS_InferenceReque | |
return nullptr; | ||
} | ||
|
||
DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestOutputRemoveData(OVMS_InferenceRequest* req, const char* outputName) { | ||
if (req == nullptr) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference request")); | ||
} | ||
if (outputName == nullptr) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "output name")); | ||
} | ||
InferenceRequest* request = reinterpret_cast<InferenceRequest*>(req); | ||
auto status = request->removeOutputBuffer(outputName); | ||
if (!status.ok()) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(std::move(status))); | ||
} | ||
return nullptr; | ||
} | ||
|
||
DLL_PUBLIC OVMS_Status* OVMS_InferenceResponseOutput(OVMS_InferenceResponse* res, uint32_t id, const char** name, OVMS_DataType* datatype, const int64_t** shape, size_t* dimCount, const void** data, size_t* bytesize, OVMS_BufferType* bufferType, uint32_t* deviceId) { | ||
if (res == nullptr) { | ||
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference response")); | ||
|
@@ -945,6 +975,7 @@ DLL_PUBLIC OVMS_Status* OVMS_Inference(OVMS_Server* serverPtr, OVMS_InferenceReq | |
} | ||
|
||
if (!status.ok()) { | ||
// TODO fixme error handling with callbacks - we may need to move callback usage here | ||
return reinterpret_cast<OVMS_Status*>(new Status(std::move(status))); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,13 @@ Status InferenceRequest::setOutputBuffer(const char* name, const void* addr, siz | |
} | ||
return it->second.setBuffer(addr, byteSize, bufferType, deviceId); | ||
} | ||
Status InferenceRequest::removeOutputBuffer(const char* name) { | ||
auto it = outputs.find(name); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's checked at capi.cpp |
||
if (it == outputs.end()) { | ||
return StatusCode::NONEXISTENT_TENSOR_FOR_REMOVE_BUFFER; | ||
} | ||
return it->second.removeBuffer(); | ||
} | ||
Status InferenceRequest::removeInputBuffer(const char* name) { | ||
auto it = inputs.find(name); | ||
if (it == inputs.end()) { | ||
|
@@ -99,6 +106,13 @@ Status InferenceRequest::removeInput(const char* name) { | |
} | ||
return StatusCode::NONEXISTENT_TENSOR_FOR_REMOVAL; | ||
} | ||
Status InferenceRequest::removeOutput(const char* name) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's checked at capi.cpp |
||
auto count = outputs.erase(name); | ||
if (count) { | ||
return StatusCode::OK; | ||
} | ||
return StatusCode::NONEXISTENT_TENSOR_FOR_REMOVAL; | ||
} | ||
Status InferenceRequest::addParameter(const char* parameterName, OVMS_DataType datatype, const void* data) { | ||
auto [it, emplaced] = parameters.emplace(parameterName, InferenceParameter{parameterName, datatype, data}); | ||
return emplaced ? StatusCode::OK : StatusCode::DOUBLE_PARAMETER_INSERT; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -964,6 +964,7 @@ Status ModelInstance::loadModelImpl(const ModelConfig& config, const DynamicMode | |
this->status.setLoading(ModelVersionStatusErrorCode::UNKNOWN); | ||
return status; | ||
} | ||
this->checkForOutputTensorResetAbility(); | ||
this->loadTensorFactories(); | ||
} catch (const ov::Exception& e) { | ||
SPDLOG_ERROR("exception occurred while loading model: {}", e.what()); | ||
|
@@ -1307,6 +1308,31 @@ void handleCallback(const InferenceRequest* request, InferenceResponse* response | |
} | ||
} | ||
|
||
struct OutputKeeper { | ||
std::unordered_map<std::string, ov::Tensor> outputs; | ||
ov::InferRequest& request; | ||
OutputKeeper(ov::InferRequest& request, const tensor_map_t& outputsInfo) : | ||
request(request) { | ||
for (auto [name, _] : outputsInfo) { | ||
OV_LOGGER("ov::InferRequest: {}, request.get_tensor({})", reinterpret_cast<void*>(&request), name); | ||
try { | ||
ov::Tensor tensor = request.get_tensor(name); | ||
OV_LOGGER("ov::Tensor(): {}", reinterpret_cast<void*>(&tensor)); | ||
outputs.emplace(std::make_pair(name, std::move(tensor))); | ||
OV_LOGGER("ov::Tensor(ov::Tensor&&): {}", reinterpret_cast<void*>(&outputs.at(name))); | ||
} catch (std::exception& e) { | ||
SPDLOG_DEBUG("Resetting output:{}; for this model is not supported. Check C-API documentation for OVMS_InferenceRequestOutputSetData. Error:", name, e.what()); | ||
} | ||
} | ||
} | ||
~OutputKeeper() { | ||
for (auto [name, v] : outputs) { | ||
OV_LOGGER("ov::InferRequest: {}, request.set_tensor({}, {})", reinterpret_cast<void*>(&request), name, reinterpret_cast<void*>(&v)); | ||
request.set_tensor(name, v); | ||
} | ||
} | ||
}; | ||
|
||
template <typename RequestType, typename ResponseType> | ||
Status ModelInstance::infer(const RequestType* requestProto, | ||
ResponseType* responseProto, | ||
|
@@ -1356,6 +1382,11 @@ Status ModelInstance::infer(const RequestType* requestProto, | |
timer.start(DESERIALIZE); | ||
InputSink<ov::InferRequest&> inputSink(inferRequest); | ||
bool isPipeline = false; | ||
|
||
std::unique_ptr<OutputKeeper> outKeeper; | ||
if (this->doesSupportOutputReset()) { | ||
outKeeper = std::make_unique<OutputKeeper>(executingStreamIdGuard.getInferRequest(), getOutputsInfo()); | ||
} | ||
status = deserializePredictRequest<ConcreteTensorProtoDeserializator, InputSink<ov::InferRequest&>>(*requestProto, getInputsInfo(), getOutputsInfo(), inputSink, isPipeline, this->tensorFactories); | ||
timer.stop(DESERIALIZE); | ||
if (!status.ok()) { | ||
|
@@ -1397,6 +1428,26 @@ Status ModelInstance::infer(const RequestType* requestProto, | |
// handleCallback(requestProto, responseProto); to be enabled when callbacks are implemented in network API's | ||
return status; | ||
} | ||
void ModelInstance::checkForOutputTensorResetAbility() { | ||
atobiszei marked this conversation as resolved.
Show resolved
Hide resolved
atobiszei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
StreamIdGuard guard(getInferRequestsQueue()); | ||
auto request = guard.getInferRequest(); | ||
bool allOutputsSupported = true; | ||
for (auto [name, _] : getOutputsInfo()) { | ||
try { | ||
ov::Tensor tensor = request.get_tensor(name); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible that we capture input name here? If model has an input called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Output name is logged in L1442/1439 |
||
} catch (std::exception& e) { | ||
SPDLOG_LOGGER_WARN(modelmanager_logger, "Resetting output:{}; for model:{}; version:{}, is not supported. Check C-API documentation for OVMS_InferenceRequestOutputSetData. Error:{}", name, getName(), getVersion(), e.what()); | ||
allOutputsSupported = false; | ||
} catch (...) { | ||
SPDLOG_LOGGER_WARN(modelmanager_logger, "Resetting output:{}; for model:{}; version:{}, is not supported. Check C-API documentation for OVMS_InferenceRequestOutputSetData.", name, getName(), getVersion()); | ||
allOutputsSupported = false; | ||
} | ||
} | ||
this->supportOutputTensorsReset = allOutputsSupported; | ||
} | ||
bool ModelInstance::doesSupportOutputReset() const { | ||
return this->supportOutputTensorsReset; | ||
} | ||
|
||
#pragma GCC diagnostic pop | ||
template <typename RequestType, typename ResponseType> | ||
|
@@ -1448,6 +1499,10 @@ Status ModelInstance::inferAsync(const RequestType* requestProto, | |
timer.start(DESERIALIZE); | ||
InputSink<ov::InferRequest&> inputSink(inferRequest); | ||
bool isPipeline = false; | ||
std::shared_ptr<OutputKeeper> outKeeper; | ||
if (this->doesSupportOutputReset()) { | ||
outKeeper = std::make_shared<OutputKeeper>(executingStreamIdGuard.getInferRequest(), getOutputsInfo()); | ||
} | ||
status = deserializePredictRequest<ConcreteTensorProtoDeserializator, InputSink<ov::InferRequest&>>(*requestProto, getInputsInfo(), getOutputsInfo(), inputSink, isPipeline, this->tensorFactories); | ||
timer.stop(DESERIALIZE); | ||
if (!status.ok()) { | ||
|
@@ -1467,7 +1522,7 @@ Status ModelInstance::inferAsync(const RequestType* requestProto, | |
void* userCallbackData = requestProto->getResponseCompleteCallbackData(); | ||
// here pass by copy into callback | ||
{ | ||
inferRequest.set_callback([this, requestProto, &inferRequest, userCallback, userCallbackData, modelUnloadGuardPtrMoved = std::shared_ptr<ModelInstanceUnloadGuard>(std::move(modelUnloadGuardPtr))](std::exception_ptr exception) mutable { | ||
inferRequest.set_callback([this, requestProto, &inferRequest, movedOutputKeeper = std::move(outKeeper), userCallback, userCallbackData, modelUnloadGuardPtrMoved = std::shared_ptr<ModelInstanceUnloadGuard>(std::move(modelUnloadGuardPtr))](std::exception_ptr exception) mutable { | ||
SPDLOG_DEBUG("Entry of ov::InferRequest callback call"); | ||
if (exception) { | ||
try { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about standard dummy config?