From 2d4f82b2b17a438e459c99d625311c77fce69d20 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Fri, 20 Dec 2024 02:39:37 -0800 Subject: [PATCH] Fix Dispatch API test cases The NPU binary files were not being loaded from the right path PiperOrigin-RevId: 708249997 --- tflite/experimental/litert/core/filesystem.cc | 4 ++-- .../dispatch/dispatch_api_google_tensor_test.cc | 5 +++-- .../litert/vendors/mediatek/dispatch/BUILD | 1 + .../mediatek/dispatch/dispatch_api_mediatek_test.cc | 6 ++++-- .../qualcomm/compiler/qnn_compiler_plugin.cc | 10 +++++++++- .../litert/vendors/qualcomm/dispatch/BUILD | 1 + .../qualcomm/dispatch/dispatch_api_qualcomm_test.cc | 13 ++++++++----- 7 files changed, 28 insertions(+), 12 deletions(-) diff --git a/tflite/experimental/litert/core/filesystem.cc b/tflite/experimental/litert/core/filesystem.cc index 5cd56fe5..1beb28c3 100644 --- a/tflite/experimental/litert/core/filesystem.cc +++ b/tflite/experimental/litert/core/filesystem.cc @@ -77,7 +77,7 @@ bool Exists(absl::string_view path) { return StdExists(MakeStdPath(path)); } Expected Size(absl::string_view path) { auto std_path = MakeStdPath(path); if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorNotFound); + return Error(kLiteRtStatusErrorNotFound, "File not found"); } return StdSize(std_path); } @@ -86,7 +86,7 @@ Expected> LoadBinaryFile(absl::string_view path) { auto std_path = MakeStdPath(path); if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorFileIO); + return Error(kLiteRtStatusErrorFileIO, "File not found"); } OwningBufferRef buf(StdSize(std_path)); diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc index 7d8420c7..a8201548 100644 --- a/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc @@ -61,9 +61,10 @@ TEST(DispatchApi, GoogleTensor) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kGoogleTensorModelFileName; + auto model_file_name = + litert::testing::GetTestFilePath(kGoogleTensorModelFileName); auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes"; diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/BUILD b/tflite/experimental/litert/vendors/mediatek/dispatch/BUILD index acc89511..69a95f33 100644 --- a/tflite/experimental/litert/vendors/mediatek/dispatch/BUILD +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/BUILD @@ -75,6 +75,7 @@ cc_test( "//tflite/experimental/litert/c:litert_common", "//tflite/experimental/litert/c:litert_tensor_buffer", "//tflite/experimental/litert/core:filesystem", + "//tflite/experimental/litert/test:common", "//tflite/experimental/litert/test:simple_model_npu", "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", "@com_google_absl//absl/log", diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc b/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc index 111f9663..469df9dc 100644 --- a/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc @@ -24,6 +24,7 @@ #include "tflite/experimental/litert/c/litert_tensor_buffer.h" #include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" #include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/test/common.h" #include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" #include "tflite/experimental/litert/vendors/c/litert_dispatch.h" @@ -60,9 +61,10 @@ TEST(DispatchApi, MediaTek) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kMediaTekModelFileName; + auto model_file_name = + litert::testing::GetTestFilePath(kMediaTekModelFileName); auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes"; diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc index 74ffc974..a1e26920 100644 --- a/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc @@ -155,6 +155,9 @@ struct LiteRtCompiledResultT { LiteRtStatus LiteRtGetCompiledResultByteCode( LiteRtCompiledResult compiled_result, const void** byte_code, size_t* byte_code_size) { + if (!compiled_result || !byte_code || !byte_code_size) { + return kLiteRtStatusErrorInvalidArgument; + } *byte_code = compiled_result->context_bin.data(); *byte_code_size = compiled_result->context_bin.size(); return kLiteRtStatusOk; @@ -163,7 +166,9 @@ LiteRtStatus LiteRtGetCompiledResultByteCode( LiteRtStatus LiteRtGetCompiledResultCallInfo( LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, const void** call_info, size_t* call_info_size) { - if (call_idx >= compiled_result->graph_names.size()) { + if (!compiled_result || !call_info || !call_info_size) { + return kLiteRtStatusErrorInvalidArgument; + } else if (call_idx >= compiled_result->graph_names.size()) { return kLiteRtStatusErrorIndexOOB; } @@ -175,6 +180,9 @@ LiteRtStatus LiteRtGetCompiledResultCallInfo( LiteRtStatus LiteRtGetNumCompiledResultCalls( LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + if (!compiled_result || !num_calls) { + return kLiteRtStatusErrorInvalidArgument; + } *num_calls = compiled_result->graph_names.size(); return kLiteRtStatusOk; } diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD b/tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD index e0134009..41ad94b1 100644 --- a/tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD @@ -83,6 +83,7 @@ cc_test( "//tflite/experimental/litert/c:litert_common", "//tflite/experimental/litert/c:litert_tensor_buffer", "//tflite/experimental/litert/core:filesystem", + "//tflite/experimental/litert/test:common", "//tflite/experimental/litert/test:simple_model_npu", "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", "@com_google_absl//absl/log", diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc b/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc index 29057d91..cfa54372 100644 --- a/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc @@ -24,6 +24,7 @@ #include "tflite/experimental/litert/c/litert_tensor_buffer.h" #include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" #include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/test/common.h" #include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" #include "tflite/experimental/litert/vendors/c/litert_dispatch.h" @@ -60,9 +61,10 @@ TEST(Qualcomm, DispatchApiWithFastRpc) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kQualcommModelFileName; + auto model_file_name = + litert::testing::GetTestFilePath(kQualcommModelFileName); auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes"; @@ -311,9 +313,10 @@ TEST(Qualcomm, DispatchApiWithDmaBuf) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kQualcommModelFileName; - auto model = ::litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + auto model_file_name = + litert::testing::GetTestFilePath(kQualcommModelFileName); + auto model = litert::internal::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes";