Skip to content

Commit

Permalink
Avoid GetChildren when using Specific servable versions
Browse files Browse the repository at this point in the history
For some filesystem providers, like GCS, GetChildren does a lot more work that a
simple FileExists call. This change special cases the `SPECIFIC`
ServableVersionPolicy and does direct FileExists calls for each one. In the
common case of a single version, this can be a single stat() call and avoid an
expensive directory listing entirely.

This optimization *only* applies when the versions and directories are
equivalent to "base_dir/%d". So this fast path now happens before the
GetChildren call, but will fall back to the general case of a directory listing
when there are folders like:

 base_dir/
   - 00001/
   - 2/

and you want the specific version 1. Generally speaking, the support for
strtod-ifying the string name is nice, but forces the directory listing.

PiperOrigin-RevId: 627111947
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed Apr 22, 2024
1 parent 0fe7da7 commit 6fb9403
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 4 deletions.
2 changes: 2 additions & 0 deletions tensorflow_serving/sources/storage_path/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ cc_library(
"//tensorflow_serving/core:servable_id",
"//tensorflow_serving/core:source",
"//tensorflow_serving/core:storage_path",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:tensorflow",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/macros.h"
#include "tensorflow_serving/core/servable_data.h"
#include "tensorflow_serving/core/servable_id.h"

Expand Down Expand Up @@ -159,6 +164,49 @@ bool AspireLatestVersions(
return !children_by_version.empty();
}

// Like `AspireSpecificVersions` but use `FileExists` instead of GetChildren to
// remove unnecessary directory listings. Note that this function has to
// fallback to the general case when there are directories that *parse as* the
// version number via `strtod` but aren't equivalent (e.g., "base_dir/00001"
// rather than "base_dir/1").
//
// Returns true if all the models are loaded.
bool AspireSpecificVersionsFastPath(
const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
std::vector<ServableData<StoragePath>>* versions) {
if (servable.servable_version_policy().specific().versions().empty()) {
// There aren't any requested versions, WARN loudly and explicitly, since
// this is a likely configuration error. Return *true*, since we are done
// with processing this servable.
LOG(WARNING) << "No specific versions requested for servable "
<< servable.servable_name() << ".";
return true;
}

// First ensure that we find *all* the requested versions, so that we can use
// this fast path. If not, we'll call the general AspireSpecificVersions after
// a GetChildren call.
for (const int64_t version :
servable.servable_version_policy().specific().versions()) {
const string version_dir = absl::StrCat(version);
const string child_dir = io::JoinPath(servable.base_path(), version_dir);

const absl::Status status = Env::Default()->FileExists(child_dir);
if (!status.ok()) {
return false;
}
}

// We've found them all. Aspire them one by one.
for (const int64_t version :
servable.servable_version_policy().specific().versions()) {
const string version_dir = absl::StrCat(version);
AspireVersion(servable, version_dir, version, versions);
}

return true;
}

// Aspire versions for a servable configured with the "specific" version policy.
//
// 'children' represents a list of base-path children from the file system.
Expand Down Expand Up @@ -213,6 +261,16 @@ Status PollFileSystemForServable(
servable.servable_name(), " with error ", status.ToString());
}

if (servable.servable_version_policy().policy_choice_case() ==
FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific) {
// Special case the specific handler, to avoid GetChildren in the case where
// all of the directories match their version number.
if (AspireSpecificVersionsFastPath(servable, versions)) {
// We found them all, exit early.
return absl::OkStatus();
}
}

// Retrieve a list of base-path children from the file system.
std::vector<string> children;
TF_RETURN_IF_ERROR(
Expand Down Expand Up @@ -243,11 +301,10 @@ Status PollFileSystemForServable(
at_least_one_version_found =
AspireAllVersions(servable, children, versions);
break;
case FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific: {
case FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific:
at_least_one_version_found =
AspireSpecificVersions(servable, children_by_version, versions);
break;
}
default:
return errors::Internal("Unhandled servable version_policy: ",
servable.servable_version_policy().DebugString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,89 @@ TEST(FileSystemStoragePathSourceTest, SpecificVersions) {
.PollFileSystemAndInvokeCallback());
}

// This is the same as the `SpecificVersions` test above, but with leading zeros
// on one of the directories to ensure we maintain the `strtod` property of
// directory name => version number.
TEST(FileSystemStoragePathSourceTest, SpecificVersionsLeadingZeros) {
const string base_path =
io::JoinPath(testing::TmpDir(), "SpecificVersionsLeadingZeros");
TF_ASSERT_OK(Env::Default()->CreateDir(base_path));
for (const string& version :
{"non_numerical_child", "42", "33", "30", "21", "00017"}) {
TF_ASSERT_OK(Env::Default()->CreateDir(io::JoinPath(base_path, version)));
}

const FileSystemStoragePathSourceConfig config =
test_util::CreateProto<FileSystemStoragePathSourceConfig>(
strings::Printf("servables: { "
" servable_version_policy { "
" specific { "
" versions: 17"
" versions: 30"
" } "
" } "
" servable_name: 'test_servable_name' "
" base_path: '%s' "
"} "
// Disable the polling thread.
"file_system_poll_wait_seconds: -1 ",
base_path.c_str()));
std::unique_ptr<FileSystemStoragePathSource> source;
TF_ASSERT_OK(FileSystemStoragePathSource::Create(config, &source));
std::unique_ptr<test_util::MockStoragePathTarget> target(
new StrictMock<test_util::MockStoragePathTarget>);
ConnectSourceToTarget(source.get(), target.get());

EXPECT_CALL(
*target,
SetAspiredVersions(
Eq("test_servable_name"),
ElementsAre(
ServableData<StoragePath>({"test_servable_name", 17},
io::JoinPath(base_path, "00017")),
ServableData<StoragePath>({"test_servable_name", 30},
io::JoinPath(base_path, "30")))));

TF_ASSERT_OK(internal::FileSystemStoragePathSourceTestAccess(source.get())
.PollFileSystemAndInvokeCallback());
}

TEST(FileSystemStoragePathSourceTest, SpecificVersionsEmpty) {
const string base_path =
io::JoinPath(testing::TmpDir(), "SpecificVersionsEmpty");
TF_ASSERT_OK(Env::Default()->CreateDir(base_path));
for (const string& version :
{"non_numerical_child", "42", "33", "30", "21", "17"}) {
TF_ASSERT_OK(Env::Default()->CreateDir(io::JoinPath(base_path, version)));
}

const FileSystemStoragePathSourceConfig config =
test_util::CreateProto<FileSystemStoragePathSourceConfig>(
strings::Printf("servables: { "
" servable_version_policy { "
" specific { "
" } "
" } "
" servable_name: 'test_servable_name' "
" base_path: '%s' "
"} "
// Disable the polling thread.
"file_system_poll_wait_seconds: -1 ",
base_path.c_str()));
std::unique_ptr<FileSystemStoragePathSource> source;
TF_ASSERT_OK(FileSystemStoragePathSource::Create(config, &source));
std::unique_ptr<test_util::MockStoragePathTarget> target(
new StrictMock<test_util::MockStoragePathTarget>);
ConnectSourceToTarget(source.get(), target.get());

// The servable has no requested versions, but we still want to call
// SetAspiredVersions with an empty list for consistency.
EXPECT_CALL(*target, SetAspiredVersions(Eq("test_servable_name"), IsEmpty()));

TF_ASSERT_OK(internal::FileSystemStoragePathSourceTestAccess(source.get())
.PollFileSystemAndInvokeCallback());
}

TEST(FileSystemStoragePathSourceTest, DefaultVersionPolicy) {
// Validate that default version policy is to serve the latest servable
// version.
Expand Down Expand Up @@ -512,7 +595,7 @@ TEST(FileSystemStoragePathSourceTest, ChangeVersionPolicy) {
const string base_path_prefix =
io::JoinPath(testing::TmpDir(), "ChangeVersionPolicy_");
TF_ASSERT_OK(Env::Default()->CreateDir(base_path_prefix));
for (const string& version : {"1", "2", "3", "5", "8", "13"}) {
for (const string& version : {"1", "02", "3", "5", "8", "13"}) {
TF_ASSERT_OK(
Env::Default()->CreateDir(io::JoinPath(base_path_prefix, version)));
}
Expand Down Expand Up @@ -572,7 +655,7 @@ TEST(FileSystemStoragePathSourceTest, ChangeVersionPolicy) {
Eq("test_servable_name"),
ElementsAre(
ServableData<StoragePath>({"test_servable_name", 2},
io::JoinPath(base_path_prefix, "2")),
io::JoinPath(base_path_prefix, "02")),
ServableData<StoragePath>({"test_servable_name", 5},
io::JoinPath(base_path_prefix, "5")))));

Expand Down

0 comments on commit 6fb9403

Please sign in to comment.