diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index 2c2e0e4be57..8be9e52f6da 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -115,7 +115,9 @@ def create_experiment( constants.EXPERIMENT_PLURAL, experiment, ) - experiment_name = outputs["metadata"]["name"] # if "generate_name" is used, "name" gets a prefix from server + experiment_name = outputs["metadata"][ + "name" + ] # if "generate_name" is used, "name" gets a prefix from server except multiprocessing.TimeoutError: raise TimeoutError( f"Timeout to create Katib Experiment: {namespace}/{experiment_name}" @@ -155,9 +157,13 @@ def tune( parameters: Dict[str, Any], base_image: str = constants.BASE_IMAGE_TENSORFLOW, namespace: Optional[str] = None, - env_per_trial: Optional[Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]] = None, + env_per_trial: Optional[ + Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]] + ] = None, algorithm_name: str = "random", - algorithm_settings: Union[dict, List[models.V1beta1AlgorithmSetting], None] = None, + algorithm_settings: Union[ + dict, List[models.V1beta1AlgorithmSetting], None + ] = None, objective_metric_name: str = None, additional_metric_names: List[str] = [], objective_type: str = "maximize", @@ -188,9 +194,10 @@ def tune( objective function. base_image: Image to use when executing the objective function. namespace: Namespace for the Experiment. - env_per_trial: Environment variable(s) to be attached to each trial container. - You can specify a dictionary as a mapping object representing the environment variables. - Otherwise, you can specify a list, in which the element can either be a kubernetes.client.models.V1EnvVar (documented here: + env_per_trial: Environment variable(s) to be attached to each trial container. + You can specify a dictionary as a mapping object representing the environment + variables. Otherwise, you can specify a list, in which the element can either + be a kubernetes.client.models.V1EnvVar (documented here: https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md) or a kubernetes.client.models.V1EnvFromSource (documented here: https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md) @@ -232,7 +239,7 @@ def tune( pip_index_url: The PyPI url from which to install Python packages. Raises: - ValueError: Objective function has invalid arguments. + ValueError: Function arguments have incorrect type or value. TimeoutError: Timeout to create Katib Experiment. RuntimeError: Failed to create Katib Experiment. """ @@ -258,7 +265,10 @@ def tune( # Add Algorithm to the Katib Experiment. if isinstance(algorithm_settings, dict): - algorithm_settings = [models.V1beta1AlgorithmSetting(name=str(k), value=str(v)) for k, v in algorithm_settings.items()] + algorithm_settings = [ + models.V1beta1AlgorithmSetting(name=str(k), value=str(v)) + for k, v in algorithm_settings.items() + ] experiment.spec.algorithm = models.V1beta1AlgorithmSpec( algorithm_name=algorithm_name, @@ -340,15 +350,24 @@ def tune( requests=resources_per_trial, limits=resources_per_trial, ) - + + env = [] + env_from = [] if isinstance(env_per_trial, dict): - env, env_from = [client.V1EnvVar(name=str(k), value=str(v)) for k, v in env_per_trial.items()] or None, None - - if env_per_trial: - env = [x for x in env_per_trial if isinstance(x, client.V1EnvVar)] or None - env_from = [x for x in env_per_trial if isinstance(x, client.V1EnvFromSource)] or None - else: - env, env_from = None, None + env = [ + client.V1EnvVar(name=str(k), value=str(v)) + for k, v in env_per_trial.items() + ] + elif env_per_trial: + for x in env_per_trial: + if isinstance(x, client.V1EnvVar): + env.append(x) + elif isinstance(x, client.V1EnvFromSource): + env_from.append(x) + else: + raise ValueError( + f"Incorrect value for env_per_trial: {env_per_trial}" + ) # Create Trial specification. trial_spec = client.V1Job( @@ -1156,9 +1175,9 @@ def get_success_trial_details( ): output = {} output["name"] = trial.metadata.name - output[ - "parameter_assignments" - ] = trial.spec.parameter_assignments + output["parameter_assignments"] = ( + trial.spec.parameter_assignments + ) output["metrics"] = trial.status.observation.metrics result.append(output) except multiprocessing.TimeoutError: