From b636b275aa88f77859dd13dac78984ab5b83de50 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Sat, 21 Sep 2024 20:41:56 -0700 Subject: [PATCH] Fix an issue that QNN models shared from other session use the session logger from that session (#22170) ### Description Fix an issue that QNN models shared from other session use the session logger from that producer session also which cause confusion. Make QNN model compute function use the session logger from current session. --- .../qnn/builder/onnx_ctx_model_helper.cc | 5 +- .../qnn/builder/onnx_ctx_model_helper.h | 1 - .../qnn/builder/qnn_backend_manager.cc | 5 +- .../qnn/builder/qnn_backend_manager.h | 1 - .../core/providers/qnn/builder/qnn_model.cc | 73 ++++++++++--------- .../core/providers/qnn/builder/qnn_model.h | 20 ++--- .../providers/qnn/qnn_execution_provider.cc | 21 +++--- .../test/providers/qnn/qnn_ep_context_test.cc | 15 ++-- 8 files changed, 72 insertions(+), 69 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 8ba2c6170b96c..57ae8c354abb7 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -87,7 +87,6 @@ Status CreateNodeArgs(const std::vector& names, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, QnnModelLookupTable& qnn_models) { ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); NodeAttrHelper node_helper(main_context_node); @@ -97,7 +96,6 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), main_context_node.Name(), - logger, qnn_models); } @@ -147,7 +145,6 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), main_context_node.Name(), - logger, qnn_models); } @@ -158,7 +155,7 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) { ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, - logger, qnn_models); + qnn_models); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 4ff7618b486e2..f308a7456d46c 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -49,7 +49,6 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, QnnModelLookupTable& qnn_models); Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index dde70fdcbdaa6..db5c2c5cb32ba 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -608,7 +608,6 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - const logging::Logger& logger, QnnModelLookupTable& qnn_models) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || @@ -665,12 +664,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t if (1 == graph_count) { // in case the EPContext node is generated from script // the graph name from the context binary may not match the EPContext node name - auto qnn_model = std::make_unique(logger, this); + auto qnn_model = std::make_unique(this); ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[0], context)); qnn_models.emplace(node_name, std::move(qnn_model)); } else { for (uint32_t i = 0; i < graph_count; ++i) { - auto qnn_model = std::make_unique(logger, this); + auto qnn_model = std::make_unique(this); ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[i], context)); qnn_models.emplace(graphs_info[i].graphInfoV1.graphName, std::move(qnn_model)); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index d1a3b46a8fc55..b80f1374fcdc7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -91,7 +91,6 @@ class QnnBackendManager { Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - const logging::Logger& logger, std::unordered_map>& qnn_models); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index a09b1daa81726..f322456e0c8f0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -17,7 +17,7 @@ namespace onnxruntime { namespace qnn { -bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper) { +bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger) { bool rt = true; graph_info_ = std::make_unique(model_wrapper.GetQnnGraph(), @@ -25,7 +25,7 @@ bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper) { std::move(model_wrapper.GetGraphInputTensorWrappers()), std::move(model_wrapper.GetGraphOutputTensorWrappers())); if (graph_info_ == nullptr) { - LOGS(logger_, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo."; + LOGS(logger, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo."; return false; } @@ -33,16 +33,19 @@ bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper) { } Status QnnModel::SetGraphInputOutputInfo(const GraphViewer& graph_viewer, - const onnxruntime::Node& fused_node) { + const onnxruntime::Node& fused_node, + const logging::Logger& logger) { auto graph_initializers = graph_viewer.GetAllInitializedTensors(); for (auto graph_ini : graph_initializers) { initializer_inputs_.emplace(graph_ini.first); } auto input_defs = fused_node.InputDefs(); - ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, model_input_index_map_, true)); + ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, + model_input_index_map_, logger, true)); auto output_defs = fused_node.OutputDefs(); - ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, model_output_index_map_)); + ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, + model_output_index_map_, logger)); return Status::OK(); } @@ -51,6 +54,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer& input_output_names, std::unordered_map& input_output_info_table, std::unordered_map& input_output_index_map, + const logging::Logger& logger, bool is_input) { for (size_t i = 0, end = input_output_defs.size(), index = 0; i < end; ++i) { const auto& name = input_output_defs[i]->Name(); @@ -60,7 +64,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainerShape(); // consider use qnn_model_wrapper.GetOnnxShape ORT_RETURN_IF(shape_proto == nullptr, "shape_proto cannot be null for output: ", name); @@ -91,8 +95,9 @@ const NodeUnit& QnnModel::GetNodeUnit(const Node* node, Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const logging::Logger& logger, const QnnGraph_Config_t** graph_configs) { - LOGS(logger_, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name(); + LOGS(logger, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name(); // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is // valid throughout the lifetime of the ModelBuilder @@ -102,9 +107,9 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); - ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node)); + ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); - QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger_, + QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger, qnn_backend_manager_->GetQnnInterface(), qnn_backend_manager_->GetQnnBackendHandle(), model_input_index_map_, @@ -121,65 +126,65 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, qnn_node_groups.reserve(node_unit_holder.size()); ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map, - node_unit_holder.size(), logger_)); + node_unit_holder.size(), logger)); for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { - Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger_); + Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger); if (!status.IsOK()) { - LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " - << status.ErrorMessage() << std::endl; + LOGS(logger, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " + << status.ErrorMessage() << std::endl; return status; } } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); - rt = GetGraphInfoFromModel(qnn_model_wrapper); + rt = GetGraphInfoFromModel(qnn_model_wrapper, logger); if (!rt) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetGraphInfoFromModel failed."); } - LOGS(logger_, VERBOSE) << "GetGraphInfoFromModel completed."; + LOGS(logger, VERBOSE) << "GetGraphInfoFromModel completed."; return Status::OK(); } -Status QnnModel::FinalizeGraphs() { - LOGS(logger_, VERBOSE) << "FinalizeGraphs started."; +Status QnnModel::FinalizeGraphs(const logging::Logger& logger) { + LOGS(logger, VERBOSE) << "FinalizeGraphs started."; Qnn_ErrorHandle_t status = qnn_backend_manager_->GetQnnInterface().graphFinalize(graph_info_->Graph(), qnn_backend_manager_->GetQnnProfileHandle(), nullptr); if (QNN_GRAPH_NO_ERROR != status) { - LOGS(logger_, ERROR) << "Failed to finalize QNN graph. Error code: " << status; + LOGS(logger, ERROR) << "Failed to finalize QNN graph. Error code: " << status; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to finalize QNN graph."); } ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); - LOGS(logger_, VERBOSE) << "FinalizeGraphs completed."; + LOGS(logger, VERBOSE) << "FinalizeGraphs completed."; return Status::OK(); } -Status QnnModel::SetupQnnInputOutput() { - LOGS(logger_, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name(); +Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { + LOGS(logger, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name(); auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); + LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!"); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); + LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!"); } return Status::OK(); } -Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { - LOGS(logger_, VERBOSE) << "QnnModel::ExecuteGraphs"; +Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger) { + LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs"; const size_t num_inputs = context.GetInputCount(); const size_t num_outputs = context.GetOutputCount(); ORT_RETURN_IF_NOT(qnn_input_infos_.size() <= num_inputs, "Inconsistent input sizes"); @@ -198,12 +203,12 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { qnn_inputs.reserve(qnn_input_infos_.size()); for (const auto& qnn_input_info : qnn_input_infos_) { - LOGS(logger_, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() - << " index = " << qnn_input_info.ort_index; + LOGS(logger, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() + << " index = " << qnn_input_info.ort_index; auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); auto ort_tensor_size = TensorDataSize(ort_input_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size."); @@ -217,13 +222,13 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { for (auto& qnn_output_info : qnn_output_infos_) { const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); - LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; + LOGS(logger, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; const auto& ort_output_info = GetOutputInfo(model_output_name); const std::vector& output_shape = ort_output_info->shape_; auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); auto ort_tensor_size = TensorDataSize(ort_output_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size"); @@ -232,7 +237,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { const_cast(ort_output_tensor.GetTensorData()), qnn_output_info.tensor_byte_size); } - LOGS(logger_, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); + LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; @@ -257,7 +262,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { if (QNN_COMMON_ERROR_SYSTEM_COMMUNICATION == execute_status) { auto error_message = "NPU crashed. SSR detected. Caused QNN graph execute error. Error code: "; - LOGS(logger_, ERROR) << error_message << execute_status; + LOGS(logger, ERROR) << error_message << execute_status; return ORT_MAKE_STATUS(ONNXRUNTIME, ENGINE_ERROR, error_message, execute_status); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 1416d9ba92671..83cf8f9f08fb0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -25,10 +25,8 @@ struct QnnTensorInfo { class QnnModel { public: - QnnModel(const logging::Logger& logger, - QnnBackendManager* qnn_backend_manager) - : logger_(logger), - qnn_backend_manager_(qnn_backend_manager) { + QnnModel(QnnBackendManager* qnn_backend_manager) + : qnn_backend_manager_(qnn_backend_manager) { qnn_backend_type_ = qnn_backend_manager_->GetQnnBackendType(); } @@ -37,13 +35,14 @@ class QnnModel { Status ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const logging::Logger& logger, const QnnGraph_Config_t** graph_configs = nullptr); - Status FinalizeGraphs(); + Status FinalizeGraphs(const logging::Logger& logger); - Status SetupQnnInputOutput(); + Status SetupQnnInputOutput(const logging::Logger& logger); - Status ExecuteGraph(const Ort::KernelContext& context); + Status ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger); const OnnxTensorInfo* GetOutputInfo(const std::string& name) const { auto it = outputs_info_.find(name); @@ -55,11 +54,13 @@ class QnnModel { } Status SetGraphInputOutputInfo(const GraphViewer& graph_viewer, - const onnxruntime::Node& fused_node); + const onnxruntime::Node& fused_node, + const logging::Logger& logger); Status ParseGraphInputOrOutput(ConstPointerContainer>& input_output_defs, std::vector& input_output_names, std::unordered_map& input_output_info_table, std::unordered_map& input_output_index, + const logging::Logger& logger, bool is_input = false); const std::unordered_set& GetInitializerInputs() const { return initializer_inputs_; } @@ -107,7 +108,7 @@ class QnnModel { private: const NodeUnit& GetNodeUnit(const Node* node, const std::unordered_map& node_unit_map) const; - bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper); + bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger); Status GetQnnTensorDataLength(const std::vector& dims, Qnn_DataType_t data_type, @@ -125,7 +126,6 @@ class QnnModel { } private: - const logging::Logger& logger_; std::unique_ptr graph_info_; QnnBackendManager* qnn_backend_manager_ = nullptr; // , initializer inputs are excluded, keep the input index here diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index f2991df3b1b8e..698ceaea7c3b7 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -789,10 +789,10 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod ORT_UNUSED_PARAMETER(state); }; - compute_info.compute_func = [](FunctionState state, const OrtApi*, OrtKernelContext* context) { + compute_info.compute_func = [&logger](FunctionState state, const OrtApi*, OrtKernelContext* context) { Ort::KernelContext ctx(context); qnn::QnnModel* model = reinterpret_cast(state); - Status result = model->ExecuteGraph(ctx); + Status result = model->ExecuteGraph(ctx, logger); return result; }; @@ -843,16 +843,15 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector qnn_model = std::make_unique(logger, - qnn_backend_manager_.get()); + std::unique_ptr qnn_model = std::make_unique(qnn_backend_manager_.get()); qnn::QnnConfigsBuilder graph_configs_builder(QNN_GRAPH_CONFIG_INIT, QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); InitQnnGraphConfigs(graph_configs_builder); - ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnConfigs())); - ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs()); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, logger, graph_configs_builder.GetQnnConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs(logger)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name(); qnn_models_.emplace(fused_node.Name(), std::move(qnn_model)); @@ -894,8 +893,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::string key = ep_context_node->Name(); auto qnn_model_shared = SharedContext::GetInstance().GetSharedQnnModel(key); ORT_RETURN_IF(nullptr == qnn_model_shared, "Graph: " + key + " not found from shared EP contexts."); - ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput()); + ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); + ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput(logger)); qnn_models_shared_.emplace(graph_meta_id, qnn_model_shared); use_shared_model_ = true; ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); @@ -929,8 +928,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::string key = ep_context_node->Name(); ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); auto qnn_model = std::move(qnn_models[key]); - ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] // the name here must be same with context->node_name in compute_info diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index d293a0d9c96c1..a3f0ed55b83f2 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -976,9 +976,14 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { UpdateEpContextModel(ctx_model_paths_to_update, last_qnn_ctx_binary_file_name, DefaultLoggingManager().DefaultLogger()); - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - so.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so1; + so1.SetLogId("so1"); + so1.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so1.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so2; + so2.SetLogId("so2"); + so2.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so2.AppendExecutionProvider("QNN", provider_options); EXPECT_TRUE(2 == ctx_model_paths.size()); #ifdef _WIN32 @@ -988,8 +993,8 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { std::string ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); std::string ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); #endif - Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so); - Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so); + Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so1); + Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so2); std::vector input_names; std::vector output_names;