Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic jobProvider and suggestionComposer registration #1069

Merged
merged 10 commits into from
Feb 28, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions pkg/controller.v1alpha3/consts/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ const (

// DefaultKatibNamespaceEnvName is the default env name of katib namespace
DefaultKatibNamespaceEnvName = "KATIB_CORE_NAMESPACE"
// DefaultKatibComposerEnvName is the default env name of katib suggestion composer
DefaultKatibComposerEnvName = "KATIB_SUGGESTION_COMPOSER"

// KatibConfigMapName is the config map constants
// Configmap name which includes Katib's configuration
Expand Down Expand Up @@ -102,18 +104,6 @@ const (
// JobKindPyTorch is the kind of PyTorchJob.
JobKindPyTorch = "PyTorchJob"

// JobVersionJob is the api version of Kubernetes Job.
JobVersionJob = "v1"
// JobVersionTF is the api version of TFJob.
JobVersionTF = "v1"
// JobVersionPyTorch is the api version of PyTorchJob.
JobVersionPyTorch = "v1"

// JobGroupJob is the group name of Kubernetes Job.
JobGroupJob = "batch"
// JobGroupKubeflow is the group name of Kubeflow.
JobGroupKubeflow = "kubeflow.org"

// AnnotationIstioSidecarInjectName is the annotation of Istio Sidecar
AnnotationIstioSidecarInjectName = "sidecar.istio.io/inject"

Expand All @@ -124,4 +114,6 @@ const (
var (
// DefaultKatibNamespace is the default namespace of katib deployment.
DefaultKatibNamespace = env.GetEnvOrDefault(DefaultKatibNamespaceEnvName, "kubeflow")
// DefaultComposer is the default composer of katib suggestion.
DefaultComposer = env.GetEnvOrDefault(DefaultKatibComposerEnvName, "General")
)
26 changes: 20 additions & 6 deletions pkg/controller.v1alpha3/suggestion/composer/composer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package composer

import (
"fmt"
"reflect"
"sigs.k8s.io/controller-runtime/pkg/manager"

"github.com/spf13/viper"
appsv1 "k8s.io/api/apps/v1"
Expand All @@ -28,23 +30,27 @@ const (
defaultGRPCHealthCheckProbe = "/bin/grpc_health_probe"
)

var log = logf.Log.WithName("suggestion-composer")
var (
log = logf.Log.WithName("suggestion-composer")
ComposerRegistry = make(map[string]reflect.Type)
)

type Composer interface {
DesiredDeployment(s *suggestionsv1alpha3.Suggestion) (*appsv1.Deployment, error)
DesiredService(s *suggestionsv1alpha3.Suggestion) (*corev1.Service, error)
CreateComposer(mgr manager.Manager) Composer
}

type General struct {
scheme *runtime.Scheme
client.Client
}

func New(scheme *runtime.Scheme, client client.Client) Composer {
return &General{
scheme: scheme,
Client: client,
}
func New(mgr manager.Manager) Composer {
// We assume DefaultComposer always exists in ComposerRegistry.
composerType, _ := ComposerRegistry[consts.DefaultComposer]
composer := reflect.New(composerType).Elem().Interface()
return composer.(Composer).CreateComposer(mgr)
}

func (g *General) DesiredDeployment(s *suggestionsv1alpha3.Suggestion) (*appsv1.Deployment, error) {
Expand Down Expand Up @@ -208,3 +214,11 @@ func (g *General) desiredContainer(s *suggestionsv1alpha3.Suggestion) (*corev1.C
}
return c, nil
}

func (g *General) CreateComposer(mgr manager.Manager) Composer {
return &General{mgr.GetScheme(), mgr.GetClient()}
}

func init() {
ComposerRegistry[consts.DefaultComposer] = reflect.TypeOf(&General{})
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func newReconciler(mgr manager.Manager) reconcile.Reconciler {
Client: mgr.GetClient(),
SuggestionClient: suggestionclient.New(),
scheme: mgr.GetScheme(),
Composer: composer.New(mgr.GetScheme(), mgr.GetClient()),
Composer: composer.New(mgr),
recorder: mgr.GetRecorder(ControllerName),
}
}
Expand Down
11 changes: 10 additions & 1 deletion pkg/controller.v1alpha3/trial/trial_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func add(mgr manager.Manager, r reconcile.Reconciler) error {
return err
}

for _, gvk := range jobv1alpha3.GetSupportedJobList() {
for _, gvk := range jobv1alpha3.SupportedJobList {
unstructuredJob := &unstructured.Unstructured{}
unstructuredJob.SetGroupVersionKind(gvk)
err = c.Watch(
Expand Down Expand Up @@ -274,6 +274,15 @@ func (r *ReconcileTrial) reconcileJob(instance *trialsv1alpha3.Trial, desiredJob
if instance.IsCompleted() {
return nil, nil
}
jobProvider, err := jobv1alpha3.New(desiredJob.GetKind())
if err != nil {
return nil, err
}
// mutate desiredJob according to provider
if err := jobProvider.MutateJob(instance, desiredJob); err != nil {
logger.Error(err, "Mutating desiredSpec of km.Training error")
return nil, err
}
logger.Info("Creating Job", "kind", kind,
"name", desiredJob.GetName())
err = r.Create(context.TODO(), desiredJob)
Expand Down
2 changes: 2 additions & 0 deletions pkg/controller.v1alpha3/trial/trial_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package trial

import (
"bytes"
"github.com/kubeflow/katib/pkg/job/v1alpha3/kubeflow"
"testing"
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
"time"

Expand Down Expand Up @@ -38,6 +39,7 @@ var tfJobKey = types.NamespacedName{Name: "test", Namespace: namespace}

func init() {
logf.SetLogger(logf.ZapLogger(true))
kubeflow.Register()
}

func TestCreateTFJobTrial(t *testing.T) {
Expand Down
47 changes: 6 additions & 41 deletions pkg/job/v1alpha3/consts.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,13 @@
package v1alpha3

import (
"github.com/kubeflow/katib/pkg/controller.v1alpha3/consts"
"k8s.io/apimachinery/pkg/runtime/schema"
)

const (
// JobNameLabel represents the label key for the job name, the value is job name
JobNameLabel = "job-name"
// JobRoleLabel represents the label key for the job role, e.g. the value is master
JobRoleLabel = "job-role"
// TFJobRoleLabel is deprecated in kubeflow 0.7, but we need to be compatible.
TFJobRoleLabel = "tf-job-role"
// PyTorchJobRoleLabel is deprecated in kubeflow 0.7, but we need to be compatible.
PyTorchJobRoleLabel = "pytorch-job-role"
var (
// JobRoleMap is the map which is used to determin if the replica is master.
// Katib will inject metrics collector into master replica.
JobRoleMap = make(map[string][]string)
// SupportedJobList returns the list of the supported jobs' GVK.
SupportedJobList = make(map[string]schema.GroupVersionKind)
)

// JobRoleMap is the map which is used to determin if the replica is master.
// Katib will inject metrics collector into master replica.
var JobRoleMap = map[string][]string{
// Job kind does not support distributed training, thus no master.
consts.JobKindJob: {},
consts.JobKindTF: {JobRoleLabel, TFJobRoleLabel},
consts.JobKindPyTorch: {JobRoleLabel, PyTorchJobRoleLabel},
}

// GetSupportedJobList returns the list of the supported jobs' GVK.
func GetSupportedJobList() []schema.GroupVersionKind {
supportedJobList := []schema.GroupVersionKind{
{
Group: consts.JobGroupJob,
Version: consts.JobVersionJob,
Kind: consts.JobKindJob,
},
{
Group: consts.JobGroupKubeflow,
Version: consts.JobVersionTF,
Kind: consts.JobKindTF,
},
{
Group: consts.JobGroupKubeflow,
Version: consts.JobVersionPyTorch,
Kind: consts.JobKindPyTorch,
},
}
return supportedJobList
}
26 changes: 26 additions & 0 deletions pkg/job/v1alpha3/job/job.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package job

import (
"github.com/kubeflow/katib/pkg/apis/controller/trials/v1alpha3"
"github.com/kubeflow/katib/pkg/controller.v1alpha3/consts"
job "github.com/kubeflow/katib/pkg/job/v1alpha3"
commonv1 "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"reflect"
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"
)

Expand Down Expand Up @@ -65,3 +70,24 @@ func (j Job) IsTrainingContainer(index int, c corev1.Container) bool {
}
return false
}
func (j Job) MutateJob(*v1alpha3.Trial, *unstructured.Unstructured) error {
return nil
}

func (j *Job) Create(kind string) job.Provider {
return &Job{}
}

func Register() {
job.ProviderRegistry[consts.JobKindJob] = reflect.TypeOf(&Job{})
job.SupportedJobList[consts.JobKindJob] = schema.GroupVersionKind{
Group: "batch",
Version: "v1",
Kind: "Job",
}
job.JobRoleMap[consts.JobKindJob] = []string{}
}

func init() {
Register()
}
35 changes: 34 additions & 1 deletion pkg/job/v1alpha3/kubeflow/kubeflow.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package kubeflow

import (
"github.com/kubeflow/katib/pkg/apis/controller/trials/v1alpha3"
job "github.com/kubeflow/katib/pkg/job/v1alpha3"
pytorchv1 "github.com/kubeflow/pytorch-operator/pkg/apis/pytorch/v1"
commonv1 "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
tfv1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"reflect"
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"

"github.com/kubeflow/katib/pkg/controller.v1alpha3/consts"
Expand Down Expand Up @@ -75,3 +78,33 @@ func (k Kubeflow) IsTrainingContainer(index int, c corev1.Container) bool {
}
return false
}

func (k Kubeflow) MutateJob(*v1alpha3.Trial, *unstructured.Unstructured) error {
return nil
}

func (k *Kubeflow) Create(kind string) job.Provider {
return &Kubeflow{Kind: kind}
}

func Register() {
job.ProviderRegistry[consts.JobKindTF] = reflect.TypeOf(&Kubeflow{})
job.SupportedJobList[consts.JobKindTF] = schema.GroupVersionKind{
Group: "kubeflow.org",
Version: "v1",
Kind: "TFJob",
}
job.JobRoleMap[consts.JobKindTF] = []string{"job-role", "tf-job-role"}
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
job.ProviderRegistry[consts.JobKindPyTorch] = reflect.TypeOf(&Kubeflow{})
job.SupportedJobList[consts.JobKindPyTorch] = schema.GroupVersionKind{
Group: "kubeflow.org",
Version: "v1",
Kind: "PyTorchJob",
}
job.JobRoleMap[consts.JobKindPyTorch] = []string{"job-role", "pytorch-job-role"}

}

func init() {
Register()
}
27 changes: 15 additions & 12 deletions pkg/job/v1alpha3/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,39 @@ package v1alpha3

import (
"fmt"
"github.com/kubeflow/katib/pkg/apis/controller/trials/v1alpha3"
"reflect"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"

"github.com/kubeflow/katib/pkg/controller.v1alpha3/consts"
"github.com/kubeflow/katib/pkg/job/v1alpha3/job"
"github.com/kubeflow/katib/pkg/job/v1alpha3/kubeflow"
commonv1 "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
)

var (
ProviderRegistry = make(map[string]reflect.Type)
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
)

// Provider provides utilities for different jobs.
type Provider interface {
// GetDeployedJobStatus get the deployed job status.
GetDeployedJobStatus(
deployedJob *unstructured.Unstructured) (*commonv1.JobCondition, error)
// IsTrainingContainer returns if the c is the actual training container.
IsTrainingContainer(index int, c corev1.Container) bool
// Mutate jobSpec before creation if necessary
MutateJob(*v1alpha3.Trial, *unstructured.Unstructured) error
// Recreate Provider from kind
Create(kind string) Provider
}

// New creates a new Provider.
func New(kind string) (Provider, error) {
switch kind {
case consts.JobKindJob:
return &job.Job{}, nil
case consts.JobKindPyTorch, consts.JobKindTF:
return &kubeflow.Kubeflow{
Kind: kind,
}, nil
default:
if providerType, ok := ProviderRegistry[kind]; ok {
ptr := reflect.New(providerType).Elem().Interface().(Provider)
return ptr.Create(kind), nil
} else {
return nil, fmt.Errorf(
"Failed to create the provider: Unknown kind %s", kind)
"failed to create the provider: Unknown kind %s", kind)
}
}
2 changes: 1 addition & 1 deletion pkg/webhook/v1alpha3/experiment/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (g *DefaultValidator) validateTrialTemplate(instance *experimentsv1alpha3.E

func (g *DefaultValidator) validateSupportedJob(job *unstructured.Unstructured) error {
gvk := job.GroupVersionKind()
supportedJobs := jobv1alpha3.GetSupportedJobList()
supportedJobs := jobv1alpha3.SupportedJobList
for _, sJob := range supportedJobs {
if gvk == sJob {
return nil
Expand Down
4 changes: 4 additions & 0 deletions pkg/webhook/v1alpha3/experiment/validator/validator_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package validator

import (
"github.com/kubeflow/katib/pkg/job/v1alpha3/job"
"github.com/kubeflow/katib/pkg/job/v1alpha3/kubeflow"
"testing"

"github.com/golang/mock/gomock"
Expand All @@ -15,6 +17,8 @@ import (

func init() {
logf.SetLogger(logf.ZapLogger(false))
job.Register()
kubeflow.Register()
}

func TestValidateTFJobTrialTemplate(t *testing.T) {
Expand Down
7 changes: 6 additions & 1 deletion pkg/webhook/v1alpha3/pod/inject_webhook_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package pod

import (
"github.com/kubeflow/katib/pkg/job/v1alpha3/kubeflow"
"testing"

common "github.com/kubeflow/katib/pkg/apis/controller/common/v1alpha3"
v1 "k8s.io/api/core/v1"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/equality"
)

func init() {
kubeflow.Register()
}

func TestWrapWorkerContainer(t *testing.T) {
testCases := []struct {
Pod *v1.Pod
Expand Down
2 changes: 1 addition & 1 deletion pkg/webhook/v1alpha3/pod/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
)

func getKatibJob(pod *v1.Pod) (string, string, error) {
for _, gvk := range jobv1alpha3.GetSupportedJobList() {
for _, gvk := range jobv1alpha3.SupportedJobList {
owners := pod.GetOwnerReferences()
for _, owner := range owners {
if isMatchGVK(owner, gvk) {
Expand Down