Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement for Custom CRD #1333

Merged
merged 11 commits into from
Oct 13, 2020
9 changes: 9 additions & 0 deletions pkg/apis/controller/trials/v1beta1/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ func (trial *Trial) IsKilled() bool {
return hasCondition(trial, TrialKilled)
}

// IsMetricsUnavailable returns true if Trial metrics are not available
func (trial *Trial) IsMetricsUnavailable() bool {
cond := getCondition(trial, TrialSucceeded)
if cond != nil && cond.Status == v1.ConditionFalse {
return true
}
return false
}

func (trial *Trial) IsCompleted() bool {
return trial.IsSucceeded() || trial.IsFailed() || trial.IsKilled()
}
Expand Down
48 changes: 39 additions & 9 deletions pkg/controller.v1beta1/trial/trial_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package trial
import (
"context"
"fmt"
"time"

batchv1beta "k8s.io/api/batch/v1beta1"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -174,6 +175,12 @@ type ReconcileTrial struct {
collector *trialutil.TrialsCollector
}

// Map which contains number of requeuing for each trial if observation logs are not available
// That is needed if Job is succeeded but metrics are not reported yet
// Key = Trial name, value = requeue count
var trialRequeueCount = make(map[string]int)
gaocegege marked this conversation as resolved.
Show resolved Hide resolved
var maxRequeueCount = 5

// Reconcile reads that state of the cluster for a Trial object and makes changes based on the state read
// and what is in the Trial.Spec
// +kubebuilder:rbac:groups=trials.kubeflow.org,resources=trials,verbs=get;list;watch;create;update;patch;delete
Expand Down Expand Up @@ -232,6 +239,24 @@ func (r *ReconcileTrial) Reconcile(request reconcile.Request) (reconcile.Result,
}
}

// Restart Reconcile for maxRequeueCount times
if instance.IsMetricsUnavailable() {

count, ok := trialRequeueCount[instance.GetName()]
if !ok {
trialRequeueCount[instance.GetName()] = 1
} else {
trialRequeueCount[instance.GetName()]++
}

if count <= maxRequeueCount {
logger.Info("Trial metrics are not available, reconciler requeued", "requeue count", maxRequeueCount)
return reconcile.Result{
RequeueAfter: time.Second * 5,
}, nil
}
}

return reconcile.Result{}, nil
}

Expand Down Expand Up @@ -319,6 +344,7 @@ func (r *ReconcileTrial) reconcileJob(instance *trialsv1beta1.Trial, desiredJob
gvk := schema.FromAPIVersionAndKind(apiVersion, kind)

// Add annotation to desired Job to disable istio sidecar
// TODO (andreyvelich): Can be removed after custom CRD implementation
err = util.TrainingJobAnnotations(desiredJob)
if err != nil {
logger.Error(err, "TrainingJobAnnotations error")
Expand All @@ -333,15 +359,19 @@ func (r *ReconcileTrial) reconcileJob(instance *trialsv1beta1.Trial, desiredJob
if instance.IsCompleted() {
return nil, nil
}
jobProvider, err := jobv1beta1.New(desiredJob.GetKind())
if err != nil {
return nil, err
}
// mutate desiredJob according to provider
if err := jobProvider.MutateJob(instance, desiredJob); err != nil {
logger.Error(err, "Mutating desiredSpec of km.Training error")
return nil, err
}

// TODO (andreyvelich): Mutate job needs to be refactored (ref: https://github.com/kubeflow/katib/issues/1320)
// Currently, commented since we don't do Mutate Job for SupportedJobList
// jobProvider, err := jobv1beta1.New(desiredJob.GetKind())
// if err != nil {
// return nil, err
// }
// // mutate desiredJob according to provider
// if err := jobProvider.MutateJob(instance, desiredJob); err != nil {
// logger.Error(err, "Mutating desiredSpec of km.Training error")
// return nil, err
// }

logger.Info("Creating Job", "kind", kind,
"name", desiredJob.GetName())
err = r.Create(context.TODO(), desiredJob)
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller.v1beta1/trial/trial_controller_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (r *ReconcileTrial) UpdateTrialStatusCondition(instance *trialsv1beta1.Tria
r.recorder.Eventf(instance, corev1.EventTypeNormal, JobSucceededReason, eventMsg)
r.collector.IncreaseTrialsSucceededCount(instance.Namespace)
} else {
// TODO (andreyvelich): Is is correct to mark succeeded status false when metrics are unavailable?
// TODO (andreyvelich): Is it correct to mark succeeded status false when metrics are unavailable?
msg := "Metrics are not available"
reason := TrialMetricsUnavailableReason

Expand Down
10 changes: 1 addition & 9 deletions pkg/controller.v1beta1/util/annotations.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@ limitations under the License.
package util

import (
"fmt"

batchv1 "k8s.io/api/batch/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"

suggestionsv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/suggestions/v1beta1"
"github.com/kubeflow/katib/pkg/controller.v1beta1/consts"
jobv1beta1 "github.com/kubeflow/katib/pkg/job/v1beta1"
pytorchv1 "github.com/kubeflow/pytorch-operator/pkg/apis/pytorch/v1"
tfv1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
)
Expand Down Expand Up @@ -102,13 +99,8 @@ func TrainingJobAnnotations(desiredJob *unstructured.Unstructured) error {
}
return nil
default:
// Annotation appending of custom job can be done in Provider.MutateJob.
if _, ok := jobv1beta1.SupportedJobList[kind]; ok {
return nil
}
return fmt.Errorf("Invalid Katib Training Job kind %v", kind)
return nil
}

}

func appendAnnotation(annotations map[string]string, newAnnotationName string, newAnnotationValue string) map[string]string {
Expand Down
38 changes: 20 additions & 18 deletions pkg/db/v1beta1/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ func NewDBInterface() (common.KatibDBInterface, error) {
}

func (d *dbConn) RegisterObservationLog(trialName string, observationLog *v1beta1.ObservationLog) error {
var mname, mvalue string
sqlQuery := "INSERT INTO observation_logs (trial_name, time, metric_name, value) VALUES "
values := []interface{}{}

for _, mlog := range observationLog.MetricLogs {
mname = mlog.Metric.Name
mvalue = mlog.Metric.Value
if mlog.TimeStamp == "" {
continue
}
Expand All @@ -104,22 +104,24 @@ func (d *dbConn) RegisterObservationLog(trialName string, observationLog *v1beta
return fmt.Errorf("Error parsing start time %s: %v", mlog.TimeStamp, err)
}
sqlTimeStr := t.UTC().Format(mysqlTimeFmt)
_, err = d.db.Exec(
`INSERT INTO observation_logs (
trial_name,
time,
metric_name,
value
) VALUES (?, ?, ?, ?)`,
trialName,
sqlTimeStr,
mname,
mvalue,
)
if err != nil {
return err
}

sqlQuery += "(?, ?, ?, ?),"
values = append(values, trialName, sqlTimeStr, mlog.Metric.Name, mlog.Metric.Value)
}
sqlQuery = sqlQuery[0 : len(sqlQuery)-1]

// Prepare the statement
stmt, err := d.db.Prepare(sqlQuery)
if err != nil {
return fmt.Errorf("Pepare SQL statement failed: %v", err)
}

// Execute INSERT
_, err = stmt.Exec(values...)
if err != nil {
return fmt.Errorf("Execute SQL INSERT failed: %v", err)
}

return nil
}

Expand Down
43 changes: 14 additions & 29 deletions pkg/db/v1beta1/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,37 +60,22 @@ func TestRegisterObservationLog(t *testing.T) {
Value: "0.5",
},
},
{
TimeStamp: "2016-12-31T20:02:05.123456Z",
Metric: &api_pb.Metric{
Name: "precision",
Value: "88.7",
},
},
{
TimeStamp: "2016-12-31T20:02:05.123456Z",
Metric: &api_pb.Metric{
Name: "recall",
Value: "89.2",
},
},
},
}
for _, m := range obsLog.MetricLogs {
mock.ExpectExec(
`INSERT INTO observation_logs \(
trial_name,
time,
metric_name,
value
\)`,
).WithArgs(
"test1_trial1",
"2016-12-31 20:02:05.123456",
m.Metric.Name,
m.Metric.Value,
).WillReturnResult(sqlmock.NewResult(1, 1))
}
mock.ExpectPrepare("INSERT")
mock.ExpectExec(
"INSERT",
).WithArgs(
"test1_trial1",
"2016-12-31 20:02:05.123456",
"f1_score",
"88.95",
"test1_trial1",
"2016-12-31 20:02:05.123456",
"loss",
"0.5",
).WillReturnResult(sqlmock.NewResult(1, 1))

err := dbInterface.RegisterObservationLog("test1_trial1", obsLog)
if err != nil {
t.Errorf("RegisterExperiment failed: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions pkg/webhook/v1beta1/experiment/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func (g *DefaultValidator) validateTrialTemplate(instance *experimentsv1beta1.Ex

// Check if Job is supported
// Check if Job can be converted to Batch Job/TFJob/PyTorchJob
// Not default CRDs can be omitted later
// Other jobs are not validated
if err := g.validateSupportedJob(runSpec); err != nil {
return fmt.Errorf("Invalid spec.trialTemplate: %v", err)
}
Expand Down Expand Up @@ -336,7 +336,7 @@ func (g *DefaultValidator) validateSupportedJob(runSpec *unstructured.Unstructur
return nil
}
}
return fmt.Errorf("Job type %v not supported", gvk)
return nil
}

func validatePatchJob(runSpec *unstructured.Unstructured, job interface{}, jobType string) error {
Expand Down
18 changes: 9 additions & 9 deletions pkg/webhook/v1beta1/experiment/validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,9 @@ spec:
emptyAPIVersionJob.TypeMeta.APIVersion = ""
emptyAPIVersionStr := convertBatchJobToString(emptyAPIVersionJob)

invalidJobType := newFakeBatchJob()
invalidJobType.TypeMeta.Kind = "InvalidKind"
invalidJobTypeStr := convertBatchJobToString(invalidJobType)
customJobType := newFakeBatchJob()
customJobType.TypeMeta.Kind = "CustomKind"
customJobTypeStr := convertBatchJobToString(customJobType)

emptyConfigMap := p.EXPECT().GetTrialTemplate(gomock.Any()).Return("", errors.New(string(metav1.StatusReasonNotFound)))

Expand All @@ -371,7 +371,7 @@ spec:
invalidParameterTemplate := p.EXPECT().GetTrialTemplate(gomock.Any()).Return(invalidParameterJobStr, nil)
notEmptyMetadataTemplate := p.EXPECT().GetTrialTemplate(gomock.Any()).Return(notEmptyMetadataStr, nil)
emptyAPIVersionTemplate := p.EXPECT().GetTrialTemplate(gomock.Any()).Return(emptyAPIVersionStr, nil)
invalidJobTypeTemplate := p.EXPECT().GetTrialTemplate(gomock.Any()).Return(invalidJobTypeStr, nil)
customJobTypeTemplate := p.EXPECT().GetTrialTemplate(gomock.Any()).Return(customJobTypeStr, nil)

gomock.InOrder(
emptyConfigMap,
Expand All @@ -384,7 +384,7 @@ spec:
invalidParameterTemplate,
notEmptyMetadataTemplate,
emptyAPIVersionTemplate,
invalidJobTypeTemplate,
customJobTypeTemplate,
)

tcs := []struct {
Expand Down Expand Up @@ -550,15 +550,15 @@ spec:
Err: true,
testDescription: "Trial template doesn't contain APIVersion or Kind",
},
// Trial Template has invalid Kind
// invalidJobTypeTemplate case
// Trial Template has custom Kind
// customJobTypeTemplate case
{
Instance: func() *experimentsv1beta1.Experiment {
i := newFakeInstance()
return i
}(),
Err: true,
testDescription: "Trial template has invalid Kind",
Err: false,
testDescription: "Trial template has custom Kind",
},
}
for _, tc := range tcs {
Expand Down
2 changes: 1 addition & 1 deletion pkg/webhook/v1beta1/pod/inject_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (s *sidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error)

mountPath, pathKind := getMountPath(trial.Spec.MetricsCollector)
if mountPath != "" {
if err = mutateVolume(mutatedPod, jobKind, mountPath, injectContainer.Name, pathKind); err != nil {
if err = mutateVolume(mutatedPod, jobKind, mountPath, injectContainer.Name, trial.Spec.PrimaryContainerName, pathKind); err != nil {
return nil, err
}
}
Expand Down
11 changes: 10 additions & 1 deletion pkg/webhook/v1beta1/pod/inject_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ func TestMutateVolume(t *testing.T) {
JobKind string
MountPath string
SidecarContainerName string
PrimaryContainerName string
PathKind common.FileSystemKind
Err bool
}{
Expand Down Expand Up @@ -433,6 +434,12 @@ func TestMutateVolume(t *testing.T) {
},
{
Name: "metrics-collector",
VolumeMounts: []v1.VolumeMount{
{
Name: common.MetricsVolume,
MountPath: filepath.Dir(common.DefaultFilePath),
},
},
},
},
Volumes: []v1.Volume{
Expand All @@ -447,7 +454,8 @@ func TestMutateVolume(t *testing.T) {
},
JobKind: "Job",
MountPath: common.DefaultFilePath,
SidecarContainerName: "train-job",
SidecarContainerName: "metrics-collector",
PrimaryContainerName: "train-job",
PathKind: common.FileKind,
}

Expand All @@ -456,6 +464,7 @@ func TestMutateVolume(t *testing.T) {
tc.JobKind,
tc.MountPath,
tc.SidecarContainerName,
tc.PrimaryContainerName,
tc.PathKind)
if err != nil {
t.Errorf("mutateVolume failed: %v", err)
Expand Down
15 changes: 10 additions & 5 deletions pkg/webhook/v1beta1/pod/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func getMarkCompletedCommand(mountPath string, pathKind common.FileSystemKind) s
return fmt.Sprintf("echo %s > %s", mccommon.TrainingCompleted, pidFile)
}

func mutateVolume(pod *v1.Pod, jobKind, mountPath, sidecarContainerName string, pathKind common.FileSystemKind) error {
func mutateVolume(pod *v1.Pod, jobKind, mountPath, sidecarContainerName, primaryContainerName string, pathKind common.FileSystemKind) error {
metricsVol := v1.Volume{
Name: common.MetricsVolume,
VolumeSource: v1.VolumeSource{
Expand All @@ -257,11 +257,16 @@ func mutateVolume(pod *v1.Pod, jobKind, mountPath, sidecarContainerName string,
if c.Name == sidecarContainerName {
shouldMount = true
} else {
jobProvider, err := jobv1beta1.New(jobKind)
if err != nil {
return err
if primaryContainerName != "" && c.Name == primaryContainerName {
shouldMount = true
// TODO (andreyvelich): This can be deleted after switch to custom CRD
} else if primaryContainerName == "" {
jobProvider, err := jobv1beta1.New(jobKind)
if err != nil {
return err
}
shouldMount = jobProvider.IsTrainingContainer(i, c)
}
shouldMount = jobProvider.IsTrainingContainer(i, c)
}
if shouldMount {
indexList = append(indexList, i)
Expand Down
Loading