diff --git a/tensorflow_serving/core/BUILD b/tensorflow_serving/core/BUILD index 33ee3a71c44..799692fd813 100644 --- a/tensorflow_serving/core/BUILD +++ b/tensorflow_serving/core/BUILD @@ -417,6 +417,7 @@ cc_library( ":servable_id", ":servable_state", "//tensorflow_serving/util:event_bus", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@org_tensorflow//tensorflow/core:lib", ], diff --git a/tensorflow_serving/core/servable_state_monitor.cc b/tensorflow_serving/core/servable_state_monitor.cc index 8e2c5cf751b..46db7821b24 100644 --- a/tensorflow_serving/core/servable_state_monitor.cc +++ b/tensorflow_serving/core/servable_state_monitor.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/time/time.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow_serving/core/servable_state.h" @@ -234,11 +235,11 @@ void ServableStateMonitor::Notify(const NotifyFn& notify_fn) { notify_fns_.push_back(notify_fn); } -bool ServableStateMonitor::WaitUntilServablesReachState( +bool ServableStateMonitor::WaitUntilServablesReachStateWithTimeout( const std::vector& servables, - const ServableState::ManagerState goal_state, + const ServableState::ManagerState goal_state, absl::Duration timeout, std::map* const states_reached) { - bool reached_goal_state; + bool reached_goal_state = false; Notification notified; NotifyWhenServablesReachState( servables, goal_state, @@ -251,10 +252,19 @@ bool ServableStateMonitor::WaitUntilServablesReachState( reached_goal_state = incoming_reached_goal_state; notified.Notify(); }); - notified.WaitForNotification(); + notified.WaitForNotificationWithTimeout(timeout); return reached_goal_state; } +bool ServableStateMonitor::WaitUntilServablesReachState( + const std::vector& servables, + const ServableState::ManagerState goal_state, + std::map* const states_reached) { + return WaitUntilServablesReachStateWithTimeout( + servables, goal_state, + /*timeout=*/absl::InfiniteDuration(), states_reached); +} + void ServableStateMonitor::PreHandleEvent( const EventBus::EventAndTime& state_and_time) {} diff --git a/tensorflow_serving/core/servable_state_monitor.h b/tensorflow_serving/core/servable_state_monitor.h index a943ff8ef12..23f8528e289 100644 --- a/tensorflow_serving/core/servable_state_monitor.h +++ b/tensorflow_serving/core/servable_state_monitor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -156,11 +157,19 @@ class ServableStateMonitor { /// /// To understand the return value and the return parameter 'states_reached', /// please read the documentation on NotifyWhenServablesReachState(...). + /// WaitUntilServablesReachStateWithTimeout and WaitUntilServablesReachState + /// perform the same function, but the former has a timeout while the latter + /// waits indefinitely. + bool WaitUntilServablesReachStateWithTimeout( + const std::vector& servables, + ServableState::ManagerState goal_state, absl::Duration timeout, + std::map* states_reached = + nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT; bool WaitUntilServablesReachState( const std::vector& servables, ServableState::ManagerState goal_state, std::map* states_reached = - nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT; + nullptr) TF_MUST_USE_RESULT; // Subscribes to all servable state changes hitting this monitor. This is // called after the monitor updates its own state based on the event. diff --git a/tensorflow_serving/model_servers/BUILD b/tensorflow_serving/model_servers/BUILD index ebfe97fbc04..9b9d4ba1985 100644 --- a/tensorflow_serving/model_servers/BUILD +++ b/tensorflow_serving/model_servers/BUILD @@ -96,6 +96,7 @@ cc_library( "//tensorflow_serving/util:event_bus", "//tensorflow_serving/util:unique_ptr_with_deps", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:cc_wkt_protos", "@org_tensorflow//tensorflow/core:lib", diff --git a/tensorflow_serving/model_servers/server_core.cc b/tensorflow_serving/model_servers/server_core.cc index 16a39a8f0ff..f8e41850e8a 100644 --- a/tensorflow_serving/model_servers/server_core.cc +++ b/tensorflow_serving/model_servers/server_core.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow_serving/config/file_system_storage_path_source.pb.h" #include "tensorflow_serving/core/load_servables_fast.h" +#include "tensorflow_serving/core/servable_state_monitor.h" #include "tensorflow_serving/model_servers/model_platform_types.h" #include "tensorflow_serving/resources/resource_values.h" #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.h" @@ -296,9 +297,10 @@ Status ServerCore::WaitUntilModelsAvailable(const std::set& models, awaited_servables.push_back(ServableRequest::Latest(model)); } std::map states_reached; - const bool all_models_available = monitor->WaitUntilServablesReachState( - awaited_servables, ServableState::ManagerState::kAvailable, - &states_reached); + const bool all_models_available = + monitor->WaitUntilServablesReachStateWithTimeout( + awaited_servables, ServableState::ManagerState::kAvailable, + options_.servable_state_waiter_timeout, &states_reached); if (!all_models_available) { const int num_unavailable_models = std::count_if( states_reached.begin(), states_reached.end(), @@ -367,6 +369,7 @@ Status ServerCore::AddModelsViaModelConfigList() { } else { // Create a fresh servable state monitor, to avoid getting confused if we're // re-loading a model-version that has previously been unloaded. + ServableStateMonitor fresh_servable_state_monitor( servable_event_bus_.get()); diff --git a/tensorflow_serving/model_servers/server_core.h b/tensorflow_serving/model_servers/server_core.h index 13088c561a7..ca3553e9697 100644 --- a/tensorflow_serving/model_servers/server_core.h +++ b/tensorflow_serving/model_servers/server_core.h @@ -24,6 +24,7 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/base/macros.h" +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cpu_info.h" @@ -207,6 +208,8 @@ class ServerCore : public Manager { // If true, propagate current context to children threads (periodic // functions) in AspiredVersionsManager. bool with_current_context = false; + + absl::Duration servable_state_waiter_timeout = absl::InfiniteDuration(); }; virtual ~ServerCore() = default;