Skip to content

Commit

Permalink
refactor: Use manager client to get log for test
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <gaoce@caicloud.io>
  • Loading branch information
gaocegege committed May 22, 2019
1 parent 73d940d commit 8787e7d
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 42 deletions.
18 changes: 17 additions & 1 deletion pkg/controller/v1alpha2/trial/managerclient/managerclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type ManagerClient interface {
CreateTrialInDB(instance *trialsv1alpha2.Trial) error
UpdateTrialStatusInDB(instance *trialsv1alpha2.Trial) error
GetTrialObservation(instance *trialsv1alpha2.Trial) error
GetTrialObservationLog(
instance *trialsv1alpha2.Trial) (*api_pb.GetObservationLogReply, error)
GetTrialConf(instance *trialsv1alpha2.Trial) *api_pb.Trial
}

Expand Down Expand Up @@ -66,8 +68,22 @@ func (d *DefaultClient) UpdateTrialStatusInDB(instance *trialsv1alpha2.Trial) er
return nil
}

func (d *DefaultClient) GetTrialObservation(instance *trialsv1alpha2.Trial) error {
func (d *DefaultClient) GetTrialObservationLog(
instance *trialsv1alpha2.Trial) (*api_pb.GetObservationLogReply, error) {
// read GetObservationLog call and update observation field
objectiveMetricName := instance.Spec.Objective.ObjectiveMetricName
request := &api_pb.GetObservationLogRequest{
TrialName: instance.Name,
MetricName: objectiveMetricName,
}
reply, err := common.GetObservationLog(request)
if err != nil {
return nil, err
}
return reply, nil
}

func (d *DefaultClient) GetTrialObservation(instance *trialsv1alpha2.Trial) error {
return nil
}

Expand Down
9 changes: 4 additions & 5 deletions pkg/controller/v1alpha2/trial/trial_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import (
trialsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/trial/v1alpha2"
commonv1alpha2 "github.com/kubeflow/katib/pkg/common/v1alpha2"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/trial/managerclient"
trialutil "github.com/kubeflow/katib/pkg/controller/v1alpha2/trial/util"
)

var (
Expand Down Expand Up @@ -159,7 +158,7 @@ func (r *ReconcileTrial) Reconcile(request reconcile.Request) (reconcile.Result,
instance.Status.CompletionTime = &metav1.Time{}
}
msg := "Trial is created"
instance.MarkTrialStatusCreated(trialutil.TrialCreatedReason, msg)
instance.MarkTrialStatusCreated(TrialCreatedReason, msg)
err = r.CreateTrialInDB(instance)
if err != nil {
logger.Error(err, "Create trial in DB error")
Expand Down Expand Up @@ -220,11 +219,11 @@ func (r *ReconcileTrial) reconcileTrial(instance *trialsv1alpha2.Trial) error {
//Job already exists
//TODO Can desired Spec differ from deployedSpec?
if deployedJob != nil {
if err = trialutil.UpdateTrialStatusCondition(instance, deployedJob); err != nil {
if err = r.UpdateTrialStatusCondition(instance, deployedJob); err != nil {
logger.Error(err, "Update trial status condition error")
return err
}
if err = trialutil.UpdateTrialStatusObservation(instance, deployedJob); err != nil {
if err = r.UpdateTrialStatusObservation(instance, deployedJob); err != nil {
logger.Error(err, "Update trial status observation error")
return err
}
Expand Down Expand Up @@ -259,7 +258,7 @@ func (r *ReconcileTrial) reconcileJob(instance *trialsv1alpha2.Trial, desiredJob
}

msg := "Trial is running"
instance.MarkTrialStatusRunning(trialutil.TrialRunningReason, msg)
instance.MarkTrialStatusRunning(TrialRunningReason, msg)
return deployedJob, nil
}

Expand Down
10 changes: 10 additions & 0 deletions pkg/controller/v1alpha2/trial/trial_controller_consts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package trial

const (
DefaultJobKind = "Job"
TrialCreatedReason = "TrialCreated"
TrialRunningReason = "TrialRunning"
TrialSucceededReason = "TrialSucceeded"
TrialFailedReason = "TrialFailed"
TrialKilledReason = "TrialKilled"
)
4 changes: 4 additions & 0 deletions pkg/controller/v1alpha2/trial/trial_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

commonv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/common/v1alpha2"
trialsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/trial/v1alpha2"
api_pb "github.com/kubeflow/katib/pkg/api/v1alpha2"
managerclientmock "github.com/kubeflow/katib/pkg/mock/v1alpha2/trial/managerclient"
)

Expand Down Expand Up @@ -96,6 +97,9 @@ func TestReconcileTFJobTrial(t *testing.T) {
mc := managerclientmock.NewMockManagerClient(mockCtrl)
mc.EXPECT().CreateTrialInDB(gomock.Any()).Return(nil).AnyTimes()
mc.EXPECT().UpdateTrialStatusInDB(gomock.Any()).Return(nil).AnyTimes()
mc.EXPECT().GetTrialObservationLog(gomock.Any()).Return(&api_pb.GetObservationLogReply{
ObservationLog: nil,
}, nil).AnyTimes()

// Setup the Manager and Controller. Wrap the Controller Reconcile function so it writes each request to a
// channel when it is finished.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package util
package trial

import (
"strconv"
Expand All @@ -22,27 +22,14 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"

commonv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/common/v1alpha2"
trialsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/trial/v1alpha2"
api_pb "github.com/kubeflow/katib/pkg/api/v1alpha2"
common "github.com/kubeflow/katib/pkg/common/v1alpha2"
commonv1beta2 "github.com/kubeflow/tf-operator/pkg/apis/common/v1beta2"
)

var log = logf.Log.WithName("trial-status-util")

const (
DefaultJobKind = "Job"
TrialCreatedReason = "TrialCreated"
TrialRunningReason = "TrialRunning"
TrialSucceededReason = "TrialSucceeded"
TrialFailedReason = "TrialFailed"
TrialKilledReason = "TrialKilled"
)

func UpdateTrialStatusCondition(instance *trialsv1alpha2.Trial, deployedJob *unstructured.Unstructured) error {
func (r *ReconcileTrial) UpdateTrialStatusCondition(instance *trialsv1alpha2.Trial, deployedJob *unstructured.Unstructured) error {

kind := deployedJob.GetKind()
status, ok, unerr := unstructured.NestedFieldCopy(deployedJob.Object, "status")
Expand Down Expand Up @@ -96,29 +83,23 @@ func UpdateTrialStatusCondition(instance *trialsv1alpha2.Trial, deployedJob *uns
return nil
}

func UpdateTrialStatusObservation(instance *trialsv1alpha2.Trial, deployedJob *unstructured.Unstructured) error {

// read GetObservationLog call and update observation field
func (r *ReconcileTrial) UpdateTrialStatusObservation(instance *trialsv1alpha2.Trial, deployedJob *unstructured.Unstructured) error {
objectiveMetricName := instance.Spec.Objective.ObjectiveMetricName
request := &api_pb.GetObservationLogRequest{
TrialName: instance.Name,
MetricName: objectiveMetricName,
}
if reply, err := common.GetObservationLog(request); err != nil {
reply, err := r.GetTrialObservationLog(instance)
if err != nil {
return err
} else {
if reply.ObservationLog != nil {
bestObjectiveValue := getBestObjectiveMetricValue(reply.ObservationLog.MetricLogs, instance.Spec.Objective.Type)
if bestObjectiveValue != nil {
if instance.Status.Observation == nil {
instance.Status.Observation = &commonv1alpha2.Observation{}
metric := commonv1alpha2.Metric{Name: objectiveMetricName, Value: *bestObjectiveValue}
instance.Status.Observation.Metrics = []commonv1alpha2.Metric{metric}
} else {
for index, metric := range instance.Status.Observation.Metrics {
if metric.Name == objectiveMetricName {
instance.Status.Observation.Metrics[index].Value = *bestObjectiveValue
}
}
if reply.ObservationLog != nil {
bestObjectiveValue := getBestObjectiveMetricValue(reply.ObservationLog.MetricLogs, instance.Spec.Objective.Type)
if bestObjectiveValue != nil {
if instance.Status.Observation == nil {
instance.Status.Observation = &commonv1alpha2.Observation{}
metric := commonv1alpha2.Metric{Name: objectiveMetricName, Value: *bestObjectiveValue}
instance.Status.Observation.Metrics = []commonv1alpha2.Metric{metric}
} else {
for index, metric := range instance.Status.Observation.Metrics {
if metric.Name == objectiveMetricName {
instance.Status.Observation.Metrics[index].Value = *bestObjectiveValue
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/mock/v1alpha2/trial/managerclient/katibmanager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8787e7d

Please sign in to comment.