diff --git a/cmd/suggestion/hyperband/v1alpha3/main.py b/cmd/suggestion/hyperband/v1alpha3/main.py index 57ccc1fe571..e9ad4aebc1a 100644 --- a/cmd/suggestion/hyperband/v1alpha3/main.py +++ b/cmd/suggestion/hyperband/v1alpha3/main.py @@ -8,10 +8,12 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24 DEFAULT_PORT = "0.0.0.0:6789" + def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) service = HyperbandService() - api_pb2_grpc.add_SuggestionServicer_to_server(service, server) health_pb2_grpc.add_HealthServicer_to_server(service, server) + api_pb2_grpc.add_SuggestionServicer_to_server(service, server) + health_pb2_grpc.add_HealthServicer_to_server(service, server) server.add_insecure_port(DEFAULT_PORT) print("Listening...") @@ -22,5 +24,6 @@ def serve(): except KeyboardInterrupt: server.stop(0) + if __name__ == "__main__": serve() diff --git a/pkg/suggestion/v1alpha3/hyperband_service.py b/pkg/suggestion/v1alpha3/hyperband_service.py index 0d4e6290fa6..bec1e628a30 100644 --- a/pkg/suggestion/v1alpha3/hyperband_service.py +++ b/pkg/suggestion/v1alpha3/hyperband_service.py @@ -25,6 +25,7 @@ def GetSuggestions(self, request, context): try: reply = api_pb2.GetSuggestionsReply() experiment = request.experiment + self.all_trials = request.trials alg_settings = experiment.spec.algorithm.algorithm_setting param = HyperBandParam.convert(alg_settings) @@ -126,7 +127,7 @@ def get_objective_value(t): return float(m.value) top_trials = [] - all_trials = self._get_trials(experiment.name) + all_trials = self.all_trials latest_trials = self._get_last_trials(all_trials, latest_trials_num) for t in latest_trials: