Skip to content

Commit

Permalink
refactor(controller): add job create function in EvaluationService (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Dec 7, 2023
1 parent 43c9c05 commit e118ae4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.dag.DagQuerier;
import ai.starwhale.mlops.domain.dag.bo.Graph;
import ai.starwhale.mlops.domain.evaluation.EvaluationService;
import ai.starwhale.mlops.domain.event.EventService;
import ai.starwhale.mlops.domain.ft.FineTuneAppService;
import ai.starwhale.mlops.domain.job.BizType;
Expand Down Expand Up @@ -96,6 +97,8 @@ public class JobController {
private final RunService runService;
private final UserJobConverter userJobConverter;

private final EvaluationService evaluationService;

public JobController(
JobServiceForWeb jobServiceForWeb,
FineTuneAppService fineTuneAppService, TaskService taskService,
Expand All @@ -106,7 +109,8 @@ public JobController(
FeaturesProperties featuresProperties,
EventService eventService,
RunService runService,
UserJobConverter userJobConverter
UserJobConverter userJobConverter,
EvaluationService evaluationService
) {
this.jobServiceForWeb = jobServiceForWeb;
this.fineTuneAppService = fineTuneAppService;
Expand All @@ -119,6 +123,7 @@ public JobController(
this.eventService = eventService;
this.runService = runService;
this.userJobConverter = userJobConverter;
this.evaluationService = evaluationService;

var actions = InvokerManager.<String, String>create()
.addInvoker("cancel", jobServiceForWeb::cancelJob);
Expand Down Expand Up @@ -224,6 +229,8 @@ public ResponseEntity<ResponseMessage<String>> createJob(
} else if (jobRequest.getType() == JobType.ONLINE_EVAL) {
jobId = jobServiceForWeb.createJob(userJobConverter.convert(projectUrl, jobRequest));
}
} else if (jobRequest.getType() == JobType.EVALUATION) {
jobId = evaluationService.createEvaluationJob(userJobConverter.convert(projectUrl, jobRequest));
} else {
jobId = jobServiceForWeb.createJob(userJobConverter.convert(projectUrl, jobRequest));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import ai.starwhale.mlops.domain.evaluation.bo.ConfigQuery;
import ai.starwhale.mlops.domain.evaluation.mapper.ViewConfigMapper;
import ai.starwhale.mlops.domain.evaluation.po.ViewConfigEntity;
import ai.starwhale.mlops.domain.job.JobCreator;
import ai.starwhale.mlops.domain.job.bo.UserJobCreateRequest;
import ai.starwhale.mlops.domain.project.ProjectService;
import ai.starwhale.mlops.domain.user.UserService;
import cn.hutool.core.io.FileUtil;
Expand All @@ -41,16 +43,20 @@ public class EvaluationService {
private final ViewConfigMapper viewConfigMapper;
private final ViewConfigConverter viewConfigConvertor;

private final JobCreator jobCreator;

public EvaluationService(
UserService userService,
ProjectService projectService,
ViewConfigMapper viewConfigMapper,
ViewConfigConverter viewConfigConvertor
ViewConfigConverter viewConfigConvertor,
JobCreator jobCreator
) {
this.userService = userService;
this.projectService = projectService;
this.viewConfigMapper = viewConfigMapper;
this.viewConfigConvertor = viewConfigConvertor;
this.jobCreator = jobCreator;
}


Expand Down Expand Up @@ -87,4 +93,9 @@ public Boolean createViewConfig(String projectUrl, ConfigRequest configRequest)
int res = viewConfigMapper.createViewConfig(entity);
return res > 0;
}

public Long createEvaluationJob(UserJobCreateRequest request) {
var job = jobCreator.createJob(request);
return job.getId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import ai.starwhale.mlops.common.PageParams;
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.dag.DagQuerier;
import ai.starwhale.mlops.domain.evaluation.EvaluationService;
import ai.starwhale.mlops.domain.event.EventService;
import ai.starwhale.mlops.domain.ft.FineTuneAppService;
import ai.starwhale.mlops.domain.job.DevWay;
Expand Down Expand Up @@ -84,6 +85,8 @@ public class JobControllerTest {

private FeaturesProperties featuresProperties;

private EvaluationService evaluationService;

@BeforeEach
public void setUp() {
jobServiceForWeb = mock(JobServiceForWeb.class);
Expand All @@ -92,6 +95,8 @@ public void setUp() {
modelServingService = mock(ModelServingService.class);
dagQuerier = mock(DagQuerier.class);
runtimeSuggestionService = mock(RuntimeSuggestionService.class);
evaluationService = mock(EvaluationService.class);
when(evaluationService.createEvaluationJob(any())).thenReturn(1L);
featuresProperties = new FeaturesProperties();
controller = new JobController(
jobServiceForWeb,
Expand All @@ -104,8 +109,8 @@ public void setUp() {
featuresProperties,
mock(EventService.class),
mock(RunService.class),
mock(UserJobConverter.class)
);
mock(UserJobConverter.class),
evaluationService);
}

@Test
Expand Down Expand Up @@ -306,8 +311,8 @@ public void testJobPauseDisabled() {
featuresProperties,
mock(EventService.class),
mock(RunService.class),
mock(UserJobConverter.class)
);
mock(UserJobConverter.class),
evaluationService);
assertThrows(StarwhaleApiException.class,
() -> controller.action("", "job1", "pause"));

Expand All @@ -329,8 +334,8 @@ public void testJobResumeDisabled() {
featuresProperties,
mock(EventService.class),
mock(RunService.class),
mock(UserJobConverter.class)
);
mock(UserJobConverter.class),
evaluationService);
assertThrows(StarwhaleApiException.class,
() -> controller.action("", "job1", "resume"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@
import static org.hamcrest.Matchers.hasProperty;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import ai.starwhale.mlops.api.protocol.evaluation.ConfigRequest;
import ai.starwhale.mlops.domain.evaluation.bo.ConfigQuery;
import ai.starwhale.mlops.domain.evaluation.mapper.ViewConfigMapper;
import ai.starwhale.mlops.domain.evaluation.po.ViewConfigEntity;
import ai.starwhale.mlops.domain.job.JobCreator;
import ai.starwhale.mlops.domain.job.bo.Job;
import ai.starwhale.mlops.domain.job.bo.UserJobCreateRequest;
import ai.starwhale.mlops.domain.project.ProjectService;
import ai.starwhale.mlops.domain.user.UserService;
import ai.starwhale.mlops.domain.user.bo.User;
import java.util.Date;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand All @@ -42,18 +48,22 @@ public class EvaluationServiceTest {
private EvaluationService service;
private ViewConfigMapper viewConfigMapper;

private JobCreator jobCreator;

@BeforeEach
public void setUp() {
UserService userService = mock(UserService.class);
given(userService.currentUserDetail()).willReturn(User.builder().id(1L).build());
ProjectService projectService = mock(ProjectService.class);
given(projectService.getProjectId(same("1"))).willReturn(1L);

jobCreator = mock(JobCreator.class);
service = new EvaluationService(
userService,
projectService,
viewConfigMapper = mock(ViewConfigMapper.class),
new ViewConfigConverter()
new ViewConfigConverter(),
jobCreator
);
}

Expand Down Expand Up @@ -99,4 +109,12 @@ public void testCreateViewConfig() {
res = service.createViewConfig("1", request);
assertThat(res, is(false));
}

@Test
public void testJobCreate() {
when(jobCreator.createJob(any())).thenReturn(Job.builder().id(1234L).build());
Long jobId = service.createEvaluationJob(mock(UserJobCreateRequest.class));
Assertions.assertEquals(1234L, jobId);
}

}

0 comments on commit e118ae4

Please sign in to comment.