Skip to content

Commit

Permalink
feat(trial): Add trial and metrics collector (kubeflow#3)
Browse files Browse the repository at this point in the history
* feat(trial): Add trial and metrics collector

Signed-off-by: Ce Gao <gaoce@caicloud.io>

* fix: Remove debug statements

Signed-off-by: Ce Gao <gaoce@caicloud.io>
  • Loading branch information
gaocegege authored and caicloud-bot committed May 7, 2019
1 parent 30f0f32 commit cd6d1e8
Show file tree
Hide file tree
Showing 18 changed files with 732 additions and 74 deletions.
16 changes: 16 additions & 0 deletions manifests/v1alpha3/crd/experiment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
creationTimestamp: null
labels:
controller-tools.k8s.io: "1.0"
name: experiments.kubeflow.org
spec:
group: kubeflow.org
version: v1alpha2
names:
kind: Experiment
plural: experiments
scope: Namespaced
subresources:
status: {}
16 changes: 16 additions & 0 deletions manifests/v1alpha3/crd/suggestion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
creationTimestamp: null
labels:
controller-tools.k8s.io: "1.0"
name: suggestions.kubeflow.org
spec:
group: kubeflow.org
version: v1alpha2
names:
kind: Suggestion
plural: suggestions
scope: Namespaced
subresources:
status: {}
16 changes: 16 additions & 0 deletions manifests/v1alpha3/crd/trial.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
creationTimestamp: null
labels:
controller-tools.k8s.io: "1.0"
name: trials.kubeflow.org
spec:
group: kubeflow.org
version: v1alpha2
names:
kind: Trial
plural: trials
scope: Namespaced
subresources:
status: {}
29 changes: 29 additions & 0 deletions manifests/v1alpha3/sample/trial.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
apiVersion: kubeflow.org/v1alpha2
kind: Trial
metadata:
labels:
controller-tools.k8s.io: "1.0"
name: sample
spec:
objective: accuracy
metrics:
- loss
MetricsCollector: general
parameterAssignments:
- name: param1
value: "1.2345"
runSpec: |-
apiVersion: "kubeflow.org/v1beta1"
kind: "TFJob"
metadata:
name: "dist-mnist-for-e2e-test"
spec:
tfReplicaSpecs:
Worker:
replicas: 1
restartPolicy: Never
template:
spec:
containers:
- name: tensorflow
image: gaocegege/mnist:1
22 changes: 22 additions & 0 deletions pkg/api/operators/apis/addtoscheme_tfjob_v1beta2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package apis

import (
"github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1beta2"
)

func init() {
// Register the types with the Scheme so the components can map objects to GroupVersionKinds and back
AddToSchemes = append(AddToSchemes, v1beta2.SchemeBuilder.AddToScheme)
}
2 changes: 1 addition & 1 deletion pkg/api/operators/apis/suggestion/v1alpha2/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (

var (
// SchemeGroupVersion is group version used to register these objects
SchemeGroupVersion = schema.GroupVersion{Group: "suggestions.kubeflow.org", Version: "v1alpha2"}
SchemeGroupVersion = schema.GroupVersion{Group: "kubeflow.org", Version: "v1alpha2"}

// SchemeBuilder is used to add go types to the GroupVersionKind scheme
SchemeBuilder = &scheme.Builder{GroupVersion: SchemeGroupVersion}
Expand Down
12 changes: 9 additions & 3 deletions pkg/api/operators/apis/trial/v1alpha2/trial_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ limitations under the License.
package v1alpha2

import (
"k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

type TrialSpec struct {
// Key-value pairs for hyperparameters and assignment values.
ParameterAssignments []ParameterAssignment `json:"parameterAssignments"`

Objective string `json:"objective"`
Metrics []string `json:"metrics"`
MetricsCollector string `json:"metricCollector"`

// Raw text for the trial run spec. This can be any generic Kubernetes
// runtime object. The trial operator should create the resource as written,
// and let the corresponding resource controller (e.g. tf-operator) handle
Expand Down Expand Up @@ -51,7 +55,7 @@ type TrialStatus struct {
Conditions []TrialCondition `json:"conditions,omitempty"`

// Results of the Trial - objectives and other metrics values.
Observation Observation `json:"observation,omitempty"`
Observation *Observation `json:"observation,omitempty"`
}

type ParameterAssignment struct {
Expand All @@ -65,8 +69,10 @@ type Metric struct {
}

type Observation struct {
// Objective is objective name and value.
Objective *Metric `json:"objective,omitempty"`
// Key-value pairs for metric names and values
Metrics []Metric `json:"metrics"`
Metrics []Metric `json:"metrics,omitempty"`
}

// +k8s:deepcopy-gen=true
Expand Down
16 changes: 15 additions & 1 deletion pkg/api/operators/apis/trial/v1alpha2/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 70 additions & 0 deletions pkg/controller/v1alpha3/trial/clientset/tf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package clientset

import (
"context"

tfv1beta1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1beta1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"

trialv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/trial/v1alpha2"
"github.com/kubeflow/katib/pkg/controller/v1alpha3/recorder"
"github.com/kubeflow/katib/pkg/controller/v1alpha3/util"
)

const (
loggerNameTF = "tensorflow-client"
)

// TensorFlow is the type for tensorflow client.
type TensorFlow interface {
CreateOrUpdateTFJob(t *trialv1alpha2.Trial, tfJob *tfv1beta1.TFJob) error
}

// GeneralTF is the general client for TensorFlow.
type GeneralTF struct {
client.Client
recorder.Recorder
}

// CreateOrUpdateTFJob creates or updates the TFJob owned by the trial.
func (g *GeneralTF) CreateOrUpdateTFJob(t *trialv1alpha2.Trial, tfJob *tfv1beta1.TFJob) error {
log := logf.Log.WithName(loggerNameTF)
found := &tfv1beta1.TFJob{}
err := g.Get(context.TODO(), types.NamespacedName{
Name: tfJob.Name,
Namespace: tfJob.Namespace,
}, found)
if err != nil && errors.IsNotFound(err) {
log.Info("Creating TFJob", "namespace", tfJob.Namespace, "name", tfJob.Name)
err = g.Create(context.TODO(), tfJob)
if err != nil {
return err
}
g.ReportChange(t, util.FlagCreate, util.TypeTFJob)
return nil
}

// // Update the found object and write the result back if there are any changes.
// if !reflect.DeepEqual(tfJob.Spec, found.Spec) {
// found.Spec = tfJob.Spec
// log.Info("Updating TFJob", "namespace", tfJob.Namespace, "name", tfJob.Name)
// err = g.Update(context.TODO(), tfJob)
// if err != nil {
// return err
// }
// g.ReportChange(t, util.FlagUpdate, util.TypeTFJob)
// }
tfJob.Status = found.Status
return nil
}

// New creates a new TFJob client.
func NewTF(c client.Client, r recorder.Recorder) TensorFlow {
return &GeneralTF{
Client: c,
Recorder: r,
}
}
63 changes: 63 additions & 0 deletions pkg/controller/v1alpha3/trial/clientset/unstructured.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package clientset

import (
"context"

"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"

trialv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/trial/v1alpha2"
"github.com/kubeflow/katib/pkg/controller/v1alpha3/recorder"
"github.com/kubeflow/katib/pkg/controller/v1alpha3/util"
)

const (
loggerNameUnstructured = "unstructured-client"
)

// Unstructured is the type for unstructured client.
type Unstructured interface {
CreateOrUpdateUnifiedJob(t *trialv1alpha2.Trial, job *unstructured.Unstructured) error
}

// GeneralUnstructured is the general client for Unstructured.
type GeneralUnstructured struct {
client.Client
recorder.Recorder
}

// CreateOrUpdateUnifiedJob creates or updates the unified job owned by the trial.
func (g *GeneralUnstructured) CreateOrUpdateUnifiedJob(t *trialv1alpha2.Trial, job *unstructured.Unstructured) error {
typedName := types.NamespacedName{
Name: job.GetName(),
Namespace: job.GetNamespace(),
}
logger := logf.Log.WithName(typedName.String())
found := job.DeepCopy()
err := g.Get(context.TODO(), typedName, found)
if err != nil && errors.IsNotFound(err) {
logger.Info("Creating Job", "namespace", job.GetNamespace(), "name", job.GetName())
err = g.Create(context.TODO(), job)
if err != nil {
return err
}
g.ReportChange(t, util.FlagCreate, util.TypeTFJob)
return nil
} else if err != nil {
return err
}

// We do not support updating now.
return nil
}

// NewUnstructured creates a new Unstructured client.
func NewUnstructured(c client.Client, r recorder.Recorder) Unstructured {
return &GeneralUnstructured{
Client: c,
Recorder: r,
}
}
Loading

0 comments on commit cd6d1e8

Please sign in to comment.