diff --git a/hack/gen-python-sdk/post_gen.py b/hack/gen-python-sdk/post_gen.py index b61021edeab..a0d1649d205 100644 --- a/hack/gen-python-sdk/post_gen.py +++ b/hack/gen-python-sdk/post_gen.py @@ -41,6 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules): if (output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py"): lines.append("# Import Katib API client.\n") lines.append("from kubeflow.katib.api.katib_client import KatibClient\n") + lines.append("# Import Katib report metrics functions") + lines.append("from kubeflow.katib.api.report_metrics import report_metrics") lines.append("# Import Katib helper functions.\n") lines.append("import kubeflow.katib.api.search as search\n") lines.append("# Import Katib helper constants.\n") diff --git a/pkg/apis/controller/common/v1beta1/common_types.go b/pkg/apis/controller/common/v1beta1/common_types.go index 8722e8a474d..251f8887042 100644 --- a/pkg/apis/controller/common/v1beta1/common_types.go +++ b/pkg/apis/controller/common/v1beta1/common_types.go @@ -220,8 +220,8 @@ const ( CustomCollector CollectorKind = "Custom" // When model training source code persists metrics into persistent layer - // directly, metricsCollector isn't in need, and its kind is "noneCollector" - NoneCollector CollectorKind = "None" + // directly, sidecar container isn't in need, and its kind is "pushCollector" + PushCollector CollectorKind = "Push" MetricsVolume = "metrics-volume" ) diff --git a/pkg/controller.v1beta1/consts/const.go b/pkg/controller.v1beta1/consts/const.go index b59fb4f4bc6..2cffe30cde3 100644 --- a/pkg/controller.v1beta1/consts/const.go +++ b/pkg/controller.v1beta1/consts/const.go @@ -60,6 +60,9 @@ const ( // resources list which can be used as trial template ConfigTrialResources = "trial-resources" + // EnvTrialName is the env variable of Trial name + EnvTrialName = "KATIB_TRIAL_NAME" + // LabelExperimentName is the label of experiment name. LabelExperimentName = "katib.kubeflow.org/experiment" // LabelSuggestionName is the label of suggestion name. diff --git a/pkg/webhook/v1beta1/experiment/validator/validator.go b/pkg/webhook/v1beta1/experiment/validator/validator.go index aeabbc0f463..56bde4b0621 100644 --- a/pkg/webhook/v1beta1/experiment/validator/validator.go +++ b/pkg/webhook/v1beta1/experiment/validator/validator.go @@ -488,7 +488,7 @@ func (g *DefaultValidator) validateMetricsCollector(inst *experimentsv1beta1.Exp } // TODO(hougangliu): log warning message if some field will not be used for the metricsCollector kind switch mcKind { - case commonapiv1beta1.NoneCollector, commonapiv1beta1.StdOutCollector: + case commonapiv1beta1.PushCollector, commonapiv1beta1.StdOutCollector: return allErrs case commonapiv1beta1.FileCollector: if mcSpec.Source == nil || mcSpec.Source.FileSystemPath == nil || diff --git a/pkg/webhook/v1beta1/pod/inject_webhook.go b/pkg/webhook/v1beta1/pod/inject_webhook.go index f48b4b96a07..3932a6bbfd3 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook.go @@ -140,6 +140,13 @@ func (s *SidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error) // Add Katib Trial labels to the Pod metadata. mutatePodMetadata(mutatedPod, trial) + // Add env variables to the Pod's primary container. + // We add this function because of push-based metrics collection function `report_metrics` in Python SDK. + // Currently, we only pass the Trial name as env variable `KATIB_TRIAL_NAME` to the training container. + if err := mutatePodEnv(mutatedPod, trial); err != nil { + return nil, err + } + // Do the following mutation only for the Primary pod. // If PrimaryPodLabel is not set we mutate all pods which are related to Trial job. // Otherwise, mutate pod only with the appropriate labels. @@ -147,8 +154,8 @@ func (s *SidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error) return mutatedPod, nil } - // If Metrics Collector in None, skip the mutation. - if trial.Spec.MetricsCollector.Collector.Kind == common.NoneCollector { + // If Metrics Collector is Push, skip the mutation. + if trial.Spec.MetricsCollector.Collector.Kind == common.PushCollector { return mutatedPod, nil } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook_test.go b/pkg/webhook/v1beta1/pod/inject_webhook_test.go index ab1646b6769..4436c10e7f1 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook_test.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/onsi/gomega" appsv1 "k8s.io/api/apps/v1" batchv1 "k8s.io/api/batch/v1" @@ -1067,3 +1069,103 @@ func TestMutatePodMetadata(t *testing.T) { } } } + +func TestMutatePodEnv(t *testing.T) { + testcases := map[string]struct { + pod *v1.Pod + trial *trialsv1beta1.Trial + mutatedPod *v1.Pod + wantError error + }{ + "Valid case for mutating Pod's env variable": { + pod: &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "training-container", + }, + }, + }, + }, + trial: &trialsv1beta1.Trial{ + Spec: trialsv1beta1.TrialSpec{ + PrimaryContainerName: "training-container", + }, + }, + mutatedPod: &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "training-container", + Env: []v1.EnvVar{ + { + Name: consts.EnvTrialName, + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: fmt.Sprintf("metadata.labels['%s']", consts.LabelTrialName), + }, + }, + }, + }, + }, + }, + }, + }, + }, + "Mismatch for Pod name and primaryContainerName in Trial": { + pod: &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "training-container", + }, + }, + }, + }, + trial: &trialsv1beta1.Trial{ + Spec: trialsv1beta1.TrialSpec{ + PrimaryContainerName: "training-containers", + }, + }, + mutatedPod: &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "training-container", + }, + }, + }, + }, + wantError: fmt.Errorf( + "Unable to find primary container %v in mutated pod containers %v", + "training-containers", + []v1.Container{ + { + Name: "training-container", + }, + }, + ), + }, + } + + for name, testcase := range testcases { + t.Run(name, func(t *testing.T) { + err := mutatePodEnv(testcase.pod, testcase.trial) + // Compare error with expected error + if testcase.wantError != nil && err != nil { + if diff := cmp.Diff(testcase.wantError.Error(), err.Error()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + } else if testcase.wantError != nil || err != nil { + t.Errorf( + "Unexpected error (-want,+got):\n%s", + cmp.Diff(testcase.wantError, err, cmpopts.EquateErrors()), + ) + } + // Compare Pod with expected pod after mutation + if diff := cmp.Diff(testcase.mutatedPod, testcase.pod); len(diff) != 0 { + t.Errorf("Unexpected mutated result (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/webhook/v1beta1/pod/utils.go b/pkg/webhook/v1beta1/pod/utils.go index 6381bf6d895..7dad82553cf 100644 --- a/pkg/webhook/v1beta1/pod/utils.go +++ b/pkg/webhook/v1beta1/pod/utils.go @@ -281,6 +281,33 @@ func mutatePodMetadata(pod *v1.Pod, trial *trialsv1beta1.Trial) { pod.Labels = podLabels } +func mutatePodEnv(pod *v1.Pod, trial *trialsv1beta1.Trial) error { + // Search for the primary container + index := getPrimaryContainerIndex(pod.Spec.Containers, trial.Spec.PrimaryContainerName) + if index >= 0 { + if pod.Spec.Containers[index].Env == nil { + pod.Spec.Containers[index].Env = []v1.EnvVar{} + } + + // Pass env variable KATIB_TRIAL_NAME to the primary container using fieldPath + pod.Spec.Containers[index].Env = append( + pod.Spec.Containers[index].Env, + v1.EnvVar{ + Name: consts.EnvTrialName, + ValueFrom: &v1.EnvVarSource{ + FieldRef: &v1.ObjectFieldSelector{ + FieldPath: fmt.Sprintf("metadata.labels['%s']", consts.LabelTrialName), + }, + }, + }, + ) + return nil + } else { + return fmt.Errorf("Unable to find primary container %v in mutated pod containers %v", + trial.Spec.PrimaryContainerName, pod.Spec.Containers) + } +} + func getSidecarContainerName(cKind common.CollectorKind) string { if cKind == common.StdOutCollector || cKind == common.FileCollector { return mccommon.MetricLoggerCollectorContainerName diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index 7988dbaa898..398e73ef908 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -186,6 +186,7 @@ def tune( retain_trials: bool = False, packages_to_install: List[str] = None, pip_index_url: str = "https://pypi.org/simple", + metrics_collector_config: Dict[str, Any] = {"kind": "StdOut"}, ): """Create HyperParameter Tuning Katib Experiment from the objective function. @@ -248,6 +249,9 @@ def tune( to the base image packages. These packages are installed before executing the objective function. pip_index_url: The PyPI url from which to install Python packages. + metrics_collector_config: Specify the config of metrics collector, + for example, `metrics_collector_config = {"kind": "Push"}`. + Currently, we only support `StdOut` and `Push` metrics collector. Raises: ValueError: Function arguments have incorrect type or value. @@ -380,6 +384,12 @@ def tune( f"Incorrect value for env_per_trial: {env_per_trial}" ) + # Add metrics collector to the Katib Experiment. + # Up to now, We only support parameter `kind`, of which default value is `StdOut`, to specify the kind of metrics collector. + experiment.spec.metrics_collector = models.V1beta1MetricsCollectorSpec( + collector=models.V1beta1CollectorSpec(kind=metrics_collector_config["kind"]) + ) + # Create Trial specification. trial_spec = client.V1Job( api_version="batch/v1",