Skip to content

Commit

Permalink
KEP-2170: Add TrainJob conditions (#2322)
Browse files Browse the repository at this point in the history
* KEP-2170: Implement TrainJob conditions

Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>

* Fix API comments

Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>

* Make condition message constants

Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>

* Stop connecting condition type and reason in JobSet plugin

Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>

---------

Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>
  • Loading branch information
tenzen-y authored Nov 9, 2024
1 parent 95be3c0 commit 94dee0e
Show file tree
Hide file tree
Showing 19 changed files with 528 additions and 23 deletions.
8 changes: 7 additions & 1 deletion api.v2/openapi-spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,13 @@
"items": {
"default": {},
"$ref": "#/definitions/v1.Condition"
}
},
"x-kubernetes-list-map-keys": [
"type"
],
"x-kubernetes-list-type": "map",
"x-kubernetes-patch-merge-key": "type",
"x-kubernetes-patch-strategy": "merge"
},
"jobsStatus": {
"description": "JobsStatus tracks the child Jobs in TrainJob.",
Expand Down
1 change: 0 additions & 1 deletion hack/violation_exception_v2alpha1.list
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,PodSpecOverride,Volumes
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TorchElasticPolicy,Metrics
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobSpec,PodSpecOverrides
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,Conditions
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,JobsStatus
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Args
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Command
Expand Down
3 changes: 3 additions & 0 deletions manifests/v2/base/crds/kubeflow.org_trainjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3055,6 +3055,9 @@ spec:
- type
type: object
type: array
x-kubernetes-list-map-keys:
- type
x-kubernetes-list-type: map
jobsStatus:
description: JobsStatus tracks the child Jobs in TrainJob.
items:
Expand Down
10 changes: 10 additions & 0 deletions pkg/apis/kubeflow.org/v2alpha1/openapi_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 44 additions & 1 deletion pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,43 @@ type TrainJob struct {
Status TrainJobStatus `json:"status,omitempty"`
}

const (
// TrainJobSuspended means that TrainJob is suspended.
TrainJobSuspended string = "Suspended"

// TrainJobComplete means that the TrainJob has completed its execution.
TrainJobComplete string = "Complete"

// TrainJobFailed means that the actual jobs have failed its execution.
TrainJobFailed string = "Failed"

// TrainJobCreated means that the actual jobs creation has succeeded.
TrainJobCreated string = "Created"
)

const (
// TrainJobSuspendedReason is the "Suspended" condition reason.
// When the TrainJob is suspended, this is added.
TrainJobSuspendedReason string = "Suspended"

// TrainJobResumedReason is the "Suspended" condition reason.
// When the TrainJob suspension is changed from True to False, this is added.
TrainJobResumedReason string = "Resumed"

// TrainJobJobsCreationSucceededReason is the "Created" condition reason.
// When the creating objects succeeded after building succeeded, this is added.
TrainJobJobsCreationSucceededReason string = "JobsCreationSucceeded"

// TrainJobJobsBuildFailedReason is the "Created" condition reason.
// When the building objects based on the TrainJob and the specified runtime failed,
// this is added.
TrainJobJobsBuildFailedReason string = "JobsBuildFailed"

// TrainJobJobsCreationFailedReason is the "Created" condition reason.
// When the creating objects failed even though building succeeded, this is added.
TrainJobJobsCreationFailedReason string = "JobsCreationFailed"
)

// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
// +resource:path=trainjobs

Expand Down Expand Up @@ -269,7 +306,13 @@ type ContainerOverride struct {
// TrainJobStatus represents the current status of TrainJob.
type TrainJobStatus struct {
// Conditions for the TrainJob.
Conditions []metav1.Condition `json:"conditions,omitempty"`
//
// +optional
// +listType=map
// +listMapKey=type
// +patchStrategy=merge
// +patchMergeKey=type
Conditions []metav1.Condition `json:"conditions,omitempty" patchStrategy:"merge" patchMergeKey:"type"`

// JobsStatus tracks the child Jobs in TrainJob.
JobsStatus []JobStatus `json:"jobsStatus,omitempty"`
Expand Down
20 changes: 20 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ const (

// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TrainJobJobsCreationSucceededMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationSucceeded"} condition.
TrainJobJobsCreationSucceededMessage = "Succeeded to create Jobs"

// TrainJobJobsBuildFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsBuildFailed"} condition.
TrainJobJobsBuildFailedMessage = "Failed to build Jobs"

// TrainJobJobsCreationFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationFailed"} condition.
TrainJobJobsCreationFailedMessage = "Failed to create Jobs"

// TrainJobSuspendedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Suspended"} condition.
TrainJobSuspendedMessage = "TrainJob is suspended"

// TrainJobResumedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
TrainJobResumedMessage = "TrainJob is resumed"
)

var (
Expand Down
125 changes: 110 additions & 15 deletions pkg/controller.v2/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
"fmt"

"github.com/go-logr/logr"
"github.com/kubeflow/training-operator/pkg/constants"
"k8s.io/apimachinery/pkg/api/equality"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
Expand All @@ -36,6 +40,15 @@ import (

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type objsOpState int

const (
creationSucceeded objsOpState = iota
buildFailed objsOpState = iota
creationFailed objsOpState = iota
updateFailed objsOpState = iota
)

type TrainJobReconciler struct {
log logr.Logger
client client.Client
Expand Down Expand Up @@ -63,29 +76,41 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
ctx = ctrl.LoggerInto(ctx, log)
log.V(2).Info("Reconciling TrainJob")
if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil {
return ctrl.Result{}, err
if isTrainJobFinished(&trainJob) {
log.V(5).Info("TrainJob has already been finished")
return ctrl.Result{}, nil
}
// TODO (tenzen-y): Do update the status.
return ctrl.Result{}, nil
}

func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
log := ctrl.LoggerFrom(ctx)

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
}
opState, err := r.reconcileObjects(ctx, runtime, &trainJob)

originStatus := trainJob.Status.DeepCopy()
setSuspendedCondition(&trainJob)
setCreatedCondition(&trainJob, opState)
if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil {
return ctrl.Result{}, errors.Join(err, terminalCondErr)
}
if !equality.Semantic.DeepEqual(&trainJob, originStatus) {
return ctrl.Result{}, errors.Join(err, r.client.Status().Update(ctx, &trainJob))
}
return ctrl.Result{}, err
}

func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) (objsOpState, error) {
log := ctrl.LoggerFrom(ctx)

objs, err := runtime.NewObjects(ctx, trainJob)
if err != nil {
return err
return buildFailed, err
}
for _, obj := range objs {
var gvk schema.GroupVersionKind
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
return err
return buildFailed, err
}
logKeysAndValues := []any{
"groupVersionKind", gvk.String(),
Expand All @@ -102,21 +127,91 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k
}
switch {
case created:
log.V(5).Info("Succeeded to create object", logKeysAndValues)
log.V(5).Info("Succeeded to create object", logKeysAndValues...)
continue
case client.IgnoreAlreadyExists(creationErr) != nil:
return creationErr
return creationFailed, creationErr
default:
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
if err = r.client.Update(ctx, obj); err != nil {
return err
return updateFailed, err
}
log.V(5).Info("Succeeded to update object", logKeysAndValues)
log.V(5).Info("Succeeded to update object", logKeysAndValues...)
}
}
return creationSucceeded, nil
}

func setCreatedCondition(trainJob *kubeflowv2.TrainJob, opState objsOpState) {
var newCond metav1.Condition
switch opState {
case creationSucceeded:
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobCreated,
Status: metav1.ConditionTrue,
Message: constants.TrainJobJobsCreationSucceededMessage,
Reason: kubeflowv2.TrainJobJobsCreationSucceededReason,
}
case buildFailed:
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobCreated,
Status: metav1.ConditionFalse,
Message: constants.TrainJobJobsBuildFailedMessage,
Reason: kubeflowv2.TrainJobJobsBuildFailedReason,
}
// TODO (tenzen-y): Provide more granular message based on creation or update failure.
case creationFailed, updateFailed:
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobCreated,
Status: metav1.ConditionFalse,
Message: constants.TrainJobJobsCreationFailedMessage,
Reason: kubeflowv2.TrainJobJobsCreationFailedReason,
}
default:
return
}
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
}

func setSuspendedCondition(trainJob *kubeflowv2.TrainJob) {
var newCond metav1.Condition
switch {
case ptr.Deref(trainJob.Spec.Suspend, false):
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobSuspended,
Status: metav1.ConditionTrue,
Message: constants.TrainJobSuspendedMessage,
Reason: kubeflowv2.TrainJobSuspendedReason,
}
case meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobSuspended):
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobSuspended,
Status: metav1.ConditionFalse,
Message: constants.TrainJobResumedMessage,
Reason: kubeflowv2.TrainJobResumedReason,
}
default:
return
}
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
}

func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) error {
terminalCond, err := runtime.TerminalCondition(ctx, trainJob)
if err != nil {
return err
}
if terminalCond != nil {
meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond)
}
return nil
}

func isTrainJobFinished(trainJob *kubeflowv2.TrainJob) bool {
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobComplete) ||
meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobFailed)
}

func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Expand Down
5 changes: 5 additions & 0 deletions pkg/runtime.v2/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -59,6 +60,10 @@ func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubef
return r.buildObjects(ctx, trainJob, clTrainingRuntime.Spec.Template, clTrainingRuntime.Spec.MLPolicy, clTrainingRuntime.Spec.PodGroupPolicy)
}

func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
return r.TrainingRuntime.TerminalCondition(ctx, trainJob)
}

func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
return nil
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -127,6 +128,10 @@ func (r *TrainingRuntime) buildObjects(
return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
var builders []runtime.ReconcilerBuilder
for _, ex := range r.framework.WatchExtensionPlugins() {
Expand Down
Loading

0 comments on commit 94dee0e

Please sign in to comment.