diff --git a/pkg/webhook/v1beta1/pod/const.go b/pkg/webhook/v1beta1/pod/const.go index e4566e4f620..7853ae52389 100644 --- a/pkg/webhook/v1beta1/pod/const.go +++ b/pkg/webhook/v1beta1/pod/const.go @@ -23,6 +23,10 @@ import ( const ( MasterRole = "master" BatchJob = "Job" + // TrialKind is the name of Trial kind + TrialKind = "Trial" + // TrialAPIVersion is the name of Trial API Version + TrialAPIVersion = "kubeflow.org/v1beta1" ) var ( diff --git a/pkg/webhook/v1beta1/pod/inject_webhook.go b/pkg/webhook/v1beta1/pod/inject_webhook.go index 7a51e34d4ac..67b2303b9d2 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook.go @@ -18,13 +18,13 @@ package pod import ( "context" - "fmt" + "errors" "net/http" - "path/filepath" - "strings" "github.com/spf13/viper" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" apitypes "k8s.io/apimachinery/pkg/types" "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" @@ -35,10 +35,8 @@ import ( common "github.com/kubeflow/katib/pkg/apis/controller/common/v1beta1" trialsv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/trials/v1beta1" - katibmanagerv1beta1 "github.com/kubeflow/katib/pkg/common/v1beta1" "github.com/kubeflow/katib/pkg/controller.v1beta1/consts" - jobv1beta1 "github.com/kubeflow/katib/pkg/job/v1beta1" - mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common" + "github.com/kubeflow/katib/pkg/controller.v1beta1/util" "github.com/kubeflow/katib/pkg/util/v1beta1/katibconfig" ) @@ -108,7 +106,13 @@ func NewSidecarInjector(c client.Client) *sidecarInjector { } func (s *sidecarInjector) MutationRequired(pod *v1.Pod, ns string) (bool, error) { - jobKind, jobName, err := getKatibJob(pod) + object, err := util.ConvertObjectToUnstructured(pod) + if err != nil { + return false, err + } + + // Try to get Katib job kind and job name from mutating pod + jobKind, jobName, err := s.getKatibJob(object, ns) if err != nil { return false, nil } @@ -141,9 +145,17 @@ func (s *sidecarInjector) MutationRequired(pod *v1.Pod, ns string) (bool, error) func (s *sidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error) { mutatedPod := pod.DeepCopy() - kind, trialName, _ := getKatibJob(pod) + object, err := util.ConvertObjectToUnstructured(pod) + if err != nil { + return nil, err + } + + // Try to get Katib job kind and job name from mutating pod + jobKind, jobName, _ := s.getKatibJob(object, namespace) + trial := &trialsv1beta1.Trial{} - if err := s.client.Get(context.TODO(), apitypes.NamespacedName{Name: trialName, Namespace: namespace}, trial); err != nil { + // jobName and Trial name is equal + if err := s.client.Get(context.TODO(), apitypes.NamespacedName{Name: jobName, Namespace: namespace}, trial); err != nil { return nil, err } @@ -157,16 +169,16 @@ func (s *sidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error) mountPath, pathKind := getMountPath(trial.Spec.MetricsCollector) if mountPath != "" { - if err = mutateVolume(mutatedPod, kind, mountPath, injectContainer.Name, pathKind); err != nil { + if err = mutateVolume(mutatedPod, jobKind, mountPath, injectContainer.Name, pathKind); err != nil { return nil, err } } if needWrapWorkerContainer(trial.Spec.MetricsCollector) { - if err = wrapWorkerContainer(mutatedPod, namespace, kind, mountPath, pathKind, trial.Spec.MetricsCollector); err != nil { + if err = wrapWorkerContainer(mutatedPod, namespace, jobKind, mountPath, pathKind, trial.Spec.MetricsCollector); err != nil { return nil, err } } - log.Info("Inject metrics collector sidecar container", "Pod Generate Name", mutatedPod.GenerateName, "Trial", trialName) + log.Info("Inject metrics collector sidecar container", "Pod Generate Name", mutatedPod.GenerateName, "Trial", jobName) return mutatedPod, nil } @@ -206,143 +218,53 @@ func (s *sidecarInjector) getMetricsCollectorContainer(trial *trialsv1beta1.Tria return &injectContainer, nil } -func getMetricsCollectorArgs(trialName, metricName string, mc common.MetricsCollectorSpec) []string { - args := []string{"-t", trialName, "-m", metricName, "-s", katibmanagerv1beta1.GetDBManagerAddr()} - if mountPath, _ := getMountPath(mc); mountPath != "" { - args = append(args, "-path", mountPath) - } - if mc.Source != nil && mc.Source.Filter != nil && len(mc.Source.Filter.MetricsFormat) > 0 { - args = append(args, "-f", strings.Join(mc.Source.Filter.MetricsFormat, ";")) - } - return args -} - -func getMountPath(mc common.MetricsCollectorSpec) (string, common.FileSystemKind) { - if mc.Collector.Kind == common.StdOutCollector { - return common.DefaultFilePath, common.FileKind - } else if mc.Collector.Kind == common.FileCollector { - return mc.Source.FileSystemPath.Path, common.FileKind - } else if mc.Collector.Kind == common.TfEventCollector { - return mc.Source.FileSystemPath.Path, common.DirectoryKind - } else if mc.Collector.Kind == common.CustomCollector { - if mc.Source == nil || mc.Source.FileSystemPath == nil { - return "", common.InvalidKind - } - return mc.Source.FileSystemPath.Path, mc.Source.FileSystemPath.Kind - } else { - return "", common.InvalidKind - } -} - -func needWrapWorkerContainer(mc common.MetricsCollectorSpec) bool { - mcKind := mc.Collector.Kind - for _, kind := range NeedWrapWorkerMetricsCollecterList { - if mcKind == kind { - return true - } - } - return false -} - -func wrapWorkerContainer( - pod *v1.Pod, namespace, jobKind, metricsFile string, - pathKind common.FileSystemKind, - mc common.MetricsCollectorSpec) error { - index := -1 - for i, c := range pod.Spec.Containers { - jobProvider, err := jobv1beta1.New(jobKind) - if err != nil { - return err - } - if jobProvider.IsTrainingContainer(i, c) { - index = i - break +func (s *sidecarInjector) getKatibJob(object *unstructured.Unstructured, namespace string) (string, string, error) { + owners := object.GetOwnerReferences() + // jobKind and jobName points to the object kind and name that Trial is created + jobKind := "" + jobName := "" + // Search for Trial owner in object owner references + // Trial is owned object if kind = Trial kind and API version = Trial API version + for _, owner := range owners { + if owner.Kind == TrialKind && owner.APIVersion == TrialAPIVersion { + jobKind = object.GetKind() + jobName = object.GetName() } } - if index >= 0 { - command := []string{"sh", "-c"} - args, err := getContainerCommand(pod, namespace, index) - if err != nil { - return err - } - // If the first two commands are sh -c, we do not inject command. - if args[0] == "sh" || args[0] == "bash" { - if args[1] == "-c" { - command = args[0:2] - args = args[2:] + // If Trial is not found in object owners search for nested owners + if jobKind == "" { + i := 0 + // Search for Trial ownership unless jobKind is empty and owners is exists + for jobKind == "" && i < len(owners) { + nestedJob := &unstructured.Unstructured{} + // Get group and version from owner API version + gv, err := schema.ParseGroupVersion(owners[i].APIVersion) + if err != nil { + return "", "", err } - } - if mc.Collector.Kind == common.StdOutCollector { - redirectStr := fmt.Sprintf("1>%s 2>&1", metricsFile) - args = append(args, redirectStr) - } - args = append(args, "&&", getMarkCompletedCommand(metricsFile, pathKind)) - argsStr := strings.Join(args, " ") - c := &pod.Spec.Containers[index] - c.Command = command - c.Args = []string{argsStr} - } - return nil -} - -func getMarkCompletedCommand(mountPath string, pathKind common.FileSystemKind) string { - dir := mountPath - if pathKind == common.FileKind { - dir = filepath.Dir(mountPath) - } - // $$ is process id in shell - pidFile := filepath.Join(dir, "$$$$.pid") - return fmt.Sprintf("echo %s > %s", mccommon.TrainingCompleted, pidFile) -} - -func mutateVolume(pod *v1.Pod, jobKind, mountPath, sidecarContainerName string, pathKind common.FileSystemKind) error { - metricsVol := v1.Volume{ - Name: common.MetricsVolume, - VolumeSource: v1.VolumeSource{ - EmptyDir: &v1.EmptyDirVolumeSource{}, - }, - } - dir := mountPath - if pathKind == common.FileKind { - dir = filepath.Dir(mountPath) - } - vm := v1.VolumeMount{ - Name: metricsVol.Name, - MountPath: dir, - } - indexList := []int{} - for i, c := range pod.Spec.Containers { - shouldMount := false - if c.Name == sidecarContainerName { - shouldMount = true - } else { - jobProvider, err := jobv1beta1.New(jobKind) + gvk := schema.GroupVersionKind{ + Group: gv.Group, + Version: gv.Version, + Kind: owners[i].Kind, + } + // Set GVK for nested unstructured object + nestedJob.SetGroupVersionKind(gvk) + // Get nested object from cluster. + // Nested object namespace must be equal to object namespace + err = s.client.Get(context.TODO(), apitypes.NamespacedName{Name: owners[i].Name, Namespace: namespace}, nestedJob) if err != nil { - return err + return "", "", err } - shouldMount = jobProvider.IsTrainingContainer(i, c) - } - if shouldMount { - indexList = append(indexList, i) - } - } - for _, i := range indexList { - c := &pod.Spec.Containers[i] - if c.VolumeMounts == nil { - c.VolumeMounts = make([]v1.VolumeMount, 0) + // Recursively search for Trial ownership in nested object + jobKind, jobName, err = s.getKatibJob(nestedJob, namespace) + i++ } - c.VolumeMounts = append(c.VolumeMounts, vm) - pod.Spec.Containers[i] = *c } - pod.Spec.Volumes = append(pod.Spec.Volumes, metricsVol) - return nil -} - -func getSidecarContainerName(cKind common.CollectorKind) string { - if cKind == common.StdOutCollector || cKind == common.FileCollector { - return mccommon.MetricLoggerCollectorContainerName - } else { - return mccommon.MetricCollectorContainerName + // If jobKind is empty after the loop, Trial doesn't own the object + if jobKind == "" { + return "", "", errors.New("The Pod doesn't belong to Katib Job") } + + return jobKind, jobName, nil } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook_test.go b/pkg/webhook/v1beta1/pod/inject_webhook_test.go index 34821b0e9dd..4612f25f1ba 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook_test.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook_test.go @@ -1,17 +1,29 @@ package pod import ( + "context" "reflect" + "sync" "testing" + "time" "path/filepath" - common "github.com/kubeflow/katib/pkg/apis/controller/common/v1beta1" - "github.com/kubeflow/katib/pkg/controller.v1beta1/consts" - mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common" + "github.com/onsi/gomega" + appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/envtest" + "sigs.k8s.io/controller-runtime/pkg/manager" + + common "github.com/kubeflow/katib/pkg/apis/controller/common/v1beta1" + "github.com/kubeflow/katib/pkg/controller.v1beta1/consts" + "github.com/kubeflow/katib/pkg/controller.v1beta1/util" + mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common" + tfv1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1" ) func TestWrapWorkerContainer(t *testing.T) { @@ -425,70 +437,258 @@ func TestGetSidecarContainerName(t *testing.T) { } } +func StartTestManager(mgr manager.Manager, g *gomega.GomegaWithT) (chan struct{}, *sync.WaitGroup) { + stop := make(chan struct{}) + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + g.Expect(mgr.Start(stop)).NotTo(gomega.HaveOccurred()) + }() + return stop, wg +} + func TestGetKatibJob(t *testing.T) { + // Start test k8s server + envTest := &envtest.Environment{ + CRDDirectoryPaths: []string{ + filepath.Join("..", "..", "..", "..", "manifests", "v1beta1", "katib-controller"), + filepath.Join("..", "..", "..", "..", "test", "unit", "v1beta1", "crds"), + }, + } + + cfg, err := envTest.Start() + if err != nil { + t.Error(err) + } + + g := gomega.NewGomegaWithT(t) + + mgr, err := manager.New(cfg, manager.Options{}) + g.Expect(err).NotTo(gomega.HaveOccurred()) + + stopMgr, mgrStopped := StartTestManager(mgr, g) + defer func() { + close(stopMgr) + mgrStopped.Wait() + }() + + c := mgr.GetClient() + si := NewSidecarInjector(c) + + namespace := "default" + trialName := "trial-name" + podName := "pod-name" + deployName := "deploy-name" + tfJobName := "tfjob-name" + timeout := time.Second * 5 + testCases := []struct { - Pod v1.Pod + Pod *v1.Pod + TFJob *tfv1.TFJob + Deployment *appsv1.Deployment ExpectedJobKind string ExpectedJobName string Err bool - Name string + TestDescription string }{ { - Pod: v1.Pod{ + Pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: namespace, OwnerReferences: []metav1.OwnerReference{ { APIVersion: "kubeflow.org/v1", - Kind: "PyTorchJob", - Name: "OwnerName", + Kind: "TFJob", + Name: tfJobName + "-1", }, }, }, }, - ExpectedJobKind: "PyTorchJob", - ExpectedJobName: "OwnerName", + TFJob: &tfv1.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: tfJobName + "-1", + Namespace: namespace, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "kubeflow.org/v1beta1", + Kind: "Trial", + Name: trialName + "-1", + UID: "test-uid", + }, + }, + }, + }, + ExpectedJobKind: "TFJob", + ExpectedJobName: tfJobName + "-1", Err: false, - Name: "Valid Pod", + TestDescription: "Valid run with ownership sequence: Trial -> TFJob -> Pod", }, { - Pod: v1.Pod{ + Pod: &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: namespace, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "kubeflow.org/v1", + Kind: "TFJob", + Name: tfJobName + "-2", + }, + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: deployName + "-2", + }, + }, + }, + }, + TFJob: &tfv1.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: tfJobName + "-2", + Namespace: namespace, + }, + }, + Deployment: &appsv1.Deployment{ ObjectMeta: metav1.ObjectMeta{ + Name: deployName + "-2", + Namespace: namespace, OwnerReferences: []metav1.OwnerReference{ { - APIVersion: "notkubeflow.org/v1", - Kind: "PyTorchJob", - Name: "OwnerName", + APIVersion: "kubeflow.org/v1beta1", + Kind: "Trial", + Name: trialName + "-2", + UID: "test-uid", + }, + }, + }, + Spec: appsv1.DeploymentSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "test-key": "test-value", + }, + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "test-key": "test-value", + }, + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "test", + Image: "test", + }, + }, }, }, }, }, - Err: true, - Name: "Invalid APIVersion", + ExpectedJobKind: "Deployment", + ExpectedJobName: deployName + "-2", + Err: false, + TestDescription: "Valid run with ownership sequence: Trial -> Deployment -> Pod, TFJob -> Pod", }, { - Pod: v1.Pod{ + Pod: &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: namespace, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "kubeflow.org/v1", + Kind: "TFJob", + Name: tfJobName + "-3", + }, + }, + }, + }, + TFJob: &tfv1.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: tfJobName + "-3", + Namespace: namespace, + }, + }, + Err: true, + TestDescription: "Run for not Trial's pod with ownership sequence: TFJob -> Pod", + }, + { + Pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: namespace, OwnerReferences: []metav1.OwnerReference{ { APIVersion: "kubeflow.org/v1", - Kind: "MXJob", - Name: "OwnerName", + Kind: "TFJob", + Name: tfJobName + "-4", + }, + }, + }, + }, + Err: true, + TestDescription: "Run when Pod owns TFJob that doesn't exists", + }, + { + Pod: &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: namespace, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "invalid/api/version", + Kind: "TFJob", + Name: tfJobName + "-4", }, }, }, }, - Err: true, - Name: "Invalid Kind", + Err: true, + TestDescription: "Run when Pod owns TFJob with invalid API version", }, } for _, tc := range testCases { - jobKind, jobName, err := getKatibJob(&tc.Pod) + // Create TFJob if it is needed + if tc.TFJob != nil { + tfJobUnstr, err := util.ConvertObjectToUnstructured(tc.TFJob) + gvk := schema.GroupVersionKind{ + Group: "kubeflow.org", + Version: "v1", + Kind: "TFJob", + } + tfJobUnstr.SetGroupVersionKind(gvk) + if err != nil { + t.Errorf("ConvertObjectToUnstructured error %v", err) + } + + g.Expect(c.Create(context.TODO(), tfJobUnstr)).NotTo(gomega.HaveOccurred()) + + // Wait that TFJob is created + g.Eventually(func() error { + return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.TFJob.Name}, tfJobUnstr) + }, timeout).ShouldNot(gomega.HaveOccurred()) + } + + // Create Deployment if it is needed + if tc.Deployment != nil { + g.Expect(c.Create(context.TODO(), tc.Deployment)).NotTo(gomega.HaveOccurred()) + + // Wait that Deployment is created + g.Eventually(func() error { + return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.Deployment.Name}, tc.Deployment) + }, timeout).ShouldNot(gomega.HaveOccurred()) + } + + object, _ := util.ConvertObjectToUnstructured(tc.Pod) + jobKind, jobName, err := si.getKatibJob(object, namespace) if !tc.Err && err != nil { - t.Errorf("Case %v failed. Error %v", tc.Name, err) + t.Errorf("Case %v failed. Error %v", tc.TestDescription, err) } else if !tc.Err && (tc.ExpectedJobKind != jobKind || tc.ExpectedJobName != jobName) { t.Errorf("Case %v failed. Expected jobKind %v, got %v, Expected jobName %v, got %v", - tc.Name, tc.ExpectedJobKind, jobKind, tc.ExpectedJobName, jobName) + tc.TestDescription, tc.ExpectedJobKind, jobKind, tc.ExpectedJobName, jobName) } else if tc.Err && err == nil { t.Errorf("Expected error got nil") } diff --git a/pkg/webhook/v1beta1/pod/utils.go b/pkg/webhook/v1beta1/pod/utils.go index 577612fc3b6..063eb3fba74 100644 --- a/pkg/webhook/v1beta1/pod/utils.go +++ b/pkg/webhook/v1beta1/pod/utils.go @@ -19,6 +19,8 @@ package pod import ( "errors" "fmt" + "path/filepath" + "strings" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/authn/k8schain" @@ -26,35 +28,13 @@ import ( crv1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/remote" v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime/schema" + common "github.com/kubeflow/katib/pkg/apis/controller/common/v1beta1" + katibmanagerv1beta1 "github.com/kubeflow/katib/pkg/common/v1beta1" jobv1beta1 "github.com/kubeflow/katib/pkg/job/v1beta1" + mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common" ) -func getKatibJob(pod *v1.Pod) (string, string, error) { - for _, gvk := range jobv1beta1.SupportedJobList { - owners := pod.GetOwnerReferences() - for _, owner := range owners { - if isMatchGVK(owner, gvk) { - return owner.Kind, owner.Name, nil - } - } - } - return "", "", errors.New("The Pod doesn't belong to Katib Job") -} - -func isMatchGVK(owner metav1.OwnerReference, gvk schema.GroupVersionKind) bool { - if owner.Kind != gvk.Kind { - return false - } - gv := gvk.Group + "/" + gvk.Version - if gv != owner.APIVersion { - return false - } - return true -} - func isPrimaryPod(podLabels, primaryLabels map[string]string) bool { for primaryKey, primaryValue := range primaryLabels { @@ -155,3 +135,144 @@ func getContainerCommand(pod *v1.Pod, namespace string, containerIndex int) ([]s } return args, nil } + +func getMetricsCollectorArgs(trialName, metricName string, mc common.MetricsCollectorSpec) []string { + args := []string{"-t", trialName, "-m", metricName, "-s", katibmanagerv1beta1.GetDBManagerAddr()} + if mountPath, _ := getMountPath(mc); mountPath != "" { + args = append(args, "-path", mountPath) + } + if mc.Source != nil && mc.Source.Filter != nil && len(mc.Source.Filter.MetricsFormat) > 0 { + args = append(args, "-f", strings.Join(mc.Source.Filter.MetricsFormat, ";")) + } + return args +} + +func getMountPath(mc common.MetricsCollectorSpec) (string, common.FileSystemKind) { + if mc.Collector.Kind == common.StdOutCollector { + return common.DefaultFilePath, common.FileKind + } else if mc.Collector.Kind == common.FileCollector { + return mc.Source.FileSystemPath.Path, common.FileKind + } else if mc.Collector.Kind == common.TfEventCollector { + return mc.Source.FileSystemPath.Path, common.DirectoryKind + } else if mc.Collector.Kind == common.CustomCollector { + if mc.Source == nil || mc.Source.FileSystemPath == nil { + return "", common.InvalidKind + } + return mc.Source.FileSystemPath.Path, mc.Source.FileSystemPath.Kind + } else { + return "", common.InvalidKind + } +} + +func needWrapWorkerContainer(mc common.MetricsCollectorSpec) bool { + mcKind := mc.Collector.Kind + for _, kind := range NeedWrapWorkerMetricsCollecterList { + if mcKind == kind { + return true + } + } + return false +} + +func wrapWorkerContainer( + pod *v1.Pod, namespace, jobKind, metricsFile string, + pathKind common.FileSystemKind, + mc common.MetricsCollectorSpec) error { + index := -1 + for i, c := range pod.Spec.Containers { + jobProvider, err := jobv1beta1.New(jobKind) + if err != nil { + return err + } + if jobProvider.IsTrainingContainer(i, c) { + index = i + break + } + } + if index >= 0 { + command := []string{"sh", "-c"} + args, err := getContainerCommand(pod, namespace, index) + if err != nil { + return err + } + // If the first two commands are sh -c, we do not inject command. + if args[0] == "sh" || args[0] == "bash" { + if args[1] == "-c" { + command = args[0:2] + args = args[2:] + } + } + if mc.Collector.Kind == common.StdOutCollector { + redirectStr := fmt.Sprintf("1>%s 2>&1", metricsFile) + args = append(args, redirectStr) + } + args = append(args, "&&", getMarkCompletedCommand(metricsFile, pathKind)) + argsStr := strings.Join(args, " ") + c := &pod.Spec.Containers[index] + c.Command = command + c.Args = []string{argsStr} + } + return nil +} + +func getMarkCompletedCommand(mountPath string, pathKind common.FileSystemKind) string { + dir := mountPath + if pathKind == common.FileKind { + dir = filepath.Dir(mountPath) + } + // $$ is process id in shell + pidFile := filepath.Join(dir, "$$$$.pid") + return fmt.Sprintf("echo %s > %s", mccommon.TrainingCompleted, pidFile) +} + +func mutateVolume(pod *v1.Pod, jobKind, mountPath, sidecarContainerName string, pathKind common.FileSystemKind) error { + metricsVol := v1.Volume{ + Name: common.MetricsVolume, + VolumeSource: v1.VolumeSource{ + EmptyDir: &v1.EmptyDirVolumeSource{}, + }, + } + dir := mountPath + if pathKind == common.FileKind { + dir = filepath.Dir(mountPath) + } + vm := v1.VolumeMount{ + Name: metricsVol.Name, + MountPath: dir, + } + indexList := []int{} + for i, c := range pod.Spec.Containers { + shouldMount := false + if c.Name == sidecarContainerName { + shouldMount = true + } else { + jobProvider, err := jobv1beta1.New(jobKind) + if err != nil { + return err + } + shouldMount = jobProvider.IsTrainingContainer(i, c) + } + if shouldMount { + indexList = append(indexList, i) + } + } + for _, i := range indexList { + c := &pod.Spec.Containers[i] + if c.VolumeMounts == nil { + c.VolumeMounts = make([]v1.VolumeMount, 0) + } + c.VolumeMounts = append(c.VolumeMounts, vm) + pod.Spec.Containers[i] = *c + } + pod.Spec.Volumes = append(pod.Spec.Volumes, metricsVol) + + return nil +} + +func getSidecarContainerName(cKind common.CollectorKind) string { + if cKind == common.StdOutCollector || cKind == common.FileCollector { + return mccommon.MetricLoggerCollectorContainerName + } else { + return mccommon.MetricCollectorContainerName + } +}