Skip to content

Commit

Permalink
[SDK] Fix env per Trial parameter in tune API (#2304)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
  • Loading branch information
andreyvelich authored Apr 11, 2024
1 parent 7df05c2 commit 086093f
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def create_experiment(
constants.EXPERIMENT_PLURAL,
experiment,
)
experiment_name = outputs["metadata"]["name"] # if "generate_name" is used, "name" gets a prefix from server
experiment_name = outputs["metadata"][
"name"
] # if "generate_name" is used, "name" gets a prefix from server
except multiprocessing.TimeoutError:
raise TimeoutError(
f"Timeout to create Katib Experiment: {namespace}/{experiment_name}"
Expand Down Expand Up @@ -155,9 +157,13 @@ def tune(
parameters: Dict[str, Any],
base_image: str = constants.BASE_IMAGE_TENSORFLOW,
namespace: Optional[str] = None,
env_per_trial: Optional[Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]] = None,
env_per_trial: Optional[
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
] = None,
algorithm_name: str = "random",
algorithm_settings: Union[dict, List[models.V1beta1AlgorithmSetting], None] = None,
algorithm_settings: Union[
dict, List[models.V1beta1AlgorithmSetting], None
] = None,
objective_metric_name: str = None,
additional_metric_names: List[str] = [],
objective_type: str = "maximize",
Expand Down Expand Up @@ -188,9 +194,10 @@ def tune(
objective function.
base_image: Image to use when executing the objective function.
namespace: Namespace for the Experiment.
env_per_trial: Environment variable(s) to be attached to each trial container.
You can specify a dictionary as a mapping object representing the environment variables.
Otherwise, you can specify a list, in which the element can either be a kubernetes.client.models.V1EnvVar (documented here:
env_per_trial: Environment variable(s) to be attached to each trial container.
You can specify a dictionary as a mapping object representing the environment
variables. Otherwise, you can specify a list, in which the element can either
be a kubernetes.client.models.V1EnvVar (documented here:
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
or a kubernetes.client.models.V1EnvFromSource (documented here:
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
Expand Down Expand Up @@ -232,7 +239,7 @@ def tune(
pip_index_url: The PyPI url from which to install Python packages.
Raises:
ValueError: Objective function has invalid arguments.
ValueError: Function arguments have incorrect type or value.
TimeoutError: Timeout to create Katib Experiment.
RuntimeError: Failed to create Katib Experiment.
"""
Expand All @@ -258,7 +265,10 @@ def tune(

# Add Algorithm to the Katib Experiment.
if isinstance(algorithm_settings, dict):
algorithm_settings = [models.V1beta1AlgorithmSetting(name=str(k), value=str(v)) for k, v in algorithm_settings.items()]
algorithm_settings = [
models.V1beta1AlgorithmSetting(name=str(k), value=str(v))
for k, v in algorithm_settings.items()
]

experiment.spec.algorithm = models.V1beta1AlgorithmSpec(
algorithm_name=algorithm_name,
Expand Down Expand Up @@ -340,15 +350,24 @@ def tune(
requests=resources_per_trial,
limits=resources_per_trial,
)


env = []
env_from = []
if isinstance(env_per_trial, dict):
env, env_from = [client.V1EnvVar(name=str(k), value=str(v)) for k, v in env_per_trial.items()] or None, None

if env_per_trial:
env = [x for x in env_per_trial if isinstance(x, client.V1EnvVar)] or None
env_from = [x for x in env_per_trial if isinstance(x, client.V1EnvFromSource)] or None
else:
env, env_from = None, None
env = [
client.V1EnvVar(name=str(k), value=str(v))
for k, v in env_per_trial.items()
]
elif env_per_trial:
for x in env_per_trial:
if isinstance(x, client.V1EnvVar):
env.append(x)
elif isinstance(x, client.V1EnvFromSource):
env_from.append(x)
else:
raise ValueError(
f"Incorrect value for env_per_trial: {env_per_trial}"
)

# Create Trial specification.
trial_spec = client.V1Job(
Expand Down Expand Up @@ -1156,9 +1175,9 @@ def get_success_trial_details(
):
output = {}
output["name"] = trial.metadata.name
output[
"parameter_assignments"
] = trial.spec.parameter_assignments
output["parameter_assignments"] = (
trial.spec.parameter_assignments
)
output["metrics"] = trial.status.observation.metrics
result.append(output)
except multiprocessing.TimeoutError:
Expand Down

0 comments on commit 086093f

Please sign in to comment.