Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C-API output tensors reset #2683

Merged
merged 14 commits into from
Oct 22, 2024
4 changes: 2 additions & 2 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ cc_library(
srcs = ["test/openvino_remote_tensors_tests.cpp"],
data = [
"test/c_api/config_standard_dummy.json",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about standard dummy config?

"test/c_api/config_gpu_dummy.json",
"test/configs/config_gpu_dummy.json",
"test/configs/config_gpu_face_detection_adas.json",
"test/dummy/1/dummy.xml",
"test/dummy/1/dummy.bin",
Expand Down Expand Up @@ -1627,7 +1627,7 @@ cc_library(
srcs = ["test/openvino_tests.cpp"],
data = [
"test/c_api/config_standard_dummy.json",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about standard dummy config?

"test/c_api/config_gpu_dummy.json",
"test/configs/config_gpu_dummy.json",
"test/dummy/1/dummy.xml",
"test/dummy/1/dummy.bin",
],
Expand Down
31 changes: 31 additions & 0 deletions src/capi_frontend/capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,21 @@ DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestRemoveInput(OVMS_InferenceRequest*
return nullptr;
}

DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestRemoveOutput(OVMS_InferenceRequest* req, const char* outputName) {
if (req == nullptr) {
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference request"));
}
if (outputName == nullptr) {
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "output name"));
}
InferenceRequest* request = reinterpret_cast<InferenceRequest*>(req);
auto status = request->removeOutput(outputName);
if (!status.ok()) {
return reinterpret_cast<OVMS_Status*>(new Status(std::move(status)));
}
return nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is nullptr interpreted as OK status?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

}

DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestInputRemoveData(OVMS_InferenceRequest* req, const char* inputName) {
if (req == nullptr) {
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference request"));
Expand All @@ -708,6 +723,21 @@ DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestInputRemoveData(OVMS_InferenceReque
return nullptr;
}

DLL_PUBLIC OVMS_Status* OVMS_InferenceRequestOutputRemoveData(OVMS_InferenceRequest* req, const char* outputName) {
if (req == nullptr) {
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference request"));
}
if (outputName == nullptr) {
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "output name"));
}
InferenceRequest* request = reinterpret_cast<InferenceRequest*>(req);
auto status = request->removeOutputBuffer(outputName);
if (!status.ok()) {
return reinterpret_cast<OVMS_Status*>(new Status(std::move(status)));
}
return nullptr;
}

DLL_PUBLIC OVMS_Status* OVMS_InferenceResponseOutput(OVMS_InferenceResponse* res, uint32_t id, const char** name, OVMS_DataType* datatype, const int64_t** shape, size_t* dimCount, const void** data, size_t* bytesize, OVMS_BufferType* bufferType, uint32_t* deviceId) {
if (res == nullptr) {
return reinterpret_cast<OVMS_Status*>(new Status(StatusCode::NONEXISTENT_PTR, "inference response"));
Expand Down Expand Up @@ -945,6 +975,7 @@ DLL_PUBLIC OVMS_Status* OVMS_Inference(OVMS_Server* serverPtr, OVMS_InferenceReq
}

if (!status.ok()) {
// TODO fixme error handling with callbacks - we may need to move callback usage here
return reinterpret_cast<OVMS_Status*>(new Status(std::move(status)));
}

Expand Down
14 changes: 14 additions & 0 deletions src/capi_frontend/inferencerequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ Status InferenceRequest::setOutputBuffer(const char* name, const void* addr, siz
}
return it->second.setBuffer(addr, byteSize, bufferType, deviceId);
}
Status InferenceRequest::removeOutputBuffer(const char* name) {
auto it = outputs.find(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to check name pointer validity first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's checked at capi.cpp

if (it == outputs.end()) {
return StatusCode::NONEXISTENT_TENSOR_FOR_REMOVE_BUFFER;
}
return it->second.removeBuffer();
}
Status InferenceRequest::removeInputBuffer(const char* name) {
auto it = inputs.find(name);
if (it == inputs.end()) {
Expand Down Expand Up @@ -99,6 +106,13 @@ Status InferenceRequest::removeInput(const char* name) {
}
return StatusCode::NONEXISTENT_TENSOR_FOR_REMOVAL;
}
Status InferenceRequest::removeOutput(const char* name) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's checked at capi.cpp

auto count = outputs.erase(name);
if (count) {
return StatusCode::OK;
}
return StatusCode::NONEXISTENT_TENSOR_FOR_REMOVAL;
}
Status InferenceRequest::addParameter(const char* parameterName, OVMS_DataType datatype, const void* data) {
auto [it, emplaced] = parameters.emplace(parameterName, InferenceParameter{parameterName, datatype, data});
return emplaced ? StatusCode::OK : StatusCode::DOUBLE_PARAMETER_INSERT;
Expand Down
3 changes: 2 additions & 1 deletion src/capi_frontend/inferencerequest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ class InferenceRequest {
Status getOutput(const char* name, const InferenceTensor** tensor) const;
uint64_t getInputsSize() const;
Status removeInput(const char* name);
Status removeOutput(const char* name);
Status removeAllInputs();

Status setInputBuffer(const char* name, const void* addr, size_t byteSize, OVMS_BufferType, std::optional<uint32_t> deviceId);
Status setOutputBuffer(const char* name, const void* addr, size_t byteSize, OVMS_BufferType, std::optional<uint32_t> deviceId);
// TODO TBD add equivalent for outputs?
Status removeInputBuffer(const char* name);
Status removeOutputBuffer(const char* name);

Status addParameter(const char* parameterName, OVMS_DataType datatype, const void* data);
Status removeParameter(const char* parameterName);
Expand Down
17 changes: 12 additions & 5 deletions src/executingstreamidguard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,27 @@ ExecutingStreamIdGuard::CurrentRequestsMetricGuard::~CurrentRequestsMetricGuard(
}

ExecutingStreamIdGuard::ExecutingStreamIdGuard(OVInferRequestsQueue& inferRequestsQueue, ModelMetricReporter& reporter) :
StreamIdGuard(inferRequestsQueue),
currentRequestsMetricGuard(reporter),
inferRequestsQueue_(inferRequestsQueue),
id_(inferRequestsQueue_.getIdleStream().get()),
inferRequest(inferRequestsQueue.getInferRequest(id_)),
reporter(reporter) {
INCREMENT_IF_ENABLED(this->reporter.inferReqActive);
}

ExecutingStreamIdGuard::~ExecutingStreamIdGuard() {
DECREMENT_IF_ENABLED(this->reporter.inferReqActive);
}

StreamIdGuard::StreamIdGuard(OVInferRequestsQueue& inferRequestsQueue) :
inferRequestsQueue_(inferRequestsQueue),
id_(inferRequestsQueue_.getIdleStream().get()),
inferRequest(inferRequestsQueue.getInferRequest(id_)) {
}

StreamIdGuard::~StreamIdGuard() {
this->inferRequestsQueue_.returnStream(this->id_);
}

int ExecutingStreamIdGuard::getId() { return this->id_; }
ov::InferRequest& ExecutingStreamIdGuard::getInferRequest() { return this->inferRequest; }
int StreamIdGuard::getId() { return this->id_; }
ov::InferRequest& StreamIdGuard::getInferRequest() { return this->inferRequest; }

} // namespace ovms
19 changes: 11 additions & 8 deletions src/executingstreamidguard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,19 @@ class OVInferRequestsQueue;
class ModelMetricReporter;
class OVInferRequestsQueue;

struct ExecutingStreamIdGuard {
ExecutingStreamIdGuard(ovms::OVInferRequestsQueue& inferRequestsQueue, ModelMetricReporter& reporter);
~ExecutingStreamIdGuard();

struct StreamIdGuard {
StreamIdGuard(ovms::OVInferRequestsQueue& inferRequestsQueue);
~StreamIdGuard();
int getId();
ov::InferRequest& getInferRequest();
OVInferRequestsQueue& inferRequestsQueue_;
const int id_;
ov::InferRequest& inferRequest;
};

struct ExecutingStreamIdGuard : public StreamIdGuard {
ExecutingStreamIdGuard(ovms::OVInferRequestsQueue& inferRequestsQueue, ModelMetricReporter& reporter);
~ExecutingStreamIdGuard();

private:
class CurrentRequestsMetricGuard {
Expand All @@ -40,11 +47,7 @@ struct ExecutingStreamIdGuard {
CurrentRequestsMetricGuard(ModelMetricReporter& reporter);
~CurrentRequestsMetricGuard();
};

CurrentRequestsMetricGuard currentRequestsMetricGuard;
OVInferRequestsQueue& inferRequestsQueue_;
const int id_;
ov::InferRequest& inferRequest;
ModelMetricReporter& reporter;
};

Expand Down
57 changes: 56 additions & 1 deletion src/modelinstance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ Status ModelInstance::loadModelImpl(const ModelConfig& config, const DynamicMode
this->status.setLoading(ModelVersionStatusErrorCode::UNKNOWN);
return status;
}
this->checkForOutputTensorResetAbility();
this->loadTensorFactories();
} catch (const ov::Exception& e) {
SPDLOG_ERROR("exception occurred while loading model: {}", e.what());
Expand Down Expand Up @@ -1307,6 +1308,31 @@ void handleCallback(const InferenceRequest* request, InferenceResponse* response
}
}

struct OutputKeeper {
std::unordered_map<std::string, ov::Tensor> outputs;
ov::InferRequest& request;
OutputKeeper(ov::InferRequest& request, const tensor_map_t& outputsInfo) :
request(request) {
for (auto [name, _] : outputsInfo) {
OV_LOGGER("ov::InferRequest: {}, request.get_tensor({})", reinterpret_cast<void*>(&request), name);
try {
ov::Tensor tensor = request.get_tensor(name);
OV_LOGGER("ov::Tensor(): {}", reinterpret_cast<void*>(&tensor));
outputs.emplace(std::make_pair(name, std::move(tensor)));
OV_LOGGER("ov::Tensor(ov::Tensor&&): {}", reinterpret_cast<void*>(&outputs.at(name)));
} catch (std::exception& e) {
SPDLOG_DEBUG("Resetting output:{}; for this model is not supported. Check C-API documentation for OVMS_InferenceRequestOutputSetData. Error:", name, e.what());
}
}
}
~OutputKeeper() {
for (auto [name, v] : outputs) {
OV_LOGGER("ov::InferRequest: {}, request.set_tensor({}, {})", reinterpret_cast<void*>(&request), name, reinterpret_cast<void*>(&v));
request.set_tensor(name, v);
}
}
};

template <typename RequestType, typename ResponseType>
Status ModelInstance::infer(const RequestType* requestProto,
ResponseType* responseProto,
Expand Down Expand Up @@ -1356,6 +1382,11 @@ Status ModelInstance::infer(const RequestType* requestProto,
timer.start(DESERIALIZE);
InputSink<ov::InferRequest&> inputSink(inferRequest);
bool isPipeline = false;

std::unique_ptr<OutputKeeper> outKeeper;
if (this->doesSupportOutputReset()) {
outKeeper = std::make_unique<OutputKeeper>(executingStreamIdGuard.getInferRequest(), getOutputsInfo());
}
status = deserializePredictRequest<ConcreteTensorProtoDeserializator, InputSink<ov::InferRequest&>>(*requestProto, getInputsInfo(), getOutputsInfo(), inputSink, isPipeline, this->tensorFactories);
timer.stop(DESERIALIZE);
if (!status.ok()) {
Expand Down Expand Up @@ -1397,6 +1428,26 @@ Status ModelInstance::infer(const RequestType* requestProto,
// handleCallback(requestProto, responseProto); to be enabled when callbacks are implemented in network API's
return status;
}
void ModelInstance::checkForOutputTensorResetAbility() {
atobiszei marked this conversation as resolved.
Show resolved Hide resolved
atobiszei marked this conversation as resolved.
Show resolved Hide resolved
StreamIdGuard guard(getInferRequestsQueue());
auto request = guard.getInferRequest();
bool allOutputsSupported = true;
for (auto [name, _] : getOutputsInfo()) {
try {
ov::Tensor tensor = request.get_tensor(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we capture input name here? If model has an input called name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output name is logged in L1442/1439

} catch (std::exception& e) {
SPDLOG_LOGGER_WARN(modelmanager_logger, "Resetting output:{}; for model:{}; version:{}, is not supported. Check C-API documentation for OVMS_InferenceRequestOutputSetData. Error:{}", name, getName(), getVersion(), e.what());
allOutputsSupported = false;
} catch (...) {
SPDLOG_LOGGER_WARN(modelmanager_logger, "Resetting output:{}; for model:{}; version:{}, is not supported. Check C-API documentation for OVMS_InferenceRequestOutputSetData.", name, getName(), getVersion());
allOutputsSupported = false;
}
}
this->supportOutputTensorsReset = allOutputsSupported;
}
bool ModelInstance::doesSupportOutputReset() const {
return this->supportOutputTensorsReset;
}

#pragma GCC diagnostic pop
template <typename RequestType, typename ResponseType>
Expand Down Expand Up @@ -1448,6 +1499,10 @@ Status ModelInstance::inferAsync(const RequestType* requestProto,
timer.start(DESERIALIZE);
InputSink<ov::InferRequest&> inputSink(inferRequest);
bool isPipeline = false;
std::shared_ptr<OutputKeeper> outKeeper;
if (this->doesSupportOutputReset()) {
outKeeper = std::make_shared<OutputKeeper>(executingStreamIdGuard.getInferRequest(), getOutputsInfo());
}
status = deserializePredictRequest<ConcreteTensorProtoDeserializator, InputSink<ov::InferRequest&>>(*requestProto, getInputsInfo(), getOutputsInfo(), inputSink, isPipeline, this->tensorFactories);
timer.stop(DESERIALIZE);
if (!status.ok()) {
Expand All @@ -1467,7 +1522,7 @@ Status ModelInstance::inferAsync(const RequestType* requestProto,
void* userCallbackData = requestProto->getResponseCompleteCallbackData();
// here pass by copy into callback
{
inferRequest.set_callback([this, requestProto, &inferRequest, userCallback, userCallbackData, modelUnloadGuardPtrMoved = std::shared_ptr<ModelInstanceUnloadGuard>(std::move(modelUnloadGuardPtr))](std::exception_ptr exception) mutable {
inferRequest.set_callback([this, requestProto, &inferRequest, movedOutputKeeper = std::move(outKeeper), userCallback, userCallbackData, modelUnloadGuardPtrMoved = std::shared_ptr<ModelInstanceUnloadGuard>(std::move(modelUnloadGuardPtr))](std::exception_ptr exception) mutable {
SPDLOG_DEBUG("Entry of ov::InferRequest callback call");
if (exception) {
try {
Expand Down
6 changes: 6 additions & 0 deletions src/modelinstance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ class ModelInstance {
virtual Status loadInputTensorsImpl(const ModelConfig& config, const DynamicModelParameter& parameter = DynamicModelParameter());

private:
/**
* @brief Determines if during inference we are able to reset ov::InferRequest output tensor to original state, which is required for setting output functionality to be interoperable with both inferences with and without output set.
*/
void checkForOutputTensorResetAbility();
bool supportOutputTensorsReset = true;
bool doesSupportOutputReset() const;
Status gatherReshapeInfo(bool isBatchingModeAuto, const DynamicModelParameter& parameter, bool& isReshapeRequired, std::map<std::string, ov::PartialShape>& modelShapes);

/**
Expand Down
14 changes: 14 additions & 0 deletions src/ovms.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,27 @@ OVMS_Status* OVMS_InferenceRequestOutputSetData(OVMS_InferenceRequest* request,
// \return OVMS_Status object in case of failure
OVMS_Status* OVMS_InferenceRequestInputRemoveData(OVMS_InferenceRequest* request, const char* inputName);

// Remove the data of the output.
//
// \param request The request object
// \param outputName The name of the output with data to be removed
// \return OVMS_Status object in case of failure
OVMS_Status* OVMS_InferenceRequestOutputRemoveData(OVMS_InferenceRequest* request, const char* outputName);

// Remove input from the request.
//
// \param request The request object
// \param inputName The name of the input to be removed
// \return OVMS_Status object in case of failure
OVMS_Status* OVMS_InferenceRequestRemoveInput(OVMS_InferenceRequest* request, const char* inputName);

// Remove output from the request.
//
// \param request The request object
// \param outputName The name of the input to be removed
// \return OVMS_Status object in case of failure
OVMS_Status* OVMS_InferenceRequestRemoveOutput(OVMS_InferenceRequest* request, const char* outputName);

// Add parameter to the request.
//
// \param request The request object
Expand Down
1 change: 1 addition & 0 deletions src/predict_request_validation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ Status RequestValidator<RequestType, InputTensorType, IteratorType, ShapeType>::
RETURN_IF_ERR(checkBatchSizeMismatch(proto, inputInfo->getBatchSize(), batchIndex, finalStatus, batchingMode, shapeMode));
RETURN_IF_ERR(checkShapeMismatch(proto, *inputInfo, batchIndex, finalStatus, batchingMode, shapeMode));
RETURN_IF_ERR(validateTensorContent(proto, inputInfo->getPrecision(), bufferId));
// TODO FIXME we need validation for output for C-API (if buffer is empty, shape against buffersize, type)
}
return finalStatus;
}
Expand Down
1 change: 1 addition & 0 deletions src/test/c_api_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
ASSERT_NE(err, nullptr); \
} \
}

struct ServerSettingsGuard {
ServerSettingsGuard(int port) {
THROW_ON_ERROR_CAPI(OVMS_ServerSettingsNew(&settings));
Expand Down
Loading