diff --git a/pkg/controller.v1beta1/trial/trial_controller_test.go b/pkg/controller.v1beta1/trial/trial_controller_test.go index 4b82c225285..085f2e7c0fb 100644 --- a/pkg/controller.v1beta1/trial/trial_controller_test.go +++ b/pkg/controller.v1beta1/trial/trial_controller_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/onsi/gomega" "github.com/prometheus/client_golang/prometheus" "github.com/spf13/viper" @@ -42,6 +43,7 @@ import ( "github.com/kubeflow/katib/pkg/controller.v1beta1/consts" trialutil "github.com/kubeflow/katib/pkg/controller.v1beta1/trial/util" "github.com/kubeflow/katib/pkg/controller.v1beta1/util" + managerclientmock "github.com/kubeflow/katib/pkg/mock/v1beta1/trial/managerclient" ) const ( @@ -87,7 +89,9 @@ func TestAdd(t *testing.T) { func TestReconcileBatchJob(t *testing.T) { g := gomega.NewGomegaWithT(t) - mockMC := &mockManagerClient{} + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockManagerClient := managerclientmock.NewMockManagerClient(mockCtrl) // Setup the Manager and Controller. Wrap the Controller Reconcile function so it writes each request to a // channel when it is finished. @@ -98,7 +102,7 @@ func TestReconcileBatchJob(t *testing.T) { r := &ReconcileTrial{ Client: mgr.GetClient(), scheme: mgr.GetScheme(), - ManagerClient: mockMC, + ManagerClient: mockManagerClient, recorder: mgr.GetEventRecorderFor(ControllerName), collector: trialutil.NewTrialsCollector(mgr.GetCache(), prometheus.NewRegistry()), } @@ -177,7 +181,8 @@ func TestReconcileBatchJob(t *testing.T) { t.Run(`Trial run with "Failed" BatchJob.`, func(t *testing.T) { g := gomega.NewGomegaWithT(t) - mockMC.msg = observationLogUnavailable + mockManagerClient.EXPECT().DeleteTrialObservationLog(gomock.Any()).Return(nil, nil) + trial := newFakeTrialBatchJob() batchJob := &batchv1.Job{} @@ -236,7 +241,10 @@ func TestReconcileBatchJob(t *testing.T) { t.Run(`Trail with "Complete" BatchJob and Available metrics.`, func(t *testing.T) { g := gomega.NewGomegaWithT(t) - mockMC.msg = observationLogAvailable + gomock.InOrder( + mockManagerClient.EXPECT().GetTrialObservationLog(gomock.Any()).Return(observationLogAvailable, nil).MinTimes(1), + mockManagerClient.EXPECT().DeleteTrialObservationLog(gomock.Any()).Return(nil, nil), + ) batchJob := &batchv1.Job{} batchJobCompleteMessage := "BatchJob completed test message" batchJobCompleteReason := "BatchJob completed test reason" @@ -284,7 +292,10 @@ func TestReconcileBatchJob(t *testing.T) { t.Run(`Trail with "Complete" BatchJob and Unavailable metrics.`, func(t *testing.T) { g := gomega.NewGomegaWithT(t) - mockMC.msg = observationLogUnavailable + gomock.InOrder( + mockManagerClient.EXPECT().GetTrialObservationLog(gomock.Any()).Return(observationLogUnavailable, nil).MinTimes(1), + mockManagerClient.EXPECT().DeleteTrialObservationLog(gomock.Any()).Return(nil, nil), + ) // Create the Trial trial := newFakeTrialBatchJob() g.Expect(c.Create(ctx, trial)).NotTo(gomega.HaveOccurred()) @@ -433,16 +444,3 @@ func newFakeTrialBatchJob() *trialsv1beta1.Trial { }, } } - -type mockManagerClient struct { - msg *api_pb.GetObservationLogReply -} - -func (c *mockManagerClient) GetTrialObservationLog(instance *trialsv1beta1.Trial) (*api_pb.GetObservationLogReply, error) { - return c.msg, nil -} - -func (c *mockManagerClient) DeleteTrialObservationLog(instance *trialsv1beta1.Trial) (*api_pb.DeleteObservationLogReply, error) { - c.msg = nil - return &api_pb.DeleteObservationLogReply{}, nil -}