diff --git a/server/controller/pom.xml b/server/controller/pom.xml index f68bba4c81..efe2ea8805 100644 --- a/server/controller/pom.xml +++ b/server/controller/pom.xml @@ -144,6 +144,14 @@ io.kubernetes client-java + + com.github.docker-java + docker-java + + + com.github.docker-java + docker-java-transport-httpclient5 + org.xerial.snappy snappy-java diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/DatasetBuildLogWsServer.java b/server/controller/src/main/java/ai/starwhale/mlops/api/DatasetBuildLogWsServer.java index be71f4702b..f4bf74a0f8 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/DatasetBuildLogWsServer.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/DatasetBuildLogWsServer.java @@ -17,8 +17,8 @@ package ai.starwhale.mlops.api; import ai.starwhale.mlops.common.IdConverter; -import ai.starwhale.mlops.schedule.k8s.log.CancellableJobLogCollector; -import ai.starwhale.mlops.schedule.k8s.log.CancellableJobLogK8sCollectorFactory; +import ai.starwhale.mlops.schedule.impl.k8s.log.CancellableJobLogK8sCollectorFactory; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; import io.kubernetes.client.openapi.ApiException; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -51,7 +51,7 @@ public class DatasetBuildLogWsServer { private Long id; - private CancellableJobLogCollector logCollector; + private TaskLogStreamingCollector logCollector; @Autowired @@ -79,7 +79,7 @@ public void onOpen(Session session, @PathParam("name") String name, @PathParam(" String line; while (true) { try { - if ((line = logCollector.readLine()) == null) { + if ((line = logCollector.readLine(10L)) == null) { break; } sendMessage(line); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/TaskLogWsServer.java b/server/controller/src/main/java/ai/starwhale/mlops/api/TaskLogWsServer.java index ddb0de7bd0..a8d1a5e7ac 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/TaskLogWsServer.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/TaskLogWsServer.java @@ -17,9 +17,11 @@ package ai.starwhale.mlops.api; import ai.starwhale.mlops.common.IdConverter; -import ai.starwhale.mlops.schedule.k8s.log.CancellableJobLogCollector; -import ai.starwhale.mlops.schedule.k8s.log.CancellableJobLogK8sCollectorFactory; -import io.kubernetes.client.openapi.ApiException; +import ai.starwhale.mlops.domain.job.step.bo.Step; +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.exception.StarwhaleException; +import ai.starwhale.mlops.schedule.log.TaskLogCollectorFactory; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; import java.io.IOException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -43,7 +45,7 @@ public class TaskLogWsServer { private static IdConverter idConvertor; - private static CancellableJobLogK8sCollectorFactory logCollectorFactory; + private static TaskLogCollectorFactory taskLogCollectorFactory; private Session session; @@ -51,7 +53,7 @@ public class TaskLogWsServer { private Long id; - private CancellableJobLogCollector logCollector; + private TaskLogStreamingCollector logCollector; @Autowired @@ -60,18 +62,19 @@ public void setIdConvertor(IdConverter idConvertor) { } @Autowired - public void setLogCollectorFactory(CancellableJobLogK8sCollectorFactory factory) { - TaskLogWsServer.logCollectorFactory = factory; + public void setTaskLogCollectorFactory(TaskLogCollectorFactory taskLogCollectorFactory) { + TaskLogWsServer.taskLogCollectorFactory = taskLogCollectorFactory; } + @OnOpen public void onOpen(Session session, @PathParam("taskId") String taskId) { this.session = session; this.readerId = session.getId(); this.id = idConvertor.revert(taskId); try { - logCollector = logCollectorFactory.make(taskId); - } catch (IOException | ApiException e) { + logCollector = taskLogCollectorFactory.streamingCollector(Task.builder().id(id).step(new Step()).build()); + } catch (StarwhaleException e) { log.error("make k8s log collector failed", e); } log.info("Task log ws opened. reader={}, task={}", readerId, id); @@ -79,7 +82,7 @@ public void onOpen(Session session, @PathParam("taskId") String taskId) { String line; while (true) { try { - if ((line = logCollector.readLine()) == null) { + if ((line = logCollector.readLine(null)) == null) { break; } sendMessage(line); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/job/ModelServingStatusVo.java b/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/job/ModelServingStatusVo.java index ea0ff92bac..2f1fdf57c2 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/job/ModelServingStatusVo.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/job/ModelServingStatusVo.java @@ -16,7 +16,7 @@ package ai.starwhale.mlops.api.protocol.job; -import ai.starwhale.mlops.schedule.k8s.ResourceEventHolder; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder; import com.fasterxml.jackson.annotation.JsonProperty; import io.swagger.v3.oas.annotations.media.Schema; import java.io.Serializable; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/configuration/schedule/ScheduleConfig.java b/server/controller/src/main/java/ai/starwhale/mlops/configuration/schedule/ScheduleConfig.java index 728e215a4c..00a733229d 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/configuration/schedule/ScheduleConfig.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/configuration/schedule/ScheduleConfig.java @@ -22,7 +22,7 @@ public class ScheduleConfig { /** - * when system receive kill signal when there are scheduled tasks running, system should wait for running tasks to + * when system receive kill signal and there are scheduled tasks running, system should wait for running tasks to * be done to exit * */ diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/DatasetService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/DatasetService.java index 0635897cf6..a939490a06 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/DatasetService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/DatasetService.java @@ -16,8 +16,8 @@ package ai.starwhale.mlops.domain.dataset; -import static ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec.RESOURCE_CPU; -import static ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec.RESOURCE_MEMORY; +import static ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec.RESOURCE_CPU; +import static ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec.RESOURCE_MEMORY; import static cn.hutool.core.util.BooleanUtil.toInt; import ai.starwhale.mlops.api.protocol.dataset.DatasetInfoVo; @@ -75,10 +75,10 @@ import ai.starwhale.mlops.exception.SwValidationException; import ai.starwhale.mlops.exception.SwValidationException.ValidSubject; import ai.starwhale.mlops.exception.api.StarwhaleApiException; -import ai.starwhale.mlops.schedule.k8s.ContainerOverwriteSpec; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; -import ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.ContainerOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import ai.starwhale.mlops.storage.StorageAccessService; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; @@ -607,7 +607,7 @@ var record = buildRecordMapper.selectById(id); return new String(inputStream.readAllBytes(), StandardCharsets.UTF_8); } catch (IOException e) { throw new SwProcessException(ErrorType.DB, - MessageFormat.format("read build log path failed {}", id), e + MessageFormat.format("read build log path failed {0}", id), e ); } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/build/log/BuildLogCollector.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/build/log/BuildLogCollector.java index ebd4164369..764ceca455 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/build/log/BuildLogCollector.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/dataset/build/log/BuildLogCollector.java @@ -20,8 +20,8 @@ import ai.starwhale.mlops.exception.StarwhaleException; import ai.starwhale.mlops.exception.SwProcessException; import ai.starwhale.mlops.exception.SwProcessException.ErrorType; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; import ai.starwhale.mlops.storage.StorageAccessService; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1Pod; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/EnvService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/EnvService.java index 07783f6565..2e333f26a1 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/EnvService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/EnvService.java @@ -17,7 +17,7 @@ package ai.starwhale.mlops.domain.job; import ai.starwhale.mlops.api.protocol.runtime.DeviceVo; -import ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import java.util.ArrayList; import java.util.List; import org.springframework.stereotype.Service; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java index 13f82a26bb..991878f75d 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java @@ -43,10 +43,10 @@ import ai.starwhale.mlops.exception.SwProcessException; import ai.starwhale.mlops.exception.SwValidationException; import ai.starwhale.mlops.exception.api.StarwhaleApiException; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; -import ai.starwhale.mlops.schedule.k8s.ResourceEventHolder; -import ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import com.fasterxml.jackson.core.JsonProcessingException; import io.kubernetes.client.custom.IntOrString; import io.kubernetes.client.informer.ResourceEventHandler; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/cache/JobLoader.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/cache/JobLoader.java index de83b1de22..a10f5b3319 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/cache/JobLoader.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/cache/JobLoader.java @@ -23,6 +23,7 @@ import ai.starwhale.mlops.domain.task.status.WatchableTask; import ai.starwhale.mlops.domain.task.status.WatchableTaskFactory; import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import java.util.Collection; import java.util.List; import java.util.stream.Collectors; @@ -44,11 +45,14 @@ public class JobLoader { final SwTaskScheduler swTaskScheduler; + final TaskReportReceiver taskReportReceiver; + public JobLoader(HotJobHolder jobHolder, WatchableTaskFactory watchableTaskFactory, - SwTaskScheduler swTaskScheduler) { + SwTaskScheduler swTaskScheduler, TaskReportReceiver taskReportReceiver) { this.jobHolder = jobHolder; this.watchableTaskFactory = watchableTaskFactory; this.swTaskScheduler = swTaskScheduler; + this.taskReportReceiver = taskReportReceiver; } public Job load(@NotNull Job job, Boolean resumePausedOrFailTasks) { @@ -91,7 +95,7 @@ void scheduleReadyTasks(Collection tasks) { if (CollectionUtils.isEmpty(tasks)) { return; } - swTaskScheduler.schedule(tasks); + swTaskScheduler.schedule(tasks, taskReportReceiver); } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListener.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListener.java index 7a917411b3..1d6e747149 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListener.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListener.java @@ -19,7 +19,7 @@ import ai.starwhale.mlops.configuration.DockerSetting; import ai.starwhale.mlops.domain.system.SystemSetting; import ai.starwhale.mlops.domain.system.SystemSettingListener; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import cn.hutool.json.JSONUtil; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1ObjectMeta; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeService.java index fc5c54efde..817accecd6 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/runtime/RuntimeService.java @@ -76,9 +76,9 @@ import ai.starwhale.mlops.exception.SwValidationException; import ai.starwhale.mlops.exception.SwValidationException.ValidSubject; import ai.starwhale.mlops.exception.api.StarwhaleApiException; -import ai.starwhale.mlops.schedule.k8s.ContainerOverwriteSpec; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.ContainerOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; import ai.starwhale.mlops.storage.StorageAccessService; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemService.java index 32d2453eca..832cdb12ce 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemService.java @@ -71,7 +71,7 @@ public LatestVersionVo getLatestVersion() { } public List listResourcePools() { - return systemSettingService.getResourcePools(); + return systemSettingService.getResourcePoolsFromWeb(); } public void updateResourcePools(List resourcePools) { diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemSettingService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemSettingService.java index 06d984d22a..b99bdb8046 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemSettingService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/system/SystemSettingService.java @@ -111,13 +111,17 @@ public ResourcePool queryResourcePool(String rpName) { .orElse(ResourcePool.defaults()); } - public List getResourcePools() { + public List getResourcePoolsFromWeb() { User user = userService.currentUserDetail(); - var pools = CollectionUtils.isEmpty(this.systemSetting.getResourcePoolSetting()) - ? List.of(ResourcePool.defaults()) : this.systemSetting.getResourcePoolSetting(); + var pools = getAllResourcePools(); return pools.stream().filter(rp -> rp.allowUser(user.getId())).collect(Collectors.toList()); } + public List getAllResourcePools() { + return CollectionUtils.isEmpty(this.systemSetting.getResourcePoolSetting()) + ? List.of(ResourcePool.defaults()) : this.systemSetting.getResourcePoolSetting(); + } + public void updateResourcePools(List resourcePools) { this.systemSetting.setResourcePoolSetting(resourcePools); systemSettingMapper.put(querySetting()); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/system/resourcepool/bo/ResourcePool.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/system/resourcepool/bo/ResourcePool.java index 01a96bb9c9..4a950b0e52 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/system/resourcepool/bo/ResourcePool.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/system/resourcepool/bo/ResourcePool.java @@ -18,7 +18,7 @@ import ai.starwhale.mlops.common.Constants; import ai.starwhale.mlops.domain.runtime.RuntimeResource; -import ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import java.io.IOException; import java.util.ArrayList; import java.util.List; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/TaskService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/TaskService.java index 534ee89d4d..537581f140 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/TaskService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/TaskService.java @@ -107,7 +107,7 @@ public String logContent(Long taskId, String logFileName) { return new String(inputStream.readAllBytes(), StandardCharsets.UTF_8); } catch (IOException e) { throw new SwProcessException(ErrorType.DB, - MessageFormat.format("read log path from db failed {}", taskId), + MessageFormat.format("read log path from db failed {0}", taskId), e); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/bo/Task.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/bo/Task.java index 0b1d90d0dd..e84fccb2dd 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/bo/Task.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/bo/Task.java @@ -98,6 +98,9 @@ public boolean equals(Object obj) { return false; } Task tsk = (Task) obj; + if (null != id) { + return this.id.equals(tsk.id); + } return this.uuid.equals(tsk.uuid); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/TaskStatusMachine.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/TaskStatusMachine.java index a9a29f2764..628bcb6ca4 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/TaskStatusMachine.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/TaskStatusMachine.java @@ -54,6 +54,17 @@ public boolean couldTransfer(TaskStatus statusNow, TaskStatus statusNew) { return transferMap.get(statusNow).contains(statusNew); } + public TaskStatus transfer(TaskStatus statusNow, TaskStatus statusNew) { + if (statusNow == CANCELLING) { + if (isFinal(statusNew)) { + return CANCELED; + } else { + return CANCELLING; + } + } + return statusNew; + } + public boolean isFinal(TaskStatus status) { return transferMap.get(status).isEmpty(); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/WatchableTask.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/WatchableTask.java index 13ea54476c..7c1e8cdac3 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/WatchableTask.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/WatchableTask.java @@ -65,7 +65,7 @@ public void updateStatus(TaskStatus status) { oldStatus, status, originalTask.getId()); return; } - originalTask.updateStatus(status); + originalTask.updateStatus(taskStatusMachine.transfer(oldStatus, status)); log.debug("task status changed from {} to {} of id {}", oldStatus, status, originalTask.getId()); watchers.stream().filter(w -> { if (TaskStatusChangeWatcher.SKIPPED_WATCHERS.get() == null) { diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/TaskWatcherForSchedule.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/TaskWatcherForSchedule.java index 75a3c18b0d..6591ad5299 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/TaskWatcherForSchedule.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/TaskWatcherForSchedule.java @@ -21,6 +21,8 @@ import ai.starwhale.mlops.domain.task.status.TaskStatusChangeWatcher; import ai.starwhale.mlops.domain.task.status.TaskStatusMachine; import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.log.TaskLogSaver; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import java.util.ArrayList; import java.util.List; import java.util.concurrent.DelayQueue; @@ -38,7 +40,7 @@ @Order(6) public class TaskWatcherForSchedule implements TaskStatusChangeWatcher { - final SwTaskScheduler taskScheduler; + final SwTaskScheduler swTaskScheduler; final TaskStatusMachine taskStatusMachine; @@ -46,14 +48,20 @@ public class TaskWatcherForSchedule implements TaskStatusChangeWatcher { final DelayQueue taskToDeletes; + private final TaskLogSaver taskLogSaver; + + final TaskReportReceiver taskReportReceiver; + public TaskWatcherForSchedule( - SwTaskScheduler taskScheduler, + SwTaskScheduler swTaskScheduler, TaskStatusMachine taskStatusMachine, - @Value("${sw.task.deletion-delay-minutes}") Long deletionDelayMinutes - ) { - this.taskScheduler = taskScheduler; + @Value("${sw.task.deletion-delay-minutes}") Long deletionDelayMinutes, + TaskLogSaver taskLogSaver, TaskReportReceiver taskReportReceiver) { + this.swTaskScheduler = swTaskScheduler; this.taskStatusMachine = taskStatusMachine; this.deletionDelayMilliseconds = TimeUnit.MILLISECONDS.convert(deletionDelayMinutes, TimeUnit.MINUTES); + this.taskLogSaver = taskLogSaver; + this.taskReportReceiver = taskReportReceiver; this.taskToDeletes = new DelayQueue<>(); } @@ -61,18 +69,20 @@ public TaskWatcherForSchedule( public void onTaskStatusChange(Task task, TaskStatus oldStatus) { if (task.getStatus() == TaskStatus.READY) { log.debug("task status changed to ready id: {} oldStatus: {}, scheduled", task.getId(), oldStatus); - taskScheduler.schedule(List.of(task)); + swTaskScheduler.schedule(List.of(task), taskReportReceiver); } else if (task.getStatus() == TaskStatus.CANCELLING || task.getStatus() == TaskStatus.PAUSED) { - taskScheduler.stop(List.of(task)); + taskLogSaver.saveLog(task); + swTaskScheduler.stop(List.of(task)); log.debug("task status changed to {} with id: {} newStatus: {}, stop scheduled immediately", task.getStatus(), task.getId(), task.getStatus()); - } else if (task.getStatus() == TaskStatus.SUCCESS || task.getStatus() == TaskStatus.FAIL) { + } else if (taskStatusMachine.isFinal(task.getStatus())) { log.debug("task status changed to {} with id: {} newStatus: {}, stop scheduled in delayed queue", task.getStatus(), task.getId(), task.getStatus()); + taskLogSaver.saveLog(task); if (deletionDelayMilliseconds <= 0) { - taskScheduler.stop(List.of(task)); + swTaskScheduler.stop(List.of(task)); } else { addToDeleteQueue(task); } @@ -97,7 +107,7 @@ public void processTaskDeletion() { log.debug("delete task {}", toDelete.getTask().getId()); toDelete = taskToDeletes.poll(); } - taskScheduler.stop(tasks); + swTaskScheduler.stop(tasks); } static class TaskToDelete implements Delayed { diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/UpgradeService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/UpgradeService.java index 5f283eb086..6cd64115cf 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/UpgradeService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/UpgradeService.java @@ -29,7 +29,7 @@ import ai.starwhale.mlops.exception.SwProcessException.ErrorType; import ai.starwhale.mlops.exception.SwValidationException; import ai.starwhale.mlops.exception.SwValidationException.ValidSubject; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import cn.hutool.core.util.IdUtil; import cn.hutool.core.util.StrUtil; import io.kubernetes.client.openapi.ApiException; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/step/UpdateK8sImage.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/step/UpdateK8sImage.java index 99df01a005..88002c1d38 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/step/UpdateK8sImage.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/upgrade/step/UpdateK8sImage.java @@ -21,7 +21,7 @@ import ai.starwhale.mlops.domain.upgrade.bo.Upgrade; import ai.starwhale.mlops.exception.SwProcessException; import ai.starwhale.mlops.exception.SwProcessException.ErrorType; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import cn.hutool.json.JSONUtil; import io.kubernetes.client.custom.V1Patch; import io.kubernetes.client.openapi.ApiException; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/SwSchedulerAbstractFactory.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/SwSchedulerAbstractFactory.java new file mode 100644 index 0000000000..a460456314 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/SwSchedulerAbstractFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule; + + +import ai.starwhale.mlops.schedule.log.TaskLogCollectorFactory; + +/** + * This interface produces a family of objects that are sufficient to implement the scheduler function + */ +public interface SwSchedulerAbstractFactory { + + SwTaskScheduler buildSwTaskScheduler(); + + TaskLogCollectorFactory buildTaskLogCollectorFactory(); + + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/SwTaskScheduler.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/SwTaskScheduler.java index 67cc3255cb..e0317eb99a 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/SwTaskScheduler.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/SwTaskScheduler.java @@ -17,26 +17,33 @@ package ai.starwhale.mlops.schedule; import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import java.util.Collection; import java.util.concurrent.Future; /** * schedule tasks of jobs + * Spring framework has a class named TaskScheduler with a bean in the context. To avoid bean conflict the name here is + * SwTaskScheduler */ public interface SwTaskScheduler { /** - * scheduler should maintain the tasks to be scheduled - * - * @param tasks tasks to be scheduled + * @param tasks the tasks to be scheduled + * @param taskReportReceiver the report receiver that could process these tasks */ - void schedule(Collection tasks); + void schedule(Collection tasks, TaskReportReceiver taskReportReceiver); /** - * @param tasks tasks to be stop scheduled + * @param tasks tasks to be stopped */ void stop(Collection tasks); - Future exec(Task task, String ...command); + /** + * @param task the tasks to exec on + * @param command command may be wrapped with "sh -c" by implementations + * @return stdout, stderr in String[] + */ + Future exec(Task task, String... command); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/TaskCommandGetter.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/TaskCommandGetter.java new file mode 100644 index 0000000000..2a4a1e627f --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/TaskCommandGetter.java @@ -0,0 +1,44 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule; + +import ai.starwhale.mlops.domain.task.bo.Task; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.stereotype.Service; + +@Service +public class TaskCommandGetter { + + public TaskCommand getCmd(Task task) { + //TODO get the real command of tasks according to schema of task + return TaskCommand.builder().cmd(new String[]{"run"}).build(); + } + + @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor + public static class TaskCommand { + + String[] cmd; + String[] entrypoint; + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/TaskRunningEnvBuilder.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/TaskRunningEnvBuilder.java new file mode 100644 index 0000000000..027d5c2b00 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/TaskRunningEnvBuilder.java @@ -0,0 +1,156 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule; + + +import ai.starwhale.mlops.configuration.RunTimeProperties; +import ai.starwhale.mlops.configuration.security.TaskTokenValidator; +import ai.starwhale.mlops.domain.job.bo.Job; +import ai.starwhale.mlops.domain.runtime.RuntimeResource; +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + + +@Service +public class TaskRunningEnvBuilder { + + static final String FORMATTER_URI_ARTIFACT = "%s/project/%s/%s/%s/version/%s"; + static final String FORMATTER_VERSION_ARTIFACT = "%s/version/%s"; + final String instanceUri; + final int devPort; + final int datasetLoadBatchSize; + final RunTimeProperties runTimeProperties; + final TaskTokenValidator taskTokenValidator; + + public TaskRunningEnvBuilder( + @Value("${sw.instance-uri}") String instanceUri, + @Value("${sw.task.dev-port}") int devPort, + @Value("${sw.dataset.load.batch-size}") int datasetLoadBatchSize, + RunTimeProperties runTimeProperties, + TaskTokenValidator taskTokenValidator + ) { + this.instanceUri = instanceUri; + this.devPort = devPort; + this.datasetLoadBatchSize = datasetLoadBatchSize; + this.runTimeProperties = runTimeProperties; + this.taskTokenValidator = taskTokenValidator; + } + + + public Map buildCoreContainerEnvs(Task task) { + Job swJob = task.getStep().getJob(); + var model = swJob.getModel(); + var runtime = swJob.getJobRuntime(); + Map coreContainerEnvs = new HashMap<>(); + var taskEnv = task.getTaskRequest().getEnv(); + if (!CollectionUtils.isEmpty(taskEnv)) { + taskEnv.forEach(env -> coreContainerEnvs.put(env.getName(), env.getValue())); + } + coreContainerEnvs.put("SW_RUNTIME_PYTHON_VERSION", runtime.getManifest().getEnvironment().getPython()); + coreContainerEnvs.put("SW_VERSION", runtime.getManifest().getEnvironment().getLock().getSwVersion()); + coreContainerEnvs.put("SW_TASK_STEP", task.getStep().getName()); + coreContainerEnvs.put("DATASET_CONSUMPTION_BATCH_SIZE", String.valueOf(datasetLoadBatchSize)); + // support multi dataset uris + coreContainerEnvs.put("SW_DATASET_URI", + swJob.getDataSets().stream() + .map(dataSet -> String.format( + FORMATTER_URI_ARTIFACT, + instanceUri, + dataSet.getProjectId(), + "dataset", + dataSet.getName(), + dataSet.getVersion()) + ).collect(Collectors.joining(" "))); + coreContainerEnvs.put("SW_MODEL_URI", + String.format( + FORMATTER_URI_ARTIFACT, + instanceUri, + model.getProjectId(), + "model", + model.getName(), + model.getVersion())); + coreContainerEnvs.put("SW_RUNTIME_URI", + String.format( + FORMATTER_URI_ARTIFACT, + instanceUri, + runtime.getProjectId(), + "runtime", + runtime.getName(), + runtime.getVersion())); + coreContainerEnvs.put("SW_MODEL_VERSION", + String.format(FORMATTER_VERSION_ARTIFACT, + model.getName(), model.getVersion())); + coreContainerEnvs.put("SW_RUNTIME_VERSION", + String.format(FORMATTER_VERSION_ARTIFACT, + runtime.getName(), runtime.getVersion())); + coreContainerEnvs.put("SW_RUN_HANDLER", task.getTaskRequest().getJobName()); + coreContainerEnvs.put("SW_TASK_INDEX", String.valueOf(task.getTaskRequest().getIndex())); + coreContainerEnvs.put("SW_TASK_NUM", String.valueOf(task.getTaskRequest().getTotal())); + coreContainerEnvs.put("SW_JOB_VERSION", swJob.getUuid()); + + // datastore env + coreContainerEnvs.put("SW_TOKEN", taskTokenValidator.getTaskToken(swJob.getOwner(), task.getId())); + coreContainerEnvs.put("SW_INSTANCE_URI", instanceUri); + coreContainerEnvs.put("SW_PROJECT", swJob.getProject().getName()); + coreContainerEnvs.put("SW_PYPI_INDEX_URL", runTimeProperties.getPypi().getIndexUrl()); + coreContainerEnvs.put("SW_PYPI_EXTRA_INDEX_URL", runTimeProperties.getPypi().getExtraIndexUrl()); + coreContainerEnvs.put("SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost()); + coreContainerEnvs.put("SW_PYPI_TIMEOUT", String.valueOf(runTimeProperties.getPypi().getTimeout())); + coreContainerEnvs.put("SW_PYPI_RETRIES", String.valueOf(runTimeProperties.getPypi().getRetries())); + if (StringUtils.hasText(runTimeProperties.getCondarc())) { + coreContainerEnvs.put("SW_CONDA_CONFIG", runTimeProperties.getCondarc()); + } + + // GPU resource + var resources = deviceResourceRequirements(task).stream(); + //TODO remove ResourceOverwriteSpec.RESOURCE_GPU dependency on k8s impl + var gpu = resources.anyMatch(r -> r.getType().equals(ResourceOverwriteSpec.RESOURCE_GPU) && r.getRequest() > 0); + // overwrite visible devices to none + if (!gpu) { + // https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration + coreContainerEnvs.put("NVIDIA_VISIBLE_DEVICES", ""); + } + + if (swJob.isDevMode()) { + coreContainerEnvs.put("SW_DEV_TOKEN", swJob.getDevPassword()); + coreContainerEnvs.put("SW_DEV_PORT", String.valueOf(devPort)); + } + + return coreContainerEnvs; + } + + public List deviceResourceRequirements(Task task) { + List runtimeResources = task.getTaskRequest().getRuntimeResources(); + var pool = task.getStep().getResourcePool(); + if (pool == null) { + // use resource pool of job, for backward compatibility + pool = task.getStep().getJob().getResourcePool(); + } + if (pool != null) { + runtimeResources = pool.patchResources(runtimeResources); + } + return runtimeResources; + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/BeanConfig.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/BeanConfig.java new file mode 100644 index 0000000000..8cdb1654b4 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/BeanConfig.java @@ -0,0 +1,66 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.domain.system.SystemSettingService; +import ai.starwhale.mlops.domain.task.status.TaskStatusMachine; +import ai.starwhale.mlops.schedule.impl.docker.reporting.ContainerStatusExplainer; +import ai.starwhale.mlops.schedule.impl.docker.reporting.DockerTaskReporter; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration("dockerSchedulerBeanConfiguration") +@ConditionalOnProperty(value = "sw.scheduler", havingValue = "docker") +public class BeanConfig { + + @Bean + public DockerClientFinder dockerClientFinder() { + return new DockerClientFinderSimpleImpl(); + } + + @Bean + public ContainerTaskMapper containerTaskMapper(DockerClientFinder dockerClientFinder) { + return new ContainerTaskMapper(dockerClientFinder); + } + + @Bean + public DockerTaskReporter taskReporter( + TaskReportReceiver taskReportReceiver, + SystemSettingService systemSettingService, + DockerClientFinder dockerClientFinder, + ContainerStatusExplainer containerStatusExplainer, + TaskStatusMachine taskStatusMachine, + ContainerTaskMapper containerTaskMapper + ) { + return new DockerTaskReporter( + taskReportReceiver, + systemSettingService, + dockerClientFinder, + containerTaskMapper, + containerStatusExplainer, + taskStatusMachine + ); + } + + @Bean + public ContainerStatusExplainer containerStatusExplainer() { + return new ContainerStatusExplainer(); + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/ContainerTaskMapper.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/ContainerTaskMapper.java new file mode 100644 index 0000000000..3e92b2e190 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/ContainerTaskMapper.java @@ -0,0 +1,63 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.domain.task.bo.Task; +import cn.hutool.core.util.StrUtil; +import com.github.dockerjava.api.model.Container; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; + +@Slf4j +public class ContainerTaskMapper { + + static final String CONTAINER_LABEL_TASK_ID = "starwhale-task-id"; + + final DockerClientFinder dockerClientFinder; + + public ContainerTaskMapper(DockerClientFinder dockerClientFinder) { + this.dockerClientFinder = dockerClientFinder; + } + + public String containerName(Task task) { + return String.format("starwhale-task-%d-%d", task.getId(), System.currentTimeMillis()); + } + + public Container containerOfTask(Task task) { + List containers = dockerClientFinder.findProperDockerClient(task.getStep().getResourcePool()) + .listContainersCmd().withShowAll(true) + .withLabelFilter(Map.of(CONTAINER_LABEL_TASK_ID, task.getId().toString())).exec(); + if (CollectionUtils.isEmpty(containers)) { + return null; + } + if (containers.size() > 1) { + log.warn("multiple containers found for task {}", task.getId()); + } + return containers.get(0); + } + + public Long taskIfOfContainer(Container container) { + String taskId = container.getLabels().get(CONTAINER_LABEL_TASK_ID); + if (null != taskId && StrUtil.isNumeric(taskId)) { + return Long.valueOf(taskId); + } + return null; + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/DockerClientFinder.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/DockerClientFinder.java new file mode 100644 index 0000000000..28a0c7c69e --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/DockerClientFinder.java @@ -0,0 +1,33 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.domain.system.resourcepool.bo.ResourcePool; +import com.github.dockerjava.api.DockerClient; + +public interface DockerClientFinder { + + /** + * given a specific resourcePool find a unique DockerClient + * The DockerClient must be consistent among different calls given the same resourcePool + * + * @param resourcePool the resource pool a task wants to be scheduled + * @return the corresponding docker client + */ + DockerClient findProperDockerClient(ResourcePool resourcePool); + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/DockerClientFinderSimpleImpl.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/DockerClientFinderSimpleImpl.java new file mode 100644 index 0000000000..17bc0dabd9 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/DockerClientFinderSimpleImpl.java @@ -0,0 +1,49 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.domain.system.resourcepool.bo.ResourcePool; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.core.DefaultDockerClientConfig; +import com.github.dockerjava.core.DockerClientImpl; +import com.github.dockerjava.httpclient5.ApacheDockerHttpClient; +import com.github.dockerjava.transport.DockerHttpClient; +import java.time.Duration; + +public class DockerClientFinderSimpleImpl implements DockerClientFinder { + + final DockerClient dockerClient; + + public DockerClientFinderSimpleImpl() { + DefaultDockerClientConfig clientConfig = DefaultDockerClientConfig.createDefaultConfigBuilder() + .withDockerHost("unix:///var/run/docker.sock").build(); + DockerHttpClient httpClient = new ApacheDockerHttpClient.Builder() + .dockerHost(clientConfig.getDockerHost()) + .sslConfig(clientConfig.getSSLConfig()) + .maxConnections(100) + .connectionTimeout(Duration.ofSeconds(30)) + .responseTimeout(Duration.ofSeconds(45)) + .build(); + this.dockerClient = DockerClientImpl.getInstance(clientConfig, httpClient); + } + + @Override + public DockerClient findProperDockerClient(ResourcePool resourcePool) { + return this.dockerClient; + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/HostResourceConfigBuilder.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/HostResourceConfigBuilder.java new file mode 100644 index 0000000000..1d08081192 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/HostResourceConfigBuilder.java @@ -0,0 +1,48 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.domain.runtime.RuntimeResource; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; +import com.github.dockerjava.api.model.HostConfig; +import java.util.List; + +/** + * https://docs.docker.com/config/containers/resource_constraints/ + */ +public class HostResourceConfigBuilder { + + HostConfig build(List runtimeResources) { + HostConfig hostConfig = HostConfig.newHostConfig(); + runtimeResources.forEach(runtimeResource -> { + if (ResourceOverwriteSpec.RESOURCE_CPU.equals(runtimeResource.getType())) { + // docker has no cpu reservation for a container. So, request is not processed + if (null != runtimeResource.getLimit()) { + hostConfig.withCpuCount(runtimeResource.getLimit().longValue()); + } + } + if (ResourceOverwriteSpec.RESOURCE_MEMORY.equals(runtimeResource.getType())) { + // docker has no memory reservation for a container. So, request is not processed + if (null != runtimeResource.getLimit()) { + hostConfig.withMemory(runtimeResource.getLimit().longValue()); + } + } + }); + return hostConfig; + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/SwSchedulerFactoryDocker.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/SwSchedulerFactoryDocker.java new file mode 100644 index 0000000000..2a2474f2af --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/SwSchedulerFactoryDocker.java @@ -0,0 +1,78 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.schedule.SwSchedulerAbstractFactory; +import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.TaskCommandGetter; +import ai.starwhale.mlops.schedule.TaskRunningEnvBuilder; +import ai.starwhale.mlops.schedule.impl.docker.log.TaskLogCollectorFactoryDocker; +import ai.starwhale.mlops.schedule.impl.docker.reporting.DockerTaskReporter; +import ai.starwhale.mlops.schedule.log.TaskLogCollectorFactory; +import java.util.concurrent.Executors; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + + +@Configuration +@ConditionalOnProperty(value = "sw.scheduler", havingValue = "docker") +public class SwSchedulerFactoryDocker implements SwSchedulerAbstractFactory { + + final DockerClientFinder dockerClientFinder; + + final ContainerTaskMapper containerTaskMapper; + + final DockerTaskReporter dockerTaskReporter; + + final TaskRunningEnvBuilder taskRunningEnvBuilder; + + final String network; + final String nodeIp; + final TaskCommandGetter taskCommandGetter; + + public SwSchedulerFactoryDocker(DockerClientFinder dockerClientFinder, ContainerTaskMapper containerTaskMapper, + DockerTaskReporter dockerTaskReporter, + TaskRunningEnvBuilder taskRunningEnvBuilder, @Value("${sw.infra.docker.network}") String network, + @Value("${sw.infra.docker.node-ip}") String nodeIp, + TaskCommandGetter taskCommandGetter) { + this.dockerClientFinder = dockerClientFinder; + this.containerTaskMapper = containerTaskMapper; + this.dockerTaskReporter = dockerTaskReporter; + this.taskRunningEnvBuilder = taskRunningEnvBuilder; + this.network = network; + this.nodeIp = nodeIp; + this.taskCommandGetter = taskCommandGetter; + } + + @Bean + @Override + public SwTaskScheduler buildSwTaskScheduler() { + return new SwTaskSchedulerDocker(dockerClientFinder, containerTaskMapper, dockerTaskReporter, + Executors.newCachedThreadPool(), taskRunningEnvBuilder, network, + nodeIp, taskCommandGetter, new HostResourceConfigBuilder()); + } + + @Bean + @Override + public TaskLogCollectorFactory buildTaskLogCollectorFactory() { + return new TaskLogCollectorFactoryDocker(dockerClientFinder, containerTaskMapper); + } + + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/SwTaskSchedulerDocker.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/SwTaskSchedulerDocker.java new file mode 100644 index 0000000000..aeea64d80a --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/SwTaskSchedulerDocker.java @@ -0,0 +1,259 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.domain.task.status.TaskStatus; +import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.TaskCommandGetter; +import ai.starwhale.mlops.schedule.TaskCommandGetter.TaskCommand; +import ai.starwhale.mlops.schedule.TaskRunningEnvBuilder; +import ai.starwhale.mlops.schedule.impl.docker.reporting.DockerTaskReporter; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.async.ResultCallback; +import com.github.dockerjava.api.command.CreateContainerCmd; +import com.github.dockerjava.api.command.CreateContainerResponse; +import com.github.dockerjava.api.command.ExecCreateCmd; +import com.github.dockerjava.api.command.ExecCreateCmdResponse; +import com.github.dockerjava.api.command.ExecStartCmd; +import com.github.dockerjava.api.exception.DockerException; +import com.github.dockerjava.api.model.Container; +import com.github.dockerjava.api.model.Frame; +import com.github.dockerjava.api.model.PullResponseItem; +import java.io.Closeable; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.jetbrains.annotations.NotNull; +import org.springframework.util.CollectionUtils; + +@Slf4j +public class SwTaskSchedulerDocker implements SwTaskScheduler { + + public static Map CONTAINER_LABELS = Map.of("owner", "starwhale"); + final DockerClientFinder dockerClientFinder; + final ContainerTaskMapper containerTaskMapper; + final DockerTaskReporter dockerTaskReporter; + final ExecutorService cmdExecThreadPool; + final TaskRunningEnvBuilder taskRunningEnvBuilder; + final String network; + final String nodeIp; + final TaskCommandGetter taskCommandGetter; + + final HostResourceConfigBuilder hostResourceConfigBuilder; + + public SwTaskSchedulerDocker(DockerClientFinder dockerClientFinder, ContainerTaskMapper containerTaskMapper, + DockerTaskReporter dockerTaskReporter, ExecutorService cmdExecThreadPool, + TaskRunningEnvBuilder taskRunningEnvBuilder, String network, String nodeIp, + TaskCommandGetter taskCommandGetter, HostResourceConfigBuilder hostResourceConfigBuilder) { + this.dockerClientFinder = dockerClientFinder; + this.containerTaskMapper = containerTaskMapper; + this.dockerTaskReporter = dockerTaskReporter; + this.cmdExecThreadPool = cmdExecThreadPool; + this.taskRunningEnvBuilder = taskRunningEnvBuilder; + this.network = network; + this.nodeIp = nodeIp; + this.taskCommandGetter = taskCommandGetter; + this.hostResourceConfigBuilder = hostResourceConfigBuilder; + } + + @Override + public void schedule(Collection tasks, TaskReportReceiver taskReportReceiver) { + if (CollectionUtils.isEmpty(tasks)) { + return; + } + for (Task task : tasks) { + DockerClient dockerClient = dockerClientFinder.findProperDockerClient( + task.getStep().getResourcePool()); + String image = task.getStep().getJob().getJobRuntime().getImage(); + dockerClient.pullImageCmd(image).exec(new ResultCallback() { + @Override + public void onStart(Closeable closeable) { + ReportedTask rt = ReportedTask.builder() + .id(task.getId()) + .status(TaskStatus.PREPARING) + .startTimeMillis(System.currentTimeMillis()) + .retryCount(0) + .ip(nodeIp) + .build(); + taskReportReceiver.receive(List.of(rt)); + } + + @Override + public void onNext(PullResponseItem object) { + + } + + @Override + public void onError(Throwable throwable) { + log.error("creating container error ", throwable); + ReportedTask rt = ReportedTask.builder() + .id(task.getId()) + .status(TaskStatus.FAIL) + .stopTimeMillis(System.currentTimeMillis()) + .retryCount(0) + .failedReason(throwable.getMessage()) + .ip(nodeIp) + .build(); + taskReportReceiver.receive(List.of(rt)); + + } + + @Override + public void onComplete() { + Map labels = new HashMap(); + labels.put(ContainerTaskMapper.CONTAINER_LABEL_TASK_ID, task.getId().toString()); + labels.putAll(CONTAINER_LABELS); + + CreateContainerCmd createContainerCmd = dockerClient.createContainerCmd(image) + .withEnv(buildEnvs(task)) + .withName(containerTaskMapper.containerName(task)) + .withHostConfig(hostResourceConfigBuilder.build( + taskRunningEnvBuilder.deviceResourceRequirements(task)).withNetworkMode(network)) + .withLabels(labels); + TaskCommand taskCommand = taskCommandGetter.getCmd(task); + if (null != taskCommand.getEntrypoint()) { + createContainerCmd.withEntrypoint(taskCommand.getEntrypoint()); + } else if (null != taskCommand.getCmd()) { + createContainerCmd.withCmd(taskCommand.getCmd()); + } + CreateContainerResponse createContainerResponse = createContainerCmd.exec(); + dockerClient.startContainerCmd(createContainerResponse.getId()).exec(); + ReportedTask rt = ReportedTask.builder() + .id(task.getId()) + .status(TaskStatus.RUNNING) + .startTimeMillis(System.currentTimeMillis()) + .retryCount(0) + .ip(nodeIp) + .build(); + taskReportReceiver.receive(List.of(rt)); + } + + @Override + public void close() throws IOException { + + } + }); + } + + + } + + @NotNull + private List buildEnvs(Task task) { + Map containerEnvs = taskRunningEnvBuilder.buildCoreContainerEnvs(task); + List envs = containerEnvs.entrySet().stream().map( + es -> String.format("%s=%s", es.getKey(), es.getValue()) + ).collect(Collectors.toList()); + return envs; + } + + @Override + public void stop(Collection tasks) { + if (CollectionUtils.isEmpty(tasks)) { + return; + } + tasks.forEach(t -> { + DockerClient dockerClient = dockerClientFinder.findProperDockerClient( + t.getStep().getResourcePool()); + Container container = containerTaskMapper.containerOfTask(t); + if (null == container) { + return; + } + if ("exited".equalsIgnoreCase(container.getState())) { + try { + dockerClient.removeContainerCmd(container.getId()).withForce(true).withRemoveVolumes(true).exec(); + } catch (DockerException e) { + log.warn("try to remove container with error", e); + } + } else { + try { + dockerClient.killContainerCmd(container.getId()).exec(); + } catch (DockerException e) { + log.warn("try to kill container with error", e); + } + } + + }); + + } + + @Override + public Future exec(Task task, String... command) { + DockerClient dockerClient = dockerClientFinder.findProperDockerClient(task.getStep().getResourcePool()); + var execCommand = List.of("sh", "-c", String.join(" ", command)).toArray(new String[0]); + + ExecCreateCmd execCreateCmd = dockerClient.execCreateCmd(containerTaskMapper.containerOfTask(task).getId()) + .withCmd(execCommand) + .withAttachStdout(true) + .withAttachStderr(true) + .withTty(true); + ExecCreateCmdResponse exec = execCreateCmd.exec(); + ExecStartCmd execStartCmd = dockerClient.execStartCmd(exec.getId()); + Object lock = new Object(); + StringBuilder stringBuilder = new StringBuilder(); + execStartCmd.exec(new ResultCallback() { + @Override + public void onStart(Closeable closeable) { + + } + + @Override + public void onNext(Frame object) { + stringBuilder.append(object.toString()); + } + + @Override + public void onError(Throwable throwable) { + synchronized (lock) { + lock.notifyAll(); + } + } + + @Override + public void onComplete() { + synchronized (lock) { + lock.notifyAll(); + } + } + + @Override + public void close() throws IOException { + synchronized (lock) { + lock.notifyAll(); + } + } + }); + + return cmdExecThreadPool.submit(() -> { + synchronized (lock) { + lock.wait(); + } + + return new String[]{stringBuilder.toString(), ""}; + }); + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogCollectorFactoryDocker.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogCollectorFactoryDocker.java new file mode 100644 index 0000000000..685c70b777 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogCollectorFactoryDocker.java @@ -0,0 +1,48 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.log; + +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.exception.StarwhaleException; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinder; +import ai.starwhale.mlops.schedule.log.TaskLogCollectorFactory; +import ai.starwhale.mlops.schedule.log.TaskLogOfflineCollector; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; + +public class TaskLogCollectorFactoryDocker implements TaskLogCollectorFactory { + + final DockerClientFinder dockerClientFinder; + + final ContainerTaskMapper containerTaskMapper; + + public TaskLogCollectorFactoryDocker(DockerClientFinder dockerClientFinder, + ContainerTaskMapper containerTaskMapper) { + this.dockerClientFinder = dockerClientFinder; + this.containerTaskMapper = containerTaskMapper; + } + + @Override + public TaskLogOfflineCollector offlineCollector(Task task) throws StarwhaleException { + return new TaskLogOfflineCollectorDocker(task, dockerClientFinder, containerTaskMapper); + } + + @Override + public TaskLogStreamingCollector streamingCollector(Task task) throws StarwhaleException { + return new TaskLogStreamingCollectorDocker(task, dockerClientFinder, containerTaskMapper); + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogOfflineCollectorDocker.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogOfflineCollectorDocker.java new file mode 100644 index 0000000000..35c8622523 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogOfflineCollectorDocker.java @@ -0,0 +1,109 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.log; + +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinder; +import ai.starwhale.mlops.schedule.log.TaskLogOfflineCollector; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.async.ResultCallback; +import com.github.dockerjava.api.command.LogContainerCmd; +import com.github.dockerjava.api.model.Container; +import com.github.dockerjava.api.model.Frame; +import io.vavr.Tuple2; +import java.io.Closeable; +import java.io.IOException; + +public class TaskLogOfflineCollectorDocker implements TaskLogOfflineCollector { + + final Task task; + + final DockerClient dockerClient; + + final DockerClientFinder dockerClientFinder; + + final ContainerTaskMapper containerTaskMapper; + + final Object lock = new Object(); + + StringBuffer logBuffer = new StringBuffer(); + + + public TaskLogOfflineCollectorDocker(Task task, DockerClientFinder dockerClientFinder, + ContainerTaskMapper containerTaskMapper) { + this.task = task; + this.dockerClientFinder = dockerClientFinder; + this.dockerClient = this.dockerClientFinder.findProperDockerClient(task.getStep().getResourcePool()); + this.containerTaskMapper = containerTaskMapper; + } + + @Override + public Tuple2 collect() { + logBuffer = new StringBuffer(); + Container container = this.containerTaskMapper.containerOfTask(task); + if (null == container) { + return null; + } + LogContainerCmd logContainerCmd = dockerClient.logContainerCmd( + container.getId()) + .withStdErr(true) + .withStdOut(true) + .withFollowStream(false); + logContainerCmd.exec(new ResultCallback() { + @Override + public void onStart(Closeable closeable) { + + } + + @Override + public void onNext(Frame object) { + logBuffer.append(object.toString()); + logBuffer.append("\n"); + } + + @Override + public void onError(Throwable throwable) { + synchronized (lock) { + lock.notifyAll(); + } + } + + @Override + public void onComplete() { + synchronized (lock) { + lock.notifyAll(); + } + } + + @Override + public void close() throws IOException { + synchronized (lock) { + lock.notifyAll(); + } + } + }); + synchronized (this.lock) { + try { + this.lock.wait(); + return new Tuple2<>(container.getNames()[0], this.logBuffer.toString()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogStreamingCollectorDocker.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogStreamingCollectorDocker.java new file mode 100644 index 0000000000..fbefca9404 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogStreamingCollectorDocker.java @@ -0,0 +1,134 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.log; + +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinder; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.async.ResultCallback; +import com.github.dockerjava.api.command.LogContainerCmd; +import com.github.dockerjava.api.model.Frame; +import java.io.Closeable; +import java.io.IOException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; + + +@Slf4j +public class TaskLogStreamingCollectorDocker implements TaskLogStreamingCollector { + + final DockerClient dockerClient; + + final DockerClientFinder dockerClientFinder; + + final ContainerTaskMapper containerTaskMapper; + + final BlockingQueue logLines; + + Closeable closeable; + + Boolean closed = Boolean.FALSE; + + public TaskLogStreamingCollectorDocker(Task task, DockerClientFinder dockerClientFinder, + ContainerTaskMapper containerTaskMapper) { + this.dockerClientFinder = dockerClientFinder; + this.dockerClient = this.dockerClientFinder.findProperDockerClient(task.getStep().getResourcePool()); + this.containerTaskMapper = containerTaskMapper; + this.logLines = new LinkedBlockingQueue<>(); + LogContainerCmd logContainerCmd = dockerClient.logContainerCmd( + this.containerTaskMapper.containerOfTask(task).getId()) + .withStdErr(true) + .withStdOut(true) + .withFollowStream(true); + var that = this; + logContainerCmd.exec(new ResultCallback() { + @Override + public void onStart(Closeable closeable) { + that.closeable = closeable; + } + + @Override + public void onNext(Frame object) { + try { + logLines.put(object.toString()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) { + signalClose(); + } + + @Override + public void onComplete() { + signalClose(); + } + + @Override + public void close() throws IOException { + signalClose(); + } + + private void signalClose() { + closed = Boolean.TRUE; + try { + logLines.put(""); + } catch (InterruptedException e) { + log.error("putting to logLines failed, readLine thread may starving"); + throw new RuntimeException(e); + } + } + }); + } + + @Override + public String readLine(Long waitTimeSeconds) throws IOException { + if (this.closed) { + return null; + } + try { + if (null == waitTimeSeconds) { + return logLines.take(); + } else { + return this.logLines.poll(waitTimeSeconds, TimeUnit.SECONDS); + } + } catch (InterruptedException e) { + log.error("taking log line Interrupted", e); + return null; + } + } + + @Override + public void cancel() { + try { + if (null != this.closeable) { + this.closeable.close(); + } + this.logLines.clear(); + this.closed = Boolean.TRUE; + } catch (IOException e) { + throw new RuntimeException(e); + } + + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/reporting/ContainerStatusExplainer.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/reporting/ContainerStatusExplainer.java new file mode 100644 index 0000000000..e0a119cff0 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/reporting/ContainerStatusExplainer.java @@ -0,0 +1,57 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.reporting; + +import ai.starwhale.mlops.domain.task.status.TaskStatus; +import com.github.dockerjava.api.model.Container; +import java.util.HashMap; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; + + +@Slf4j +public class ContainerStatusExplainer { + + static final Map STATUS_MAP = new HashMap<>() { + { + put("running", TaskStatus.RUNNING); + put("created", TaskStatus.PREPARING); + put("dead", TaskStatus.FAIL); + put("paused", TaskStatus.RUNNING); + put("restarting", TaskStatus.RUNNING); + } + }; + + public TaskStatus statusOf(Container c) { + String state = c.getState(); + for (var entry : STATUS_MAP.entrySet()) { + if (entry.getKey().equalsIgnoreCase(state)) { + return entry.getValue(); + } + } + if ("exited".equalsIgnoreCase(state)) { + if (c.getStatus().toUpperCase().contains("Exited (0)".toUpperCase())) { + return TaskStatus.SUCCESS; + } + return TaskStatus.FAIL; + } + + log.warn("unexpected docker state detected State:{} Status: {}", state, c.getStatus()); + return TaskStatus.UNKNOWN; + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/reporting/DockerTaskReporter.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/reporting/DockerTaskReporter.java new file mode 100644 index 0000000000..a1c7e0a4b7 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/docker/reporting/DockerTaskReporter.java @@ -0,0 +1,97 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.reporting; + +import ai.starwhale.mlops.domain.system.SystemSettingService; +import ai.starwhale.mlops.domain.system.resourcepool.bo.ResourcePool; +import ai.starwhale.mlops.domain.task.status.TaskStatus; +import ai.starwhale.mlops.domain.task.status.TaskStatusMachine; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinder; +import ai.starwhale.mlops.schedule.impl.docker.SwTaskSchedulerDocker; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Container; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.jetbrains.annotations.Nullable; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.util.CollectionUtils; + +@Slf4j +public class DockerTaskReporter { + + final TaskReportReceiver taskReportReceiver; + + final SystemSettingService systemSettingService; + + final DockerClientFinder dockerClientFinder; + + final ContainerTaskMapper containerTaskMapper; + + final ContainerStatusExplainer containerStatusExplainer; + + final TaskStatusMachine taskStatusMachine; + + public DockerTaskReporter(TaskReportReceiver taskReportReceiver, SystemSettingService systemSettingService, + DockerClientFinder dockerClientFinder, + ContainerTaskMapper containerTaskMapper, ContainerStatusExplainer containerStatusExplainer, + TaskStatusMachine taskStatusMachine) { + this.taskReportReceiver = taskReportReceiver; + this.systemSettingService = systemSettingService; + this.dockerClientFinder = dockerClientFinder; + this.containerTaskMapper = containerTaskMapper; + this.containerStatusExplainer = containerStatusExplainer; + this.taskStatusMachine = taskStatusMachine; + } + + @Scheduled(initialDelay = 10000, fixedDelay = 3000) + public void reportTasks() { + + List resourcePools = systemSettingService.getAllResourcePools(); + if (CollectionUtils.isEmpty(resourcePools)) { + resourcePools = List.of(new ResourcePool()); + } + Set distinctDockerClients = resourcePools.stream() + .map(resourcePool -> dockerClientFinder.findProperDockerClient(resourcePool)) + .collect(Collectors.toSet()); + distinctDockerClients.forEach(dockerClient -> { + List containers = dockerClient.listContainersCmd() + .withLabelFilter(SwTaskSchedulerDocker.CONTAINER_LABELS).withShowAll(true).exec(); + taskReportReceiver.receive(containers.stream().map(c -> containerToTaskReport(c)).filter(Objects::nonNull) + .collect(Collectors.toList())); + }); + + } + + public void reportTask(Container c) { + taskReportReceiver.receive(List.of(containerToTaskReport(c))); + } + + @Nullable + private ReportedTask containerToTaskReport(Container c) { + TaskStatus status = containerStatusExplainer.statusOf(c); + Long stopMilli = taskStatusMachine.isFinal(status) ? System.currentTimeMillis() : null; + String failReason = TaskStatus.FAIL == status ? c.getStatus() : null; + return new ReportedTask(containerTaskMapper.taskIfOfContainer(c), status, 0, null, null, stopMilli, failReason); + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sClientConfig.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/BeanConfig.java similarity index 65% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sClientConfig.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/BeanConfig.java index 6491d5b0f9..ccf533c358 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sClientConfig.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/BeanConfig.java @@ -14,20 +14,43 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import io.kubernetes.client.informer.SharedInformerFactory; import io.kubernetes.client.openapi.ApiClient; +import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.apis.AppsV1Api; import io.kubernetes.client.openapi.apis.BatchV1Api; import io.kubernetes.client.openapi.apis.CoreV1Api; import io.kubernetes.client.util.Config; import java.io.IOException; +import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -@Configuration -public class K8sClientConfig { +@Configuration("k8sSchedulerBeanConfiguration") +//TODO add this line when ds building/ model serving/ runtime building are decoupled from k8s +//@ConditionalOnProperty(value = "sw.scheduler", havingValue = "k8s") +public class BeanConfig { + + @Bean + public K8sClient k8sClient( + ApiClient client, + CoreV1Api coreV1Api, + BatchV1Api batchV1Api, + AppsV1Api appsV1Api, + @Value("${sw.infra.k8s.name-space}") String ns, + SharedInformerFactory informerFactory + ) throws ApiException { + return new K8sClient( + client, + coreV1Api, + batchV1Api, + appsV1Api, + ns, + informerFactory + ); + } @Bean public ApiClient apiClient() throws IOException { @@ -55,4 +78,5 @@ public AppsV1Api appsV1Api(ApiClient apiClient) { public SharedInformerFactory sharedInformerFactory(ApiClient apiClient) { return new SharedInformerFactory(apiClient); } + } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ContainerOverwriteSpec.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ContainerOverwriteSpec.java similarity index 96% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ContainerOverwriteSpec.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ContainerOverwriteSpec.java index 9a9c1919d9..4e56b518db 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ContainerOverwriteSpec.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ContainerOverwriteSpec.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import io.kubernetes.client.openapi.models.V1EnvVar; import io.kubernetes.client.openapi.models.V1Probe; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/EventWatcher.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/EventWatcher.java similarity index 93% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/EventWatcher.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/EventWatcher.java index 4cecb751e9..430982ee0a 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/EventWatcher.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/EventWatcher.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import io.kubernetes.client.informer.ResourceEventHandler; import io.kubernetes.client.openapi.models.CoreV1Event; @@ -22,12 +22,14 @@ import io.kubernetes.client.openapi.models.V1Node; import io.kubernetes.client.openapi.models.V1Pod; import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.context.event.EventListener; import org.springframework.stereotype.Component; @Slf4j @Component +@ConditionalOnProperty(value = "sw.scheduler", havingValue = "k8s") public class EventWatcher { final K8sClient k8sClient; final ResourceEventHandler eventHandlerJob; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/JobEventHandler.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/JobEventHandler.java similarity index 96% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/JobEventHandler.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/JobEventHandler.java index b5b2e731e2..48372656f9 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/JobEventHandler.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/JobEventHandler.java @@ -14,15 +14,15 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import ai.starwhale.mlops.domain.dataset.DatasetService; import ai.starwhale.mlops.domain.dataset.build.BuildStatus; import ai.starwhale.mlops.domain.runtime.RuntimeService; import ai.starwhale.mlops.domain.task.status.TaskStatus; import ai.starwhale.mlops.domain.task.status.TaskStatusMachine; -import ai.starwhale.mlops.reporting.ReportedTask; -import ai.starwhale.mlops.reporting.TaskModifyReceiver; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import io.kubernetes.client.informer.ResourceEventHandler; import io.kubernetes.client.openapi.models.V1Job; import io.kubernetes.client.openapi.models.V1JobCondition; @@ -41,19 +41,19 @@ @Component public class JobEventHandler implements ResourceEventHandler { - private final TaskModifyReceiver taskModifyReceiver; + private final TaskReportReceiver taskReportReceiver; private final TaskStatusMachine taskStatusMachine; private final RuntimeService runtimeService; private final DatasetService datasetService; private final K8sClient k8sClient; public JobEventHandler( - TaskModifyReceiver taskModifyReceiver, + TaskReportReceiver taskReportReceiver, TaskStatusMachine taskStatusMachine, RuntimeService runtimeService, DatasetService datasetService, K8sClient k8sClient) { - this.taskModifyReceiver = taskModifyReceiver; + this.taskReportReceiver = taskReportReceiver; this.taskStatusMachine = taskStatusMachine; this.runtimeService = runtimeService; this.datasetService = datasetService; @@ -224,7 +224,7 @@ private void updateEvalTask(V1Job job, boolean onDelete) { .retryCount(retryNum) .failedReason(StringUtils.hasText(failedReason) ? failedReason : null) .build(); - taskModifyReceiver.receive(List.of(report)); + taskReportReceiver.receive(List.of(report)); } private String conditionsLogString(List conditions) { diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sClient.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sClient.java similarity index 96% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sClient.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sClient.java index dd57d910e7..3756ec7dfc 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sClient.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sClient.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import io.kubernetes.client.Exec; import io.kubernetes.client.custom.V1Patch; @@ -48,9 +48,11 @@ import io.kubernetes.client.util.labels.LabelSelector; import java.io.IOException; import java.io.InputStream; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import okhttp3.Call; @@ -59,10 +61,9 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; @Slf4j -@Component public class K8sClient { private final ApiClient client; @@ -300,12 +301,14 @@ public V1Pod podOfJob(String selector) throws ApiException { if (podList.getItems().isEmpty()) { return null; } - if (podList.getItems().size() > 1) { - throw new ApiException("to many pods"); + var pods = podList.getItems(); + if (CollectionUtils.isEmpty(pods)) { + return null; } - - V1Pod pod = podList.getItems().get(0); - return pod; + Optional first = pods.stream() + .sorted(Comparator.comparing(p -> ((V1Pod) p).getMetadata().getCreationTimestamp()).reversed()) + .findFirst(); + return first.get(); } private void appendLog(V1Pod pod, StringBuilder logBuilder, String containerName) { diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sJobTemplate.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sJobTemplate.java similarity index 98% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sJobTemplate.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sJobTemplate.java index 3ddcda187c..581c1a927c 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sJobTemplate.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sJobTemplate.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import ai.starwhale.mlops.domain.system.resourcepool.bo.Toleration; import io.kubernetes.client.custom.IntOrString; @@ -50,6 +50,8 @@ import org.springframework.util.StringUtils; @Component +// TODO add this line when ds building/ model serving/ runtime building are decoupled from k8s +//@ConditionalOnProperty(value = "sw.scheduler", havingValue = "k8s") public class K8sJobTemplate { public static final Map starwhaleJobLabel = Map.of("owner", "starwhale"); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sTaskScheduler.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sSwTaskScheduler.java similarity index 57% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sTaskScheduler.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sSwTaskScheduler.java index 3fa8696fe4..150ef0bb77 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/K8sTaskScheduler.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/K8sSwTaskScheduler.java @@ -14,11 +14,8 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; -import ai.starwhale.mlops.configuration.RunTimeProperties; -import ai.starwhale.mlops.configuration.security.TaskTokenValidator; -import ai.starwhale.mlops.domain.job.bo.Job; import ai.starwhale.mlops.domain.job.bo.JobRuntime; import ai.starwhale.mlops.domain.runtime.RuntimeResource; import ai.starwhale.mlops.domain.system.resourcepool.bo.Toleration; @@ -26,11 +23,11 @@ import ai.starwhale.mlops.domain.task.status.TaskStatus; import ai.starwhale.mlops.domain.task.status.TaskStatusChangeWatcher; import ai.starwhale.mlops.domain.task.status.watchers.TaskWatcherForSchedule; -import ai.starwhale.mlops.domain.task.status.watchers.log.TaskLogK8sCollector; -import ai.starwhale.mlops.exception.StarwhaleException; import ai.starwhale.mlops.exception.SwProcessException; import ai.starwhale.mlops.exception.SwProcessException.ErrorType; import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.TaskRunningEnvBuilder; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import ai.starwhale.mlops.storage.StorageAccessService; import cn.hutool.json.JSONUtil; import io.kubernetes.client.openapi.ApiException; @@ -47,78 +44,48 @@ import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.jetbrains.annotations.NotNull; -import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; -import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; -@Service @Slf4j -public class K8sTaskScheduler implements SwTaskScheduler { +public class K8sSwTaskScheduler implements SwTaskScheduler { final K8sClient k8sClient; - final RunTimeProperties runTimeProperties; - - final TaskTokenValidator taskTokenValidator; - final K8sJobTemplate k8sJobTemplate; - final String instanceUri; - final int devPort; - final int datasetLoadBatchSize; + final TaskRunningEnvBuilder taskRunningEnvBuilder; final String restartPolicy; final int backoffLimit; final StorageAccessService storageAccessService; - final ThreadPoolTaskScheduler taskScheduler; - - private final TaskLogK8sCollector taskLogK8sCollector; + final ThreadPoolTaskScheduler cmdExecThreadPool; - public K8sTaskScheduler( + public K8sSwTaskScheduler( K8sClient k8sClient, - TaskTokenValidator taskTokenValidator, - RunTimeProperties runTimeProperties, K8sJobTemplate k8sJobTemplate, - @Value("${sw.instance-uri}") String instanceUri, - @Value("${sw.task.dev-port}") int devPort, - @Value("${sw.dataset.load.batch-size}") int datasetLoadBatchSize, - @Value("${sw.infra.k8s.job.restart-policy}") String restartPolicy, - @Value("${sw.infra.k8s.job.backoff-limit}") Integer backoffLimit, + TaskRunningEnvBuilder taskRunningEnvBuilder, + String restartPolicy, + Integer backoffLimit, StorageAccessService storageAccessService, - TaskLogK8sCollector taskLogK8sCollector, - ThreadPoolTaskScheduler taskScheduler + ThreadPoolTaskScheduler cmdExecThreadPool ) { this.k8sClient = k8sClient; - this.taskTokenValidator = taskTokenValidator; - this.runTimeProperties = runTimeProperties; this.k8sJobTemplate = k8sJobTemplate; - this.instanceUri = instanceUri; - this.devPort = devPort; + this.taskRunningEnvBuilder = taskRunningEnvBuilder; this.storageAccessService = storageAccessService; - this.datasetLoadBatchSize = datasetLoadBatchSize; this.restartPolicy = restartPolicy; this.backoffLimit = backoffLimit; - this.taskLogK8sCollector = taskLogK8sCollector; - this.taskScheduler = taskScheduler; + this.cmdExecThreadPool = cmdExecThreadPool; } @Override - public void schedule(Collection tasks) { + public void schedule(Collection tasks, TaskReportReceiver taskReportReceiver) { tasks.forEach(this::deployTaskToK8s); } @Override public void stop(Collection tasks) { tasks.forEach(task -> { - try { - // K8s do not support job suspend before 1.24, so we collect logs and delete job directly - // https://kubernetes.io/docs/concepts/workloads/controllers/job/#suspending-a-job - taskLogK8sCollector.collect(task); - } catch (StarwhaleException e) { - log.warn("collect task {} log failed, {}", task.getId(), e.getMessage()); - } - try { k8sClient.deleteJob(task.getId().toString()); } catch (ApiException e) { @@ -137,20 +104,13 @@ public Future exec(Task task, String... command) { if (pods.getItems().size() != 1) { throw new SwProcessException(ErrorType.K8S, "multiple pods found for task " + task.getId()); } - return taskScheduler.submit( + return cmdExecThreadPool.submit( () -> k8sClient.execInPod(pods.getItems().get(0).getMetadata().getName(), null, command)); } catch (ApiException e) { throw new SwProcessException(ErrorType.K8S, "exec command failed: " + e.getResponseBody(), e); } } - /** - * {instance}/project/{projectName}/dataset/{datasetName}/version/{version} - */ - static final String FORMATTER_URI_ARTIFACT = "%s/project/%s/%s/%s/version/%s"; - - static final String FORMATTER_VERSION_ARTIFACT = "%s/version/%s"; - static final String ANNOTATION_KEY_JOB_ID = "starwhale.ai/job-id"; static final String ANNOTATION_KEY_TASK_ID = "starwhale.ai/task-id"; static final String ANNOTATION_KEY_USER_ID = "starwhale.ai/user-id"; @@ -254,85 +214,7 @@ private List getPatchedResources(Task task) { @NotNull private List buildCoreContainerEnvs(Task task) { - Job swJob = task.getStep().getJob(); - var model = swJob.getModel(); - var runtime = swJob.getJobRuntime(); - Map coreContainerEnvs = new HashMap<>(); - var taskEnv = task.getTaskRequest().getEnv(); - if (!CollectionUtils.isEmpty(taskEnv)) { - taskEnv.forEach(env -> coreContainerEnvs.put(env.getName(), env.getValue())); - } - coreContainerEnvs.put("SW_RUNTIME_PYTHON_VERSION", runtime.getManifest().getEnvironment().getPython()); - coreContainerEnvs.put("SW_VERSION", runtime.getManifest().getEnvironment().getLock().getSwVersion()); - coreContainerEnvs.put("SW_TASK_STEP", task.getStep().getName()); - coreContainerEnvs.put("DATASET_CONSUMPTION_BATCH_SIZE", String.valueOf(datasetLoadBatchSize)); - // support multi dataset uris - coreContainerEnvs.put("SW_DATASET_URI", - swJob.getDataSets().stream() - .map(dataSet -> String.format( - FORMATTER_URI_ARTIFACT, - instanceUri, - dataSet.getProjectId(), - "dataset", - dataSet.getName(), - dataSet.getVersion()) - ).collect(Collectors.joining(" "))); - coreContainerEnvs.put("SW_MODEL_URI", - String.format( - FORMATTER_URI_ARTIFACT, - instanceUri, - model.getProjectId(), - "model", - model.getName(), - model.getVersion())); - coreContainerEnvs.put("SW_RUNTIME_URI", - String.format( - FORMATTER_URI_ARTIFACT, - instanceUri, - runtime.getProjectId(), - "runtime", - runtime.getName(), - runtime.getVersion())); - coreContainerEnvs.put("SW_MODEL_VERSION", - String.format(FORMATTER_VERSION_ARTIFACT, - model.getName(), model.getVersion())); - coreContainerEnvs.put("SW_RUNTIME_VERSION", - String.format(FORMATTER_VERSION_ARTIFACT, - runtime.getName(), runtime.getVersion())); - coreContainerEnvs.put("SW_RUN_HANDLER", task.getTaskRequest().getJobName()); - coreContainerEnvs.put("SW_TASK_INDEX", String.valueOf(task.getTaskRequest().getIndex())); - coreContainerEnvs.put("SW_TASK_NUM", String.valueOf(task.getTaskRequest().getTotal())); - coreContainerEnvs.put("SW_JOB_VERSION", swJob.getUuid()); - - // datastore env - coreContainerEnvs.put("SW_TOKEN", taskTokenValidator.getTaskToken(swJob.getOwner(), task.getId())); - coreContainerEnvs.put("SW_INSTANCE_URI", instanceUri); - coreContainerEnvs.put("SW_PROJECT", swJob.getProject().getName()); - coreContainerEnvs.put("SW_PYPI_INDEX_URL", runTimeProperties.getPypi().getIndexUrl()); - coreContainerEnvs.put("SW_PYPI_EXTRA_INDEX_URL", runTimeProperties.getPypi().getExtraIndexUrl()); - coreContainerEnvs.put("SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost()); - coreContainerEnvs.put("SW_PYPI_TIMEOUT", String.valueOf(runTimeProperties.getPypi().getTimeout())); - coreContainerEnvs.put("SW_PYPI_RETRIES", String.valueOf(runTimeProperties.getPypi().getRetries())); - if (StringUtils.hasText(runTimeProperties.getCondarc())) { - coreContainerEnvs.put("SW_CONDA_CONFIG", runTimeProperties.getCondarc()); - } - - // GPU resource - var resources = getPatchedResources(task).stream(); - var gpu = resources.anyMatch(r -> r.getType().equals(ResourceOverwriteSpec.RESOURCE_GPU) && r.getRequest() > 0); - // overwrite visible devices to none - if (!gpu) { - // https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration - coreContainerEnvs.put("NVIDIA_VISIBLE_DEVICES", ""); - } - - if (swJob.isDevMode()) { - coreContainerEnvs.put("SW_DEV_TOKEN", swJob.getDevPassword()); - coreContainerEnvs.put("SW_DEV_PORT", String.valueOf(devPort)); - } - - var envs = mapToEnv(coreContainerEnvs); - + var envs = mapToEnv(taskRunningEnvBuilder.buildCoreContainerEnvs(task)); envs.add( new V1EnvVar() .name("SW_POD_NAME") diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ResourceEventHolder.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceEventHolder.java similarity index 99% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ResourceEventHolder.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceEventHolder.java index cfce9febbe..904161400a 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ResourceEventHolder.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceEventHolder.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import io.kubernetes.client.informer.ResourceEventHandler; import io.kubernetes.client.openapi.models.CoreV1Event; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ResourceOverwriteSpec.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceOverwriteSpec.java similarity index 96% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ResourceOverwriteSpec.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceOverwriteSpec.java index b742bc1dd4..63c1b5905e 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/ResourceOverwriteSpec.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceOverwriteSpec.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import ai.starwhale.mlops.domain.runtime.RuntimeResource; import io.kubernetes.client.custom.Quantity; @@ -42,7 +42,7 @@ public class ResourceOverwriteSpec { public static final String RESOURCE_MEMORY = "memory"; - static final String RESOURCE_GPU = "nvidia.com/gpu"; + public static final String RESOURCE_GPU = "nvidia.com/gpu"; public static Set K8S_BUILTIN_RESOURCES = Set.of(RESOURCE_CPU, RESOURCE_MEMORY); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/SwSchedulerFactoryK8S.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/SwSchedulerFactoryK8S.java new file mode 100644 index 0000000000..068078bc46 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/SwSchedulerFactoryK8S.java @@ -0,0 +1,84 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.k8s; + +import ai.starwhale.mlops.schedule.SwSchedulerAbstractFactory; +import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.TaskRunningEnvBuilder; +import ai.starwhale.mlops.schedule.impl.k8s.log.TaskLogK8sCollectorFactory; +import ai.starwhale.mlops.schedule.log.TaskLogCollectorFactory; +import ai.starwhale.mlops.storage.StorageAccessService; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; + +@Configuration +@ConditionalOnProperty(value = "sw.scheduler", havingValue = "k8s") +public class SwSchedulerFactoryK8S implements SwSchedulerAbstractFactory { + + final K8sClient k8sClient; + + final K8sJobTemplate k8sJobTemplate; + + final TaskRunningEnvBuilder taskRunningEnvBuilder; + + final String restartPolicy; + final int backoffLimit; + final StorageAccessService storageAccessService; + final ThreadPoolTaskScheduler cmdExecThreadPool; + + public SwSchedulerFactoryK8S( + K8sClient k8sClient, + K8sJobTemplate k8sJobTemplate, + TaskRunningEnvBuilder taskRunningEnvBuilder, + @Value("${sw.infra.k8s.job.restart-policy}") String restartPolicy, + @Value("${sw.infra.k8s.job.backoff-limit}") Integer backoffLimit, + StorageAccessService storageAccessService, + ThreadPoolTaskScheduler cmdExecThreadPool + ) { + this.k8sClient = k8sClient; + this.k8sJobTemplate = k8sJobTemplate; + this.taskRunningEnvBuilder = taskRunningEnvBuilder; + this.restartPolicy = restartPolicy; + this.backoffLimit = backoffLimit; + this.storageAccessService = storageAccessService; + this.cmdExecThreadPool = cmdExecThreadPool; + } + + @Bean + @Override + public SwTaskScheduler buildSwTaskScheduler() { + return new K8sSwTaskScheduler( + k8sClient, + k8sJobTemplate, + taskRunningEnvBuilder, + restartPolicy, + backoffLimit, + storageAccessService, + cmdExecThreadPool + ); + } + + @Bean + @Override + public TaskLogCollectorFactory buildTaskLogCollectorFactory() { + return new TaskLogK8sCollectorFactory(k8sClient, k8sJobTemplate); + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/Util.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/Util.java similarity index 94% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/Util.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/Util.java index c1379d6444..bb135bfb61 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/Util.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/Util.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import java.time.OffsetDateTime; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorFactory.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorFactory.java similarity index 71% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorFactory.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorFactory.java index eaa4f2a07f..19a3877308 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorFactory.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorFactory.java @@ -14,22 +14,25 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s.log; +package ai.starwhale.mlops.schedule.impl.k8s.log; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; import io.kubernetes.client.openapi.ApiException; import java.io.IOException; import org.springframework.stereotype.Service; @Service +//TODO remove this class when dataset build is a job public class CancellableJobLogK8sCollectorFactory { + private final K8sClient k8sClient; public CancellableJobLogK8sCollectorFactory(K8sClient k8sClient) { this.k8sClient = k8sClient; } - public CancellableJobLogCollector make(String jobName) throws IOException, ApiException { - return new CancellableJobLogK8sCollector(this.k8sClient, jobName); + public TaskLogStreamingCollector make(String jobName) throws IOException, ApiException { + return new TaskLogK8sStreamingCollector(this.k8sClient, jobName); } } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogK8sCollectorFactory.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogK8sCollectorFactory.java new file mode 100644 index 0000000000..491dfccb47 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogK8sCollectorFactory.java @@ -0,0 +1,66 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.k8s.log; + +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.exception.StarwhaleException; +import ai.starwhale.mlops.exception.SwProcessException; +import ai.starwhale.mlops.exception.SwProcessException.ErrorType; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.log.TaskLogCollectorFactory; +import ai.starwhale.mlops.schedule.log.TaskLogOfflineCollector; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; +import io.kubernetes.client.openapi.ApiException; +import java.io.IOException; +import java.text.MessageFormat; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class TaskLogK8sCollectorFactory implements TaskLogCollectorFactory { + + final K8sClient k8sClient; + + final K8sJobTemplate k8sJobTemplate; + + public TaskLogK8sCollectorFactory(K8sClient k8sClient, K8sJobTemplate k8sJobTemplate) { + this.k8sClient = k8sClient; + this.k8sJobTemplate = k8sJobTemplate; + } + + + @Override + public TaskLogOfflineCollector offlineCollector(Task task) throws StarwhaleException { + return new TaskLogOfflineCollectorK8s(k8sClient, k8sJobTemplate.getJobContainerNames( + k8sJobTemplate.loadJob(K8sJobTemplate.WORKLOAD_TYPE_EVAL)), task); + } + + @Override + public TaskLogStreamingCollector streamingCollector(Task task) throws StarwhaleException { + try { + return new TaskLogK8sStreamingCollector(this.k8sClient, String.valueOf(task.getId())); + } catch (IOException e) { + throw new SwProcessException(ErrorType.NETWORK, + MessageFormat.format("read k8s api exception {0}", e.getMessage()), + e); + } catch (ApiException e) { + throw new SwProcessException(ErrorType.INFRA, + MessageFormat.format("k8s api exception {0}", e.getResponseBody()), + e); + } + } +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollector.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogK8sStreamingCollector.java similarity index 85% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollector.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogK8sStreamingCollector.java index c3e70484e5..405a0abdd8 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollector.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogK8sStreamingCollector.java @@ -14,9 +14,10 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s.log; +package ai.starwhale.mlops.schedule.impl.k8s.log; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; import io.kubernetes.client.openapi.ApiException; import java.io.BufferedReader; import java.io.IOException; @@ -25,14 +26,15 @@ import okhttp3.Call; import okhttp3.Response; -public class CancellableJobLogK8sCollector implements CancellableJobLogCollector { +public class TaskLogK8sStreamingCollector implements TaskLogStreamingCollector { + public static final String WORKER_CONTAINER = "worker"; final K8sClient k8sClient; final Call call; final Response resp; final BufferedReader bufferedReader; - public CancellableJobLogK8sCollector(K8sClient k8sClient, String jobName) + public TaskLogK8sStreamingCollector(K8sClient k8sClient, String jobName) throws IOException, ApiException { this.k8sClient = k8sClient; call = k8sClient.readLog(getPodName(jobName), WORKER_CONTAINER, true); @@ -62,7 +64,7 @@ private String getPodName(String taskId) throws ApiException { } @Override - public String readLine() throws IOException { + public String readLine(Long timeWaitSeconds) throws IOException { return bufferedReader.readLine(); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogOfflineCollectorK8s.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogOfflineCollectorK8s.java new file mode 100644 index 0000000000..decad092e2 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/log/TaskLogOfflineCollectorK8s.java @@ -0,0 +1,70 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.k8s.log; + +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.exception.StarwhaleException; +import ai.starwhale.mlops.exception.SwProcessException; +import ai.starwhale.mlops.exception.SwProcessException.ErrorType; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.log.TaskLogOfflineCollector; +import io.kubernetes.client.openapi.ApiException; +import io.kubernetes.client.openapi.models.V1Pod; +import io.vavr.Tuple2; +import java.text.MessageFormat; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; + + +@Slf4j +public class TaskLogOfflineCollectorK8s implements TaskLogOfflineCollector { + + final K8sClient k8sClient; + + final List containers; + + final Task task; + + public TaskLogOfflineCollectorK8s(K8sClient k8sClient, List containers, Task task) { + this.k8sClient = k8sClient; + this.containers = containers; + this.task = task; + } + + @Override + public Tuple2 collect() throws StarwhaleException { + log.debug("logging for task {} begins...", task.getId()); + try { + V1Pod v1Pod = k8sClient.podOfJob(K8sClient.toV1LabelSelector(Map.of( + K8sJobTemplate.JOB_IDENTITY_LABEL, task.getId().toString()))); + if (null == v1Pod) { + log.error("pod not exists for task {}", task.getId()); + return null; + } + String logName = v1Pod.getMetadata().getName(); + String taskLog = k8sClient.logOfPod(v1Pod, containers); + return new Tuple2<>(logName, taskLog); + } catch (ApiException e) { + throw new SwProcessException(ErrorType.INFRA, + MessageFormat.format("k8s api exception {}", e.getResponseBody()), + e); + } + } + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/NodeEventHandler.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/reporting/NodeEventHandler.java similarity index 93% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/NodeEventHandler.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/reporting/NodeEventHandler.java index 9c7ae0e195..5e3fb072a9 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/NodeEventHandler.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/reporting/NodeEventHandler.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s.reporting; import ai.starwhale.mlops.domain.system.agent.AgentCache; import ai.starwhale.mlops.domain.system.agent.AgentStatus; @@ -24,9 +24,11 @@ import io.kubernetes.client.openapi.models.V1NodeSpec; import io.kubernetes.client.openapi.models.V1NodeStatus; import java.math.BigDecimal; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.stereotype.Service; @Service +@ConditionalOnProperty(value = "sw.scheduler", havingValue = "k8s") public class NodeEventHandler implements ResourceEventHandler { final AgentCache agentCache; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/PodEventHandler.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/reporting/PodEventHandler.java similarity index 86% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/PodEventHandler.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/reporting/PodEventHandler.java index bb567da6e1..1aa6478775 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/PodEventHandler.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/impl/k8s/reporting/PodEventHandler.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s.reporting; import ai.starwhale.mlops.domain.dataset.DatasetService; import ai.starwhale.mlops.domain.dataset.build.BuildStatus; @@ -22,37 +22,41 @@ import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.task.bo.Task; import ai.starwhale.mlops.domain.task.status.TaskStatus; -import ai.starwhale.mlops.domain.task.status.watchers.log.TaskLogK8sCollector; -import ai.starwhale.mlops.reporting.ReportedTask; -import ai.starwhale.mlops.reporting.TaskModifyReceiver; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.Util; +import ai.starwhale.mlops.schedule.log.TaskLogSaver; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import io.kubernetes.client.informer.ResourceEventHandler; import io.kubernetes.client.openapi.models.V1Pod; import java.util.Collection; import java.util.List; import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @Slf4j @Component +@ConditionalOnProperty(value = "sw.scheduler", havingValue = "k8s") public class PodEventHandler implements ResourceEventHandler { - final TaskLogK8sCollector taskLogK8sCollector; + final TaskLogSaver taskLogSaver; final BuildLogCollector buildLogCollector; - final TaskModifyReceiver taskModifyReceiver; + final TaskReportReceiver taskReportReceiver; final HotJobHolder jobHolder; final DatasetService datasetService; public PodEventHandler( - TaskLogK8sCollector taskLogK8sCollector, + TaskLogSaver taskLogSaver, BuildLogCollector buildLogCollector, - TaskModifyReceiver taskModifyReceiver, + TaskReportReceiver taskReportReceiver, HotJobHolder jobHolder, DatasetService datasetService) { - this.taskLogK8sCollector = taskLogK8sCollector; + this.taskLogSaver = taskLogSaver; this.buildLogCollector = buildLogCollector; - this.taskModifyReceiver = taskModifyReceiver; + this.taskReportReceiver = taskReportReceiver; this.jobHolder = jobHolder; this.datasetService = datasetService; } @@ -181,9 +185,18 @@ private void updateEvalTask(V1Pod pod) { .startTimeMillis(startTime) .stopTimeMillis(null) .build(); - taskModifyReceiver.receive(List.of(report)); + taskReportReceiver.receive(List.of(report)); } + /** + * In k8s implementation of taskScheduler there is task retry support + * So, every time a pod finishes we collect the log for the pod. + * This is a compensation for the log collecting in task watcher which only collect log once just before task + * finishes + * + * @param pod pod + * @param type type + */ private void collectLog(V1Pod pod, String type) { log.debug("collect log for pod {} status {}", pod.getMetadata().getName(), pod.getStatus()); if (null == pod.getStatus() @@ -208,7 +221,7 @@ private void collectLog(V1Pod pod, String type) { return; } Task task = optionalTasks.stream().findAny().get(); - taskLogK8sCollector.collect(task); + taskLogSaver.saveLog(task); } break; case K8sJobTemplate.WORKLOAD_TYPE_DATASET_BUILD: diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogCollectorFactory.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogCollectorFactory.java new file mode 100644 index 0000000000..8b79a6aef4 --- /dev/null +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogCollectorFactory.java @@ -0,0 +1,39 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.log; + +import ai.starwhale.mlops.domain.task.bo.Task; + +public interface TaskLogCollectorFactory { + + /** + * collect the whole log of a task with the name of the execution + * + * @param task the target task + * @return a TaskLogOfflineCollector + */ + TaskLogOfflineCollector offlineCollector(Task task); + + /** + * return a streaming task log reader which could be closed at anytime + * + * @param task the target task + * @return a TaskLogStreamingCollector + */ + TaskLogStreamingCollector streamingCollector(Task task); + +} diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/log/TaskLogCollector.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogOfflineCollector.java similarity index 71% rename from server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/log/TaskLogCollector.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogOfflineCollector.java index 1b4177dfa8..97f066aa5e 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/log/TaskLogCollector.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogOfflineCollector.java @@ -14,13 +14,12 @@ * limitations under the License. */ -package ai.starwhale.mlops.domain.task.status.watchers.log; +package ai.starwhale.mlops.schedule.log; -import ai.starwhale.mlops.domain.task.bo.Task; -import ai.starwhale.mlops.exception.StarwhaleException; +import io.vavr.Tuple2; -public interface TaskLogCollector { +public interface TaskLogOfflineCollector { - void collect(Task task) throws StarwhaleException; + Tuple2 collect(); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/log/TaskLogK8sCollector.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogSaver.java similarity index 53% rename from server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/log/TaskLogK8sCollector.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogSaver.java index fec79edb96..2b6145323e 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/task/status/watchers/log/TaskLogK8sCollector.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogSaver.java @@ -14,72 +14,53 @@ * limitations under the License. */ -package ai.starwhale.mlops.domain.task.status.watchers.log; +package ai.starwhale.mlops.schedule.log; import ai.starwhale.mlops.domain.task.bo.Task; import ai.starwhale.mlops.exception.StarwhaleException; import ai.starwhale.mlops.exception.SwProcessException; import ai.starwhale.mlops.exception.SwProcessException.ErrorType; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; import ai.starwhale.mlops.storage.StorageAccessService; -import io.kubernetes.client.openapi.ApiException; -import io.kubernetes.client.openapi.models.V1Pod; +import io.vavr.Tuple2; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.text.MessageFormat; -import java.util.List; -import java.util.Map; import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Component; +import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; @Slf4j -@Component -public class TaskLogK8sCollector implements TaskLogCollector { +@Service +public class TaskLogSaver { + final TaskLogCollectorFactory taskLogCollectorFactory; final StorageAccessService storageService; - final K8sClient k8sClient; - - final List containers; - - public TaskLogK8sCollector(StorageAccessService storageService, - K8sClient k8sClient, K8sJobTemplate k8sJobTemplate) { + public TaskLogSaver(TaskLogCollectorFactory taskLogCollectorFactory, StorageAccessService storageService) { + this.taskLogCollectorFactory = taskLogCollectorFactory; this.storageService = storageService; - this.k8sClient = k8sClient; - this.containers = k8sJobTemplate.getJobContainerNames( - k8sJobTemplate.loadJob(K8sJobTemplate.WORKLOAD_TYPE_EVAL)); } - @Override - public void collect(Task task) throws StarwhaleException { + public void saveLog(Task task) throws StarwhaleException { log.debug("logging for task {} begins...", task.getId()); try { - V1Pod v1Pod = k8sClient.podOfJob(K8sClient.toV1LabelSelector(Map.of( - K8sJobTemplate.JOB_IDENTITY_LABEL, task.getId().toString()))); - if (null == v1Pod) { - log.warn("pod not exists for task {}", task.getId()); + Tuple2 logInfo = taskLogCollectorFactory.offlineCollector(task).collect(); + if (null == logInfo) { return; } - String logName = v1Pod.getMetadata().getName(); - String taskLog = k8sClient.logOfPod(v1Pod, containers); + String taskLog = logInfo._2(); log.debug("logs for task {} collected {} ...", task.getId(), StringUtils.hasText(taskLog) ? taskLog.substring(0, Math.min(taskLog.length() - 1, 100)) : ""); - String logPath = resolveLogPath(task, logName); + String logPath = resolveLogPath(task, logInfo._1()); log.debug("putting log to storage at path {}", logPath); storageService.put(logPath, taskLog.getBytes(StandardCharsets.UTF_8)); - } catch (ApiException e) { - throw new SwProcessException(ErrorType.INFRA, - MessageFormat.format("k8s api exception {}", e.getResponseBody()), - e); } catch (IOException e) { - throw new SwProcessException(ErrorType.STORAGE, "uploading log failed", e); + throw new SwProcessException(ErrorType.STORAGE, "uploading log to storage failed", e); } } private String resolveLogPath(Task task, String logName) { return task.getResultRootPath().logDir() + "/" + logName; } + } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogCollector.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogStreamingCollector.java similarity index 81% rename from server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogCollector.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogStreamingCollector.java index 9590ab023d..63dedd2a8a 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogCollector.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/log/TaskLogStreamingCollector.java @@ -14,13 +14,13 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s.log; +package ai.starwhale.mlops.schedule.log; import java.io.IOException; -public interface CancellableJobLogCollector { +public interface TaskLogStreamingCollector { - String readLine() throws IOException; + String readLine(Long waitTimeSeconds) throws IOException; void cancel(); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/reporting/ReportedTask.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/ReportedTask.java similarity index 95% rename from server/controller/src/main/java/ai/starwhale/mlops/reporting/ReportedTask.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/ReportedTask.java index b39e557dec..eb6f731a65 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/reporting/ReportedTask.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/ReportedTask.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.reporting; +package ai.starwhale.mlops.schedule.reporting; import ai.starwhale.mlops.domain.task.status.TaskStatus; import lombok.AllArgsConstructor; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/reporting/SimpleTaskModifyReceiver.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/SimpleTaskReportReceiver.java similarity index 96% rename from server/controller/src/main/java/ai/starwhale/mlops/reporting/SimpleTaskModifyReceiver.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/SimpleTaskReportReceiver.java index 37051a385a..441bc210c7 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/reporting/SimpleTaskModifyReceiver.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/SimpleTaskReportReceiver.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.reporting; +package ai.starwhale.mlops.schedule.reporting; import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.task.bo.Task; @@ -32,13 +32,13 @@ */ @Slf4j @Service -public class SimpleTaskModifyReceiver implements TaskModifyReceiver { +public class SimpleTaskReportReceiver implements TaskReportReceiver { final HotJobHolder jobHolder; final TaskMapper taskMapper; - public SimpleTaskModifyReceiver(HotJobHolder jobHolder, TaskMapper taskMapper) { + public SimpleTaskReportReceiver(HotJobHolder jobHolder, TaskMapper taskMapper) { this.jobHolder = jobHolder; this.taskMapper = taskMapper; } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/reporting/TaskModifyReceiver.java b/server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/TaskReportReceiver.java similarity index 90% rename from server/controller/src/main/java/ai/starwhale/mlops/reporting/TaskModifyReceiver.java rename to server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/TaskReportReceiver.java index 40e0c161e2..145adc63a2 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/reporting/TaskModifyReceiver.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/schedule/reporting/TaskReportReceiver.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package ai.starwhale.mlops.reporting; +package ai.starwhale.mlops.schedule.reporting; import java.util.List; /** * receive task status from agents */ -public interface TaskModifyReceiver { +public interface TaskReportReceiver { /** * process the report from Agent diff --git a/server/controller/src/main/resources/application.yaml b/server/controller/src/main/resources/application.yaml index 4faeec2a26..36c87bd255 100644 --- a/server/controller/src/main/resources/application.yaml +++ b/server/controller/src/main/resources/application.yaml @@ -53,6 +53,9 @@ sw: min-time-to-live-in-seconds: ${SW_ONLINE_EVAL_MIN_TTL_SECS:1800} # 30min instance-uri: ${SW_INSTANCE_URI:http://console.pre.intra.starwhale.ai} infra: + docker: + network: ${SW_DOCKER_CONTAINER_NETWORK:host} + node-ip: ${SW_DOCKER_CONTAINER_NODE_IP:127.0.0.1} k8s: name-space: ${SW_K8S_NAME_SPACE:default} host-path-for-cache: ${SW_K8S_HOST_PATH_FOR_CACHE:} @@ -64,6 +67,7 @@ sw: template-path: ${SW_K8S_JOB_TEMPLATE_PATH:} restart-policy: ${SW_K8S_JOB_RESTART_POLICY:Never} backoff-limit: ${SW_K8S_JOB_BACKOFF_LIMIT:10} + scheduler: ${SW_SCHEDULER:k8s} storage: type: ${SW_STORAGE_TYPE:minio} path-prefix: ${SW_STORAGE_PREFIX:starwhale} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/JobControllerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/JobControllerTest.java index 3530df1abf..3d07299bd5 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/JobControllerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/JobControllerTest.java @@ -55,7 +55,7 @@ import ai.starwhale.mlops.domain.job.RuntimeSuggestionService; import ai.starwhale.mlops.domain.task.TaskService; import ai.starwhale.mlops.exception.api.StarwhaleApiException; -import ai.starwhale.mlops.schedule.k8s.ResourceEventHolder; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder; import com.github.pagehelper.Page; import java.util.List; import java.util.Objects; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/TaskLogWsServerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/TaskLogWsServerTest.java index d1f7aa2a60..a6c60b94ac 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/TaskLogWsServerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/TaskLogWsServerTest.java @@ -22,8 +22,10 @@ import static org.mockito.Mockito.when; import ai.starwhale.mlops.common.IdConverter; -import ai.starwhale.mlops.schedule.k8s.log.CancellableJobLogK8sCollector; -import ai.starwhale.mlops.schedule.k8s.log.CancellableJobLogK8sCollectorFactory; +import ai.starwhale.mlops.domain.job.step.bo.Step; +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.schedule.impl.k8s.log.TaskLogK8sStreamingCollector; +import ai.starwhale.mlops.schedule.log.TaskLogCollectorFactory; import io.kubernetes.client.openapi.ApiException; import java.io.IOException; import java.util.concurrent.TimeUnit; @@ -32,33 +34,33 @@ import org.junit.jupiter.api.Test; public class TaskLogWsServerTest { - private CancellableJobLogK8sCollectorFactory factory; - private CancellableJobLogK8sCollector logK8sCollector; + + private TaskLogK8sStreamingCollector logK8sCollector; private IdConverter idConvertor; private Session session; @BeforeEach public void setup() { - factory = mock(CancellableJobLogK8sCollectorFactory.class); idConvertor = mock(IdConverter.class); session = mock(Session.class); - logK8sCollector = mock(CancellableJobLogK8sCollector.class); + logK8sCollector = mock(TaskLogK8sStreamingCollector.class); } @Test public void testOpen() throws IOException, ApiException, InterruptedException { var server = new TaskLogWsServer(); server.setIdConvertor(idConvertor); - server.setLogCollectorFactory(factory); + TaskLogCollectorFactory logCollectorFactory = mock(TaskLogCollectorFactory.class); + when(logCollectorFactory.streamingCollector(any())).thenReturn(logK8sCollector); + server.setTaskLogCollectorFactory(logCollectorFactory); final Long taskId = 1L; - when(factory.make(taskId.toString())).thenReturn(logK8sCollector); when(session.getId()).thenReturn("1"); when(idConvertor.revert(any())).thenReturn(taskId); - when(logK8sCollector.readLine()).thenReturn("foo"); + when(logK8sCollector.readLine(any())).thenReturn("foo"); server.onOpen(session, "1"); - verify(factory).make(taskId.toString()); + verify(logCollectorFactory).streamingCollector(Task.builder().id(taskId).step(new Step()).build()); TimeUnit.MILLISECONDS.sleep(500); - verify(logK8sCollector).readLine(); + verify(logK8sCollector).readLine(any()); } } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/PageTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/PageTest.java index ea230cbefd..76d39ec220 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/PageTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/PageTest.java @@ -59,8 +59,14 @@ import ai.starwhale.mlops.domain.task.status.TaskStatus; import ai.starwhale.mlops.domain.user.mapper.UserMapper; import ai.starwhale.mlops.domain.user.po.UserEntity; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; -import ai.starwhale.mlops.schedule.k8s.ResourceEventHolder; +import ai.starwhale.mlops.schedule.TaskCommandGetter; +import ai.starwhale.mlops.schedule.TaskRunningEnvBuilder; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinderSimpleImpl; +import ai.starwhale.mlops.schedule.impl.docker.log.TaskLogCollectorFactoryDocker; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder; import com.github.pagehelper.autoconfigure.PageHelperAutoConfiguration; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import java.util.UUID; @@ -78,20 +84,28 @@ @MybatisTest @ComponentScan( basePackages = { - "ai.starwhale.mlops.common", - "ai.starwhale.mlops.domain", - "ai.starwhale.mlops.datastore", - "ai.starwhale.mlops.reporting", - "ai.starwhale.mlops.resulting", - "ai.starwhale.mlops.configuration.security"}, + "ai.starwhale.mlops.common", + "ai.starwhale.mlops.domain", + "ai.starwhale.mlops.datastore", + "ai.starwhale.mlops.schedule.reporting", + "ai.starwhale.mlops.schedule.log", + "ai.starwhale.mlops.schedule.impl.docker", + "ai.starwhale.mlops.resulting", + "ai.starwhale.mlops.configuration.security"}, excludeFilters = { - @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, value = ModelServingService.class) + @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, value = {ModelServingService.class, + K8sClient.class}) } ) @ImportAutoConfiguration(PageHelperAutoConfiguration.class) -@Import({K8sJobTemplate.class, ResourceEventHolder.class, SimpleMeterRegistry.class}) +@Import({K8sJobTemplate.class, ResourceEventHolder.class, SimpleMeterRegistry.class, TaskRunningEnvBuilder.class, + TaskLogCollectorFactoryDocker.class, + DockerClientFinderSimpleImpl.class, + ContainerTaskMapper.class, + TaskCommandGetter.class}) @AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE) public class PageTest extends MySqlContainerHolder { + @Autowired private JobService jobService; @Autowired @@ -194,10 +208,10 @@ public void testModelList() { for (int i = 0; i < 19; i++) { var res = modelMapper.insert( ModelEntity.builder() - .modelName(String.format("model-%d", i)) - .projectId(projectId) - .ownerId(userId) - .build()); + .modelName(String.format("model-%d", i)) + .projectId(projectId) + .ownerId(userId) + .build()); assertTrue(res > 0); } @@ -225,10 +239,10 @@ public void testRuntimeList() { for (int i = 0; i < 19; i++) { var res = runtimeMapper.insert( RuntimeEntity.builder() - .runtimeName(String.format("rt-%d", i)) - .projectId(projectId) - .ownerId(userId) - .build()); + .runtimeName(String.format("rt-%d", i)) + .projectId(projectId) + .ownerId(userId) + .build()); assertTrue(res > 0); } @@ -256,10 +270,10 @@ public void testDatasetList() { for (int i = 0; i < 19; i++) { var res = datasetMapper.insert( DatasetEntity.builder() - .datasetName(String.format("ds-%d", i)) - .projectId(projectId) - .ownerId(userId) - .build()); + .datasetName(String.format("ds-%d", i)) + .projectId(projectId) + .ownerId(userId) + .build()); assertTrue(res > 0); } @@ -297,28 +311,28 @@ public void testJobList() { .versionName("model-version-1") .ownerId(userId) .jobs("mnist.evaluator:MNISTInference.cmp:\n" - + "- cls_name: ''\n" - + " concurrency: 1\n" - + " needs: []\n" - + " resources: []\n" - + " name: mnist.evaluator:MNISTInference.ppl\n" - + " replicas: 1\n" - + "- cls_name: ''\n" - + " concurrency: 1\n" - + " needs:\n" - + " - mnist.evaluator:MNISTInference.ppl\n" - + " resources:\n" - + " - type: cpu \n" - + " request: 0.1\n" - + " limit: 0.1\n" - + " - type: nvidia.com/gpu \n" - + " request: 1\n" - + " limit: 1\n" - + " - type: memory \n" - + " request: 1\n" - + " limit: 1\n" - + " name: mnist.evaluator:MNISTInference.cmp\n" - + " replicas: 1\n") + + "- cls_name: ''\n" + + " concurrency: 1\n" + + " needs: []\n" + + " resources: []\n" + + " name: mnist.evaluator:MNISTInference.ppl\n" + + " replicas: 1\n" + + "- cls_name: ''\n" + + " concurrency: 1\n" + + " needs:\n" + + " - mnist.evaluator:MNISTInference.ppl\n" + + " resources:\n" + + " - type: cpu \n" + + " request: 0.1\n" + + " limit: 0.1\n" + + " - type: nvidia.com/gpu \n" + + " request: 1\n" + + " limit: 1\n" + + " - type: memory \n" + + " request: 1\n" + + " limit: 1\n" + + " name: mnist.evaluator:MNISTInference.cmp\n" + + " replicas: 1\n") .build(); assertTrue(modelVersionMapper.insert(modelVersion) > 0); @@ -341,17 +355,17 @@ public void testJobList() { for (int i = 0; i < 19; i++) { var res = jobMapper.addJob( JobEntity.builder() - .name(String.format("job-%d", i)) - .jobUuid(String.format("uuid-%d", i)) - .modelVersionId(modelVersion.getId()) - .runtimeVersionId(runtimeVersion.getId()) - .type(JobType.EVALUATION) - .jobStatus(JobStatus.CREATED) - .resultOutputPath("path") - .projectId(projectId) - .ownerId(userId) - .isDeleted(0) - .build()); + .name(String.format("job-%d", i)) + .jobUuid(String.format("uuid-%d", i)) + .modelVersionId(modelVersion.getId()) + .runtimeVersionId(runtimeVersion.getId()) + .type(JobType.EVALUATION) + .jobStatus(JobStatus.CREATED) + .resultOutputPath("path") + .projectId(projectId) + .ownerId(userId) + .isDeleted(0) + .build()); assertTrue(res > 0); } @@ -385,28 +399,28 @@ public void testTaskList() { .versionName("model-version-1") .ownerId(userId) .jobs("mnist.evaluator:MNISTInference.cmp:\n" - + "- cls_name: ''\n" - + " concurrency: 1\n" - + " needs: []\n" - + " resources: []\n" - + " name: mnist.evaluator:MNISTInference.ppl\n" - + " replicas: 1\n" - + "- cls_name: ''\n" - + " concurrency: 1\n" - + " needs:\n" - + " - mnist.evaluator:MNISTInference.ppl\n" - + " resources:\n" - + " - type: cpu \n" - + " request: 0.1\n" - + " limit: 0.1\n" - + " - type: nvidia.com/gpu \n" - + " request: 1\n" - + " limit: 1\n" - + " - type: memory \n" - + " request: 1\n" - + " limit: 1\n" - + " name: mnist.evaluator:MNISTInference.cmp\n" - + " replicas: 1\n") + + "- cls_name: ''\n" + + " concurrency: 1\n" + + " needs: []\n" + + " resources: []\n" + + " name: mnist.evaluator:MNISTInference.ppl\n" + + " replicas: 1\n" + + "- cls_name: ''\n" + + " concurrency: 1\n" + + " needs:\n" + + " - mnist.evaluator:MNISTInference.ppl\n" + + " resources:\n" + + " - type: cpu \n" + + " request: 0.1\n" + + " limit: 0.1\n" + + " - type: nvidia.com/gpu \n" + + " request: 1\n" + + " limit: 1\n" + + " - type: memory \n" + + " request: 1\n" + + " limit: 1\n" + + " name: mnist.evaluator:MNISTInference.cmp\n" + + " replicas: 1\n") .build(); assertTrue(modelVersionMapper.insert(modelVersion) > 0); var job = JobEntity.builder() @@ -437,11 +451,11 @@ public void testTaskList() { for (int i = 0; i < 19; i++) { var res = taskMapper.addTask( TaskEntity.builder() - .stepId(step.getId()) - .taskUuid(String.format("uuid-%d", i)) - .taskStatus(TaskStatus.CREATED) - .taskRequest("request") - .build()); + .stepId(step.getId()) + .taskUuid(String.format("uuid-%d", i)) + .taskStatus(TaskStatus.CREATED) + .taskRequest("request") + .build()); assertTrue(res > 0); } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/dataset/DatasetServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/dataset/DatasetServiceTest.java index fe19030e08..271f8aeaea 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/dataset/DatasetServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/dataset/DatasetServiceTest.java @@ -16,8 +16,8 @@ package ai.starwhale.mlops.domain.dataset; -import static ai.starwhale.mlops.schedule.k8s.K8sJobTemplate.JOB_TYPE_LABEL; -import static ai.starwhale.mlops.schedule.k8s.K8sJobTemplate.WORKLOAD_TYPE_DATASET_BUILD; +import static ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate.JOB_TYPE_LABEL; +import static ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate.WORKLOAD_TYPE_DATASET_BUILD; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.hasItem; @@ -91,8 +91,8 @@ import ai.starwhale.mlops.exception.SwNotFoundException; import ai.starwhale.mlops.exception.SwProcessException; import ai.starwhale.mlops.exception.SwValidationException; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; import ai.starwhale.mlops.storage.StorageAccessService; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1Job; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/EnvServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/EnvServiceTest.java index f6be718474..c7b047da4c 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/EnvServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/EnvServiceTest.java @@ -22,7 +22,7 @@ import static org.hamcrest.Matchers.iterableWithSize; import static org.hamcrest.Matchers.notNullValue; -import ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobLoaderTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobLoaderTest.java index 449abc6ec5..2b07dae4ff 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobLoaderTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobLoaderTest.java @@ -31,6 +31,7 @@ import ai.starwhale.mlops.domain.task.status.WatchableTask; import ai.starwhale.mlops.domain.task.status.WatchableTaskFactory; import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import java.util.List; import java.util.Set; import org.junit.jupiter.api.BeforeEach; @@ -51,13 +52,16 @@ public class JobLoaderTest { SwTaskScheduler swTaskScheduler; + TaskReportReceiver taskReportReceiver; + @BeforeEach public void setUp() { mockJob = new JobMockHolder().mockJob(); jobHolder = mock(HotJobHolder.class); watchableTaskFactory = mock(WatchableTaskFactory.class); swTaskScheduler = mock(SwTaskScheduler.class); - jobLoader = new JobLoader(jobHolder, watchableTaskFactory, swTaskScheduler); + taskReportReceiver = mock(TaskReportReceiver.class); + jobLoader = new JobLoader(jobHolder, watchableTaskFactory, swTaskScheduler, taskReportReceiver); } @@ -71,7 +75,7 @@ public void testJobLoader() { jobLoader.load(mockJob, false); verify(jobHolder, times(1)).adopt(mockJob); verify(watchableTaskFactory, times(mockJob.getSteps().size())).wrapTasks(anyCollection()); - verify(swTaskScheduler).schedule(Set.of(readyTask)); + verify(swTaskScheduler).schedule(Set.of(readyTask), taskReportReceiver); verify(failedTask, times(0)).updateStatus(TaskStatus.READY); } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobServiceTest.java index 0ca4a7fa57..8a44ca3205 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/JobServiceTest.java @@ -108,7 +108,7 @@ public class JobServiceTest { private TrashService trashService; private SystemSettingService systemSettingService; private JobSpecParser jobSpecParser; - private SwTaskScheduler taskScheduler; + private SwTaskScheduler swTaskScheduler; @BeforeEach public void setUp() { @@ -140,13 +140,13 @@ public void setUp() { trashService = mock(TrashService.class); systemSettingService = mock(SystemSettingService.class); jobSpecParser = new JobSpecParser(); - taskScheduler = mock(SwTaskScheduler.class); + swTaskScheduler = mock(SwTaskScheduler.class); service = new JobService( taskMapper, jobConverter, jobBoConverter, runtimeService, jobSpliterator, hotJobHolder, projectService, jobDao, jobLoader, modelService, resultQuerier, datasetService, storagePathCoordinator, userService, mock(JobUpdateHelper.class), - trashService, systemSettingService, jobSpecParser, taskScheduler); + trashService, systemSettingService, jobSpecParser, swTaskScheduler); } @Test @@ -458,7 +458,7 @@ public void testExec() { when(hotJobHolder.ofIds(eq(List.of(task.getId())))).thenReturn(List.of(job)); var expected = ExecResponse.builder().stdout("stdout").stderr("stderr").build(); - when(taskScheduler.exec(eq(task), any())).thenReturn(new Future<>() { + when(swTaskScheduler.exec(eq(task), any())).thenReturn(new Future<>() { @Override public boolean cancel(boolean mayInterruptIfRunning) { return false; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java index 4d4f33f987..baa6d65a78 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java @@ -48,10 +48,10 @@ import ai.starwhale.mlops.domain.system.mapper.SystemSettingMapper; import ai.starwhale.mlops.domain.user.UserService; import ai.starwhale.mlops.domain.user.bo.User; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; -import ai.starwhale.mlops.schedule.k8s.ResourceEventHolder; -import ai.starwhale.mlops.schedule.k8s.ResourceOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import io.kubernetes.client.informer.ResourceEventHandler; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1ObjectMeta; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java index 27fcf79bcb..01e96fdcb8 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java @@ -78,9 +78,13 @@ import ai.starwhale.mlops.exception.SwValidationException; import ai.starwhale.mlops.exception.api.StarwhaleApiException; import ai.starwhale.mlops.schedule.SwTaskScheduler; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; -import ai.starwhale.mlops.schedule.k8s.ResourceEventHolder; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinderSimpleImpl; +import ai.starwhale.mlops.schedule.impl.docker.log.TaskLogCollectorFactoryDocker; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder; +import ai.starwhale.mlops.schedule.log.TaskLogSaver; import ai.starwhale.mlops.storage.LengthAbleInputStream; import ai.starwhale.mlops.storage.StorageAccessService; import ai.starwhale.mlops.storage.memory.StorageAccessServiceMemory; @@ -136,11 +140,14 @@ "ai.starwhale.mlops.common", "ai.starwhale.mlops.domain", "ai.starwhale.mlops.datastore", - "ai.starwhale.mlops.reporting", + "ai.starwhale.mlops.schedule.reporting", "ai.starwhale.mlops.resulting", "ai.starwhale.mlops.configuration.security"}, excludeFilters = {@ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, value = ModelServingService.class)}) -@Import({K8sJobTemplate.class, ResourceEventHolder.class, SimpleMeterRegistry.class}) +@Import({K8sJobTemplate.class, ResourceEventHolder.class, SimpleMeterRegistry.class, TaskLogSaver.class, + DockerClientFinderSimpleImpl.class, + ContainerTaskMapper.class, + TaskLogCollectorFactoryDocker.class}) @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class ModelServiceTest extends MySqlContainerHolder { @@ -786,7 +793,7 @@ public void testListModelVersionView() { var res = modelService.listModelVersionView("1"); assertEquals(2, res.size()); assertThat(res.get(1), allOf(hasProperty("projectName", is("starwhale")), - hasProperty("modelName", is("m")))); + hasProperty("modelName", is("m")))); assertThat(res.get(0), allOf(hasProperty("projectName", is("starwhale")), hasProperty("modelName", is("m1")), diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListenerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListenerTest.java index 80fa392e08..f2da967455 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListenerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeRegistryListenerTest.java @@ -25,7 +25,7 @@ import ai.starwhale.mlops.configuration.DockerSetting; import ai.starwhale.mlops.domain.system.SystemSetting; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1Secret; import javax.servlet.http.HttpServletResponse; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeServiceTest.java index 56efc46018..cfe2f13b87 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/runtime/RuntimeServiceTest.java @@ -89,8 +89,8 @@ import ai.starwhale.mlops.exception.SwProcessException; import ai.starwhale.mlops.exception.SwValidationException; import ai.starwhale.mlops.exception.api.StarwhaleApiException; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; import ai.starwhale.mlops.storage.LengthAbleInputStream; import ai.starwhale.mlops.storage.StorageAccessService; import io.kubernetes.client.openapi.ApiException; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/system/SystemSettingServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/system/SystemSettingServiceTest.java index d5d15ca459..7168a96696 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/system/SystemSettingServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/system/SystemSettingServiceTest.java @@ -148,7 +148,7 @@ public void testUpdate() { systemSettingService.getSystemSetting().getDockerSetting().getPassword()); Assertions.assertFalse(systemSettingService.getSystemSetting().getDockerSetting().isInsecure()); // get the custom resource pool - Assertions.assertEquals(1, systemSettingService.getResourcePools().size()); + Assertions.assertEquals(1, systemSettingService.getResourcePoolsFromWeb().size()); Assertions.assertEquals(3, systemSettingService.queryResourcePool("custom").getResources().size()); // get the default resource pool Assertions.assertEquals(ResourcePool.defaults(), systemSettingService.queryResourcePool("not_exists")); @@ -264,12 +264,12 @@ public void testStartWithoutData() throws Exception { @Test public void testUpdateResourcePools() { systemSettingService.updateSetting(YAML2); - Assertions.assertEquals(1, systemSettingService.getResourcePools().size()); + Assertions.assertEquals(1, systemSettingService.getResourcePoolsFromWeb().size()); verify(listener).onUpdate(systemSettingService.getSystemSetting()); var pool = ResourcePool.builder().name("foo").metadata(Map.of("bar", "baz")).build(); systemSettingService.updateResourcePools(List.of(pool)); - Assertions.assertEquals(1, systemSettingService.getResourcePools().size()); + Assertions.assertEquals(1, systemSettingService.getResourcePoolsFromWeb().size()); Assertions.assertEquals(pool, systemSettingService.queryResourcePool("foo")); verify(listener, times(2)).onUpdate(systemSettingService.getSystemSetting()); } @@ -279,12 +279,12 @@ public void testPrivateResourcePool() { // private pool should not be visible to any users var pool = ResourcePool.builder().name("foo").isPrivate(true).build(); systemSettingService.updateResourcePools(List.of(pool)); - Assertions.assertEquals(0, systemSettingService.getResourcePools().size()); + Assertions.assertEquals(0, systemSettingService.getResourcePoolsFromWeb().size()); // update pool to be visible to user 2 pool.setVisibleUserIds(List.of(2L)); systemSettingService.updateResourcePools(List.of(pool)); - Assertions.assertEquals(1, systemSettingService.getResourcePools().size()); + Assertions.assertEquals(1, systemSettingService.getResourcePoolsFromWeb().size()); Assertions.assertEquals(pool, systemSettingService.queryResourcePool("foo")); // get the rendered yaml, should contain visibleUserIds diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/task/TaskWatcherForScheduleTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/task/TaskWatcherForScheduleTest.java index 9d42c4c8db..f699959660 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/task/TaskWatcherForScheduleTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/task/TaskWatcherForScheduleTest.java @@ -28,6 +28,8 @@ import ai.starwhale.mlops.domain.task.status.TaskStatusMachine; import ai.starwhale.mlops.domain.task.status.watchers.TaskWatcherForSchedule; import ai.starwhale.mlops.schedule.SwTaskScheduler; +import ai.starwhale.mlops.schedule.log.TaskLogSaver; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import java.util.List; import java.util.UUID; import org.junit.jupiter.api.Test; @@ -39,12 +41,16 @@ public class TaskWatcherForScheduleTest { final TaskStatusMachine taskStatusMachine = new TaskStatusMachine(); + TaskLogSaver taskLogSaver = mock(TaskLogSaver.class); + + TaskReportReceiver taskReportReceiver = mock(TaskReportReceiver.class); + @Test public void testChangeAdopt() { - SwTaskScheduler taskScheduler = mock( + SwTaskScheduler swTaskScheduler = mock( SwTaskScheduler.class); - TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(taskScheduler, - taskStatusMachine, 0L); + TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(swTaskScheduler, + taskStatusMachine, 0L, taskLogSaver, taskReportReceiver); Task task = Task.builder() .id(1L) .uuid(UUID.randomUUID().toString()) @@ -52,38 +58,39 @@ public void testChangeAdopt() { .step(Step.builder().job(Job.builder().jobRuntime(JobRuntime.builder().build()).build()).build()) .build(); taskWatcherForSchedule.onTaskStatusChange(task, TaskStatus.CREATED); - verify(taskScheduler).schedule(List.of(task)); - verify(taskScheduler, times(0)).stop(List.of(task)); + verify(swTaskScheduler).schedule(List.of(task), taskReportReceiver); + verify(swTaskScheduler, times(0)).stop(List.of(task)); } @Test public void testChangeStopSchedule() { - SwTaskScheduler taskScheduler = mock( + SwTaskScheduler swTaskScheduler = mock( SwTaskScheduler.class); - TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(taskScheduler, - taskStatusMachine, 100L); + TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(swTaskScheduler, + taskStatusMachine, 100L, taskLogSaver, taskReportReceiver); Task task = Task.builder() .id(1L) .uuid(UUID.randomUUID().toString()) .status(TaskStatus.PAUSED) + .startTime(System.currentTimeMillis()) .step(Step.builder().job(Job.builder().jobRuntime(JobRuntime.builder() .build()).build()).build()) .build(); taskWatcherForSchedule.onTaskStatusChange(task, TaskStatus.READY); - verify(taskScheduler).stop(List.of(task)); + verify(swTaskScheduler).stop(List.of(task)); task.updateStatus(TaskStatus.CANCELED); taskWatcherForSchedule.onTaskStatusChange(task, TaskStatus.READY); - verify(taskScheduler, times(0)).schedule(List.of(task)); + verify(swTaskScheduler, times(0)).schedule(List.of(task), taskReportReceiver); // canceled do not trigger schedule - verify(taskScheduler).stop(List.of(task)); + verify(swTaskScheduler).stop(List.of(task)); } @Test public void testDelayStopSchedule() throws InterruptedException { - SwTaskScheduler taskScheduler = mock( + SwTaskScheduler swTaskScheduler = mock( SwTaskScheduler.class); - TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(taskScheduler, - taskStatusMachine, 1L); + TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(swTaskScheduler, + taskStatusMachine, 1L, taskLogSaver, taskReportReceiver); Task task = Task.builder() .id(1L) .uuid(UUID.randomUUID().toString()) @@ -96,18 +103,18 @@ public void testDelayStopSchedule() throws InterruptedException { task.updateStatus(TaskStatus.SUCCESS); taskWatcherForSchedule.onTaskStatusChange(task, TaskStatus.RUNNING); taskWatcherForSchedule.processTaskDeletion(); - verify(taskScheduler, times(0)).stop(List.of(task)); + verify(swTaskScheduler, times(0)).stop(List.of(task)); Thread.sleep(2000); taskWatcherForSchedule.processTaskDeletion(); - verify(taskScheduler, times(1)).stop(List.of(task, task)); + verify(swTaskScheduler, times(1)).stop(List.of(task, task)); } @Test public void testChangeIgnore() { - SwTaskScheduler taskScheduler = mock( + SwTaskScheduler swTaskScheduler = mock( SwTaskScheduler.class); - TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(taskScheduler, - taskStatusMachine, 0L); + TaskWatcherForSchedule taskWatcherForSchedule = new TaskWatcherForSchedule(swTaskScheduler, + taskStatusMachine, 0L, taskLogSaver, taskReportReceiver); Task task = Task.builder() .id(1L) .uuid(UUID.randomUUID().toString()) @@ -118,14 +125,15 @@ public void testChangeIgnore() { taskWatcherForSchedule.onTaskStatusChange(task, TaskStatus.SUCCESS); taskWatcherForSchedule.onTaskStatusChange(task, TaskStatus.FAIL); - verify(taskScheduler, times(0)).schedule(List.of(task)); - verify(taskScheduler, times(0)).stop(List.of(task)); + verify(swTaskScheduler, times(0)).schedule(List.of(task), taskReportReceiver); + verify(swTaskScheduler, times(0)).stop(List.of(task)); } @Test public void testCancelling() { - SwTaskScheduler taskScheduler = mock(SwTaskScheduler.class); - var taskWatcherForSchedule = new TaskWatcherForSchedule(taskScheduler, taskStatusMachine, 0L); + SwTaskScheduler swTaskScheduler = mock(SwTaskScheduler.class); + var taskWatcherForSchedule = new TaskWatcherForSchedule(swTaskScheduler, taskStatusMachine, 0L, taskLogSaver, + taskReportReceiver); Task task = Task.builder() .id(1L) .uuid(UUID.randomUUID().toString()) @@ -133,7 +141,7 @@ public void testCancelling() { .step(Step.builder().job(Job.builder().jobRuntime(JobRuntime.builder().build()).build()).build()) .build(); taskWatcherForSchedule.onTaskStatusChange(task, TaskStatus.RUNNING); - verify(taskScheduler, times(0)).schedule(List.of(task)); - verify(taskScheduler, times(1)).stop(List.of(task)); + verify(swTaskScheduler, times(0)).schedule(List.of(task), taskReportReceiver); + verify(swTaskScheduler, times(1)).stop(List.of(task)); } } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/task/log/TaskLogK8sCollectorTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/task/log/TaskLogK8sCollectorFactoryTest.java similarity index 51% rename from server/controller/src/test/java/ai/starwhale/mlops/domain/task/log/TaskLogK8sCollectorTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/domain/task/log/TaskLogK8sCollectorFactoryTest.java index fc7bbfdf61..f0acef64c0 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/task/log/TaskLogK8sCollectorTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/task/log/TaskLogK8sCollectorFactoryTest.java @@ -16,53 +16,65 @@ package ai.starwhale.mlops.domain.task.log; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import ai.starwhale.mlops.domain.task.bo.ResultPath; import ai.starwhale.mlops.domain.task.bo.Task; -import ai.starwhale.mlops.domain.task.status.watchers.log.TaskLogK8sCollector; -import ai.starwhale.mlops.schedule.k8s.K8sClient; -import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate; -import ai.starwhale.mlops.storage.StorageAccessService; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.log.TaskLogK8sCollectorFactory; +import ai.starwhale.mlops.schedule.log.TaskLogOfflineCollector; +import ai.starwhale.mlops.schedule.log.TaskLogStreamingCollector; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1ObjectMeta; import io.kubernetes.client.openapi.models.V1Pod; +import io.kubernetes.client.openapi.models.V1PodList; +import io.kubernetes.client.openapi.models.V1PodStatus; +import java.io.ByteArrayInputStream; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import okhttp3.Call; +import okhttp3.Response; +import okhttp3.ResponseBody; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -public class TaskLogK8sCollectorTest { - - StorageAccessService storageService; +public class TaskLogK8sCollectorFactoryTest { K8sClient k8sClient; K8sJobTemplate k8sJobTemplate; - TaskLogK8sCollector taskLogK8sCollector; + TaskLogK8sCollectorFactory taskLogK8sCollector; @BeforeEach public void setup() { - storageService = mock(StorageAccessService.class); k8sClient = mock(K8sClient.class); k8sJobTemplate = mock(K8sJobTemplate.class); - taskLogK8sCollector = new TaskLogK8sCollector(storageService, k8sClient, k8sJobTemplate); + taskLogK8sCollector = new TaskLogK8sCollectorFactory(k8sClient, k8sJobTemplate); } @Test public void testNormal() throws IOException, ApiException { String log = "this is log"; - V1Pod v1Pod = new V1Pod().metadata(new V1ObjectMeta().name("x")); + V1Pod v1Pod = new V1Pod().metadata(new V1ObjectMeta().name("x")).status(new V1PodStatus().phase("Running")); when(k8sClient.podOfJob(anyString())).thenReturn(v1Pod); when(k8sClient.logOfPod(eq(v1Pod), anyList())).thenReturn(log); - taskLogK8sCollector.collect(Task.builder().id(1L).resultRootPath(new ResultPath("root")).build()); - verify(storageService).put(anyString(), eq(log.getBytes(StandardCharsets.UTF_8))); + when(k8sClient.getPodsByJobName(any())).thenReturn(new V1PodList().addItemsItem(v1Pod)); + Call call = mock(Call.class); + Response response = mock(Response.class); + when(call.execute()).thenReturn(response); + ResponseBody responseBody = mock(ResponseBody.class); + when(response.body()).thenReturn(responseBody); + when(responseBody.byteStream()).thenReturn(new ByteArrayInputStream(new byte[]{})); + when(k8sClient.readLog("x", "worker", true)).thenReturn(call); + Assertions.assertInstanceOf(TaskLogOfflineCollector.class, taskLogK8sCollector.offlineCollector(new Task())); + Assertions.assertInstanceOf(TaskLogStreamingCollector.class, + taskLogK8sCollector.streamingCollector(Task.builder().id(1L).build())); } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/task/status/TaskStatusMachineTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/task/status/TaskStatusMachineTest.java new file mode 100644 index 0000000000..b9c03e2ebe --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/task/status/TaskStatusMachineTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.domain.task.status; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class TaskStatusMachineTest { + + @Test + public void testCancelling() { + TaskStatusMachine taskStatusMachine = new TaskStatusMachine(); + Assertions.assertEquals(TaskStatus.CANCELLING, + taskStatusMachine.transfer(TaskStatus.CANCELLING, TaskStatus.RUNNING)); + Assertions.assertEquals(TaskStatus.CANCELLING, + taskStatusMachine.transfer(TaskStatus.CANCELLING, TaskStatus.PREPARING)); + Assertions.assertEquals(TaskStatus.CANCELED, + taskStatusMachine.transfer(TaskStatus.CANCELLING, TaskStatus.SUCCESS)); + Assertions.assertEquals(TaskStatus.CANCELED, + taskStatusMachine.transfer(TaskStatus.CANCELLING, TaskStatus.FAIL)); + Assertions.assertEquals(TaskStatus.CANCELED, + taskStatusMachine.transfer(TaskStatus.CANCELLING, TaskStatus.CANCELED)); + } + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/UpgradeServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/UpgradeServiceTest.java index e84cfcc8cb..2eaea5e14c 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/UpgradeServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/UpgradeServiceTest.java @@ -29,7 +29,7 @@ import ai.starwhale.mlops.domain.upgrade.bo.Version; import ai.starwhale.mlops.domain.upgrade.step.UpgradeStepManager; import ai.starwhale.mlops.exception.SwValidationException; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import io.kubernetes.client.openapi.models.V1Container; import io.kubernetes.client.openapi.models.V1Deployment; import io.kubernetes.client.openapi.models.V1DeploymentList; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/step/UpgradeStepTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/step/UpgradeStepTest.java index ac61ab2eee..e73902664b 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/step/UpgradeStepTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/upgrade/step/UpgradeStepTest.java @@ -28,7 +28,7 @@ import ai.starwhale.mlops.domain.upgrade.bo.Upgrade; import ai.starwhale.mlops.domain.upgrade.bo.Upgrade.Status; import ai.starwhale.mlops.domain.upgrade.bo.Version; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import ai.starwhale.mlops.storage.StorageAccessService; import io.kubernetes.client.openapi.models.V1Pod; import java.util.List; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/reporting/TaskModifyReceiverImpTest.java b/server/controller/src/test/java/ai/starwhale/mlops/reporting/TaskReportReceiverImpTest.java similarity index 88% rename from server/controller/src/test/java/ai/starwhale/mlops/reporting/TaskModifyReceiverImpTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/reporting/TaskReportReceiverImpTest.java index aa5d56dec2..b0a71e5395 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/reporting/TaskModifyReceiverImpTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/reporting/TaskReportReceiverImpTest.java @@ -25,6 +25,8 @@ import ai.starwhale.mlops.domain.task.bo.Task; import ai.starwhale.mlops.domain.task.mapper.TaskMapper; import ai.starwhale.mlops.domain.task.status.TaskStatus; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.SimpleTaskReportReceiver; import java.util.Collections; import java.util.List; import java.util.Set; @@ -33,11 +35,11 @@ import org.junit.jupiter.api.Test; /** - * test for {@link SimpleTaskModifyReceiver} + * test for {@link SimpleTaskReportReceiver} */ -public class TaskModifyReceiverImpTest { +public class TaskReportReceiverImpTest { - SimpleTaskModifyReceiver taskStatusReceiver; + SimpleTaskReportReceiver taskStatusReceiver; HotJobHolder jobHolder; TaskMapper taskMapper; @@ -46,7 +48,7 @@ public class TaskModifyReceiverImpTest { public void setup() { jobHolder = mock(HotJobHolder.class); taskMapper = mock(TaskMapper.class); - taskStatusReceiver = new SimpleTaskModifyReceiver(jobHolder, taskMapper); + taskStatusReceiver = new SimpleTaskReportReceiver(jobHolder, taskMapper); } @Test diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/TaskRunningEnvBuilderTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/TaskRunningEnvBuilderTest.java new file mode 100644 index 0000000000..59293954ee --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/TaskRunningEnvBuilderTest.java @@ -0,0 +1,170 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.starwhale.mlops.configuration.RunTimeProperties; +import ai.starwhale.mlops.configuration.RunTimeProperties.Pypi; +import ai.starwhale.mlops.configuration.RunTimeProperties.RunConfig; +import ai.starwhale.mlops.configuration.security.TaskTokenValidator; +import ai.starwhale.mlops.domain.dataset.bo.DataSet; +import ai.starwhale.mlops.domain.job.JobType; +import ai.starwhale.mlops.domain.job.bo.Job; +import ai.starwhale.mlops.domain.job.bo.JobRuntime; +import ai.starwhale.mlops.domain.job.spec.Env; +import ai.starwhale.mlops.domain.job.step.bo.Step; +import ai.starwhale.mlops.domain.model.Model; +import ai.starwhale.mlops.domain.project.bo.Project; +import ai.starwhale.mlops.domain.runtime.RuntimeResource; +import ai.starwhale.mlops.domain.runtime.RuntimeService; +import ai.starwhale.mlops.domain.system.resourcepool.bo.ResourcePool; +import ai.starwhale.mlops.domain.task.bo.ResultPath; +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.domain.task.bo.TaskRequest; +import ai.starwhale.mlops.domain.task.status.TaskStatus; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class TaskRunningEnvBuilderTest { + + static final String CONDARC = "channels:\n" + + " - defaults\n" + + "show_channel_urls: true\n" + + "default_channels:\n" + + " - http://nexus.starwhale.ai/repository/anaconda/main\n" + + " - http://nexus.starwhale.ai/repository/anaconda/r\n" + + " - http://nexus.starwhale.ai/repository/anaconda/msys2\n" + + "custom_channels:\n" + + " conda-forge: http://nexus.starwhale.ai/repository/conda-cloud\n" + + " nvidia: http://nexus.starwhale.ai/repository/conda-cloud\n" + + "ssl_verify: false\n" + + "default_threads: 10"; + Map expectedEnvs = new HashMap() { + { + put("SW_RUNTIME_PYTHON_VERSION", "3.10"); + put("SW_VERSION", "0.5.1"); + put("SW_ENV", "test"); + put("SW_PROJECT", "project"); + put("DATASET_CONSUMPTION_BATCH_SIZE", "50"); + put("SW_DATASET_URI", "http://instanceUri/project/103/dataset/swdsN/version/swdsV"); + put("SW_MODEL_URI", "http://instanceUri/project/101/model/swmpN/version/swmpV"); + put("SW_RUNTIME_URI", "http://instanceUri/project/102/runtime/swrtN/version/swrtV"); + put("SW_MODEL_VERSION", "swmpN/version/swmpV"); + put("SW_RUNTIME_VERSION", "swrtN/version/swrtV"); + put("SW_TASK_INDEX", "1"); + put("SW_TASK_NUM", "1"); + put("SW_PYPI_INDEX_URL", "indexU"); + put("SW_PYPI_EXTRA_INDEX_URL", "extraU"); + put("SW_PYPI_TRUSTED_HOST", "trustedH"); + put("SW_JOB_VERSION", "juuid"); + put("SW_TOKEN", "tt"); + put("SW_INSTANCE_URI", "http://instanceUri"); + put("SW_TASK_STEP", "cmp"); + put("NVIDIA_VISIBLE_DEVICES", ""); + put("SW_PYPI_RETRIES", "1"); + put("SW_PYPI_TIMEOUT", "2"); + put("SW_RUN_HANDLER", null); + put("SW_DEV_TOKEN", null); + put("SW_DEV_PORT", "8000"); + put("SW_CONDA_CONFIG", "channels:\n" + + " - defaults\n" + + "show_channel_urls: true\n" + + "default_channels:\n" + + " - http://nexus.starwhale.ai/repository/anaconda/main\n" + + " - http://nexus.starwhale.ai/repository/anaconda/r\n" + + " - http://nexus.starwhale.ai/repository/anaconda/msys2\n" + + "custom_channels:\n" + + " conda-forge: http://nexus.starwhale.ai/repository/conda-cloud\n" + + " nvidia: http://nexus.starwhale.ai/repository/conda-cloud\n" + + "ssl_verify: false\n" + + "default_threads: 10"); + } + }; + + + @Test + public void testEnvs() { + RunTimeProperties runTimeProperties = new RunTimeProperties( + "", new RunConfig(), new RunConfig(), new Pypi("indexU", "extraU", "trustedH", 1, 2), CONDARC); + TaskTokenValidator taskTokenValidator = mock(TaskTokenValidator.class); + when(taskTokenValidator.getTaskToken(any(), any())).thenReturn("tt"); + TaskRunningEnvBuilder builder = new TaskRunningEnvBuilder("http://instanceUri", + 8000, + 50, + runTimeProperties, taskTokenValidator); + assertMapEquals(expectedEnvs, builder.buildCoreContainerEnvs(mockTask(true))); + } + + private void assertMapEquals(Map expectedEnvs, Map actualEnv) { + Assertions.assertEquals(expectedEnvs.size(), actualEnv.size()); + expectedEnvs.forEach((k, v) -> Assertions.assertEquals(v, actualEnv.get(k))); + } + + private Task mockTask(boolean devMode) { + Job job = Job.builder() + .id(1L) + .model(Model.builder().name("swmpN").version("swmpV").projectId(101L).build()) + .jobRuntime(JobRuntime.builder() + .name("swrtN") + .version("swrtV") + .image("imageRT") + .storagePath("path_rt") + .projectId(102L) + .manifest(new RuntimeService.RuntimeManifest( + "", + new RuntimeService.RuntimeManifest.Environment("3.10", + new RuntimeService.RuntimeManifest.Lock("0.5.1")), null)) + .build()) + .type(JobType.EVALUATION) + .devMode(devMode) + .uuid("juuid") + .dataSets( + List.of(DataSet.builder() + .indexTable("it").path("swds_path").name("swdsN").version("swdsV") + .size(300L).projectId(103L).build())) + .stepSpec("") + .resourcePool(ResourcePool.builder().name("bj01").build()) + .project(Project.builder().name("project").id(100L).build()) + .build(); + Step step = new Step(); + step.setId(1L); + step.setName("cmp"); + step.setJob(job); + step.setResourcePool(job.getResourcePool()); + return Task.builder() + .id(1L) + .taskRequest(TaskRequest.builder().index(1).total(2).build()) + .step(step) + .resultRootPath(new ResultPath("task")) + .uuid("uuid") + .status(TaskStatus.READY) + .taskRequest(TaskRequest.builder() + .index(1) + .total(1) + .runtimeResources(List.of(new RuntimeResource("cpu", 1f, 1f))) + .env(List.of(Env.builder().name("SW_ENV").value("test").build())) + .build()) + .build(); + } + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/ContainerTaskMapperTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/ContainerTaskMapperTest.java new file mode 100644 index 0000000000..7626efc049 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/ContainerTaskMapperTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.starwhale.mlops.domain.job.step.bo.Step; +import ai.starwhale.mlops.domain.task.bo.Task; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Container; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ContainerTaskMapperTest { + + DockerClientFinder dockerClientFinder; + DockerClient dockerClient; + + LocalDockerTool localDockerTool = new LocalDockerTool(); + + @BeforeEach + public void setup() { + dockerClientFinder = mock(DockerClientFinder.class); + this.dockerClient = localDockerTool.getDockerClient(); + when(dockerClientFinder.findProperDockerClient(any())).thenReturn(dockerClient); + } + + static final String IMAGE_HELLO_WORLD = "hello-world:linux"; + + @Test + public void testContainerExistedOne() throws InterruptedException { + ContainerTaskMapper cm = new ContainerTaskMapper(dockerClientFinder); + String containerName = "sw-ut-container"; + try (var tc = localDockerTool.startContainerBlocking(IMAGE_HELLO_WORLD, containerName, + Map.of("starwhale-task-id", "1"), null, null)) { + Task task = Task.builder().id(1L).step(new Step()).build(); + Container container = cm.containerOfTask(task); + Assertions.assertEquals("/" + containerName, container.getNames()[0]); + Assertions.assertEquals(1L, cm.taskIfOfContainer(container)); + } + + + } + + @Test + public void testContainerNotExisted() throws InterruptedException { + ContainerTaskMapper cm = new ContainerTaskMapper(dockerClientFinder); + Task task = Task.builder().id(123344L).step(new Step()).build(); + Container container = cm.containerOfTask(task); + Assertions.assertNull(container); + } +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/HostResourceConfigBuilderTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/HostResourceConfigBuilderTest.java new file mode 100644 index 0000000000..29775f88d1 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/HostResourceConfigBuilderTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import ai.starwhale.mlops.domain.runtime.RuntimeResource; +import com.github.dockerjava.api.command.InspectContainerResponse; +import com.github.dockerjava.api.model.HostConfig; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class HostResourceConfigBuilderTest { + + static final String IMAGE_BUSY_BOX = "busybox:latest"; + + @Test + public void testRun() { + LocalDockerTool localDockerTool = new LocalDockerTool(); + String containerName = "sw-ut-HostResourceConfigBuilderTest" + System.currentTimeMillis(); + HostResourceConfigBuilder builder = new HostResourceConfigBuilder(); + HostConfig hostConfig = builder.build( + List.of(RuntimeResource.builder().type("memory").limit(1024 * 1024 * 10f).build(), + RuntimeResource.builder().type("cpu").limit(1f).build())); + try (var tc = localDockerTool.startContainerBlocking(IMAGE_BUSY_BOX, containerName, Map.of(), + new String[]{"tail", "-f", "/dev/null"}, hostConfig)) { + InspectContainerResponse response = tc.dockerClient.inspectContainerCmd(tc.name).exec(); + Assertions.assertEquals(1L, response.getHostConfig().getCpuCount()); + Assertions.assertEquals(1024 * 1024 * 10L, response.getHostConfig().getMemory()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + } + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/LocalDockerTool.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/LocalDockerTool.java new file mode 100644 index 0000000000..6b30f283be --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/LocalDockerTool.java @@ -0,0 +1,141 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.async.ResultCallback; +import com.github.dockerjava.api.command.CreateContainerCmd; +import com.github.dockerjava.api.command.CreateContainerResponse; +import com.github.dockerjava.api.command.StartContainerCmd; +import com.github.dockerjava.api.model.HostConfig; +import com.github.dockerjava.api.model.PullResponseItem; +import com.github.dockerjava.core.DefaultDockerClientConfig; +import com.github.dockerjava.core.DockerClientImpl; +import com.github.dockerjava.httpclient5.ApacheDockerHttpClient; +import com.github.dockerjava.transport.DockerHttpClient; +import java.io.Closeable; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; + + +@Slf4j +public class LocalDockerTool { + + final DockerClient dockerClient; + + public LocalDockerTool() { + DefaultDockerClientConfig clientConfig = DefaultDockerClientConfig.createDefaultConfigBuilder() + .withDockerHost("unix:///var/run/docker.sock").build(); + DockerHttpClient httpClient = new ApacheDockerHttpClient.Builder() + .dockerHost(clientConfig.getDockerHost()) + .sslConfig(clientConfig.getSSLConfig()) + .maxConnections(100) + .connectionTimeout(Duration.ofSeconds(30)) + .responseTimeout(Duration.ofSeconds(45)) + .build(); + this.dockerClient = DockerClientImpl.getInstance(clientConfig, httpClient); + } + + public DockerClient getDockerClient() { + return this.dockerClient; + } + + public TempDockerContainer startContainerBlocking(String image, String containerName, Map labels, + String[] cmds, HostConfig hostConfig) + throws InterruptedException { + Object lock = new Object(); + List rl = new ArrayList<>(); + dockerClient.pullImageCmd(image).exec(new ResultCallback() { + @Override + public void onStart(Closeable closeable) { + } + + @Override + public void onNext(PullResponseItem object) { + + } + + @Override + public void onError(Throwable throwable) { + + synchronized (lock) { + log.error("pulling image {} failed", image, throwable); + lock.notifyAll(); + } + + } + + @Override + public void onComplete() { + CreateContainerCmd createContainerCmd = dockerClient.createContainerCmd( + image + ) + .withName(containerName); + if (!CollectionUtils.isEmpty(labels)) { + createContainerCmd.withLabels(labels); + } + if (null != cmds && cmds.length > 0) { + createContainerCmd.withCmd(cmds); + } + if (null != hostConfig) { + createContainerCmd.withHostConfig(hostConfig); + } + CreateContainerResponse exec = createContainerCmd.exec(); + StartContainerCmd startContainerCmd = dockerClient.startContainerCmd(exec.getId()); + startContainerCmd.exec(); + synchronized (lock) { + lock.notifyAll(); + } + } + + @Override + public void close() throws IOException { + synchronized (lock) { + lock.notifyAll(); + } + } + }); + + synchronized (lock) { + lock.wait(); + } + return new TempDockerContainer(this.dockerClient, containerName); + } + + public static class TempDockerContainer implements AutoCloseable { + + final DockerClient dockerClient; + + final String name; + + public TempDockerContainer(DockerClient dockerClient, String name) { + this.dockerClient = dockerClient; + this.name = name; + } + + @Override + public void close() { + dockerClient.removeContainerCmd(name).withForce(true).withRemoveVolumes(true).exec(); + } + } + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/SwTaskSchedulerDockerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/SwTaskSchedulerDockerTest.java new file mode 100644 index 0000000000..ee97cb4f4d --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/SwTaskSchedulerDockerTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.starwhale.mlops.domain.job.bo.Job; +import ai.starwhale.mlops.domain.job.bo.JobRuntime; +import ai.starwhale.mlops.domain.job.step.bo.Step; +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.domain.task.status.TaskStatus; +import ai.starwhale.mlops.schedule.TaskCommandGetter; +import ai.starwhale.mlops.schedule.TaskCommandGetter.TaskCommand; +import ai.starwhale.mlops.schedule.TaskRunningEnvBuilder; +import ai.starwhale.mlops.schedule.impl.docker.reporting.DockerTaskReporter; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Container; +import com.github.dockerjava.core.DefaultDockerClientConfig; +import com.github.dockerjava.core.DockerClientImpl; +import com.github.dockerjava.httpclient5.ApacheDockerHttpClient; +import com.github.dockerjava.transport.DockerHttpClient; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +public class SwTaskSchedulerDockerTest { + + static final String IMAGE_BUSY_BOX = "busybox:latest"; + DockerClientFinder dockerClientFinder; + ContainerTaskMapper containerTaskMapper; + DockerTaskReporter dockerTaskReporter; + ExecutorService cmdExecThreadPool; + TaskRunningEnvBuilder taskRunningEnvBuilder; + String network; + String nodeIp; + SwTaskSchedulerDocker swTaskSchedulerDocker; + TaskCommandGetter taskCommandGetter; + DockerClient dockerClient; + + TaskReportReceiver taskReportReceiver; + + @BeforeEach + public void setup() { + taskReportReceiver = mock(TaskReportReceiver.class); + dockerClientFinder = mock(DockerClientFinder.class); + dockerClient = localDocker(); + when(dockerClientFinder.findProperDockerClient(any())).thenReturn(dockerClient); + containerTaskMapper = mock(ContainerTaskMapper.class); + String containerName = "sw-ut-busybox"; + Container container = mock(Container.class); + when(container.getId()).thenReturn(containerName); + when(container.getState()).thenReturn("exited"); + when(containerTaskMapper.containerOfTask(any())).thenReturn(container); + when(containerTaskMapper.containerName(any())).thenReturn(containerName); + dockerTaskReporter = mock(DockerTaskReporter.class); + cmdExecThreadPool = Executors.newCachedThreadPool(); + ; + taskRunningEnvBuilder = mock(TaskRunningEnvBuilder.class); + taskCommandGetter = mock(TaskCommandGetter.class); + network = "host"; + nodeIp = "127.1.0.2"; + swTaskSchedulerDocker = new SwTaskSchedulerDocker( + dockerClientFinder, + containerTaskMapper, + dockerTaskReporter, + cmdExecThreadPool, + taskRunningEnvBuilder, + network, + nodeIp, + taskCommandGetter, new HostResourceConfigBuilder()); + try { + dockerClient.removeContainerCmd(containerName).withForce(true).exec(); + } catch (Exception e) { + System.out.println("sw-ut-busybox may not exist"); + } + + + } + + @Test + public void testExec() throws ExecutionException, InterruptedException { + Task task = Task.builder() + .id(1L) + .step(Step.builder() + .job(Job.builder().jobRuntime(JobRuntime.builder().image(IMAGE_BUSY_BOX).build()).build()) + .build() + ) + .build(); + testSchedule(task); + Object lock = new Object(); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + List reportedTasks = (List) invocationOnMock.getArguments()[0]; + if (reportedTasks.get(0).getStatus() == TaskStatus.RUNNING) { + synchronized (lock) { + lock.notifyAll(); + } + } + return null; + } + }).when(taskReportReceiver).receive(anyList()); + synchronized (lock) { + //a timeout is set in case there are errors in the test where there is no chance a container is started. + lock.wait(1000 * 60 * 5); + } + Future future = swTaskSchedulerDocker.exec(task, "echo", "$ENV_NAME"); + String[] strings = future.get(); + Assertions.assertEquals("env_value", strings[0].replace("STDOUT:", "").strip()); + testStop(task); + } + + private void testStop(Task task) { + swTaskSchedulerDocker.stop(Set.of(task)); + var lc = dockerClient.listContainersCmd().withNameFilter(Set.of("sw-ut-busybox")).withShowAll(true).exec(); + Assertions.assertEquals(0, lc.size()); + + } + + private void testSchedule(Task task) { + when(taskCommandGetter.getCmd(any())).thenReturn( + new TaskCommand(new String[]{"tail", "-f", "/dev/null"}, null)); + when(taskRunningEnvBuilder.buildCoreContainerEnvs(any())).thenReturn(Map.of("ENV_NAME", "env_value")); + + swTaskSchedulerDocker.schedule(Set.of(task), taskReportReceiver); + + } + + DockerClient localDocker() { + DefaultDockerClientConfig clientConfig = DefaultDockerClientConfig.createDefaultConfigBuilder() + .withDockerHost("unix:///var/run/docker.sock").build(); + DockerHttpClient httpClient = new ApacheDockerHttpClient.Builder() + .dockerHost(clientConfig.getDockerHost()) + .sslConfig(clientConfig.getSSLConfig()) + .maxConnections(100) + .connectionTimeout(Duration.ofSeconds(30)) + .responseTimeout(Duration.ofSeconds(45)) + .build(); + return DockerClientImpl.getInstance(clientConfig, httpClient); + } + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogOfflineCollectorDockerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogOfflineCollectorDockerTest.java new file mode 100644 index 0000000000..844f4d7e51 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogOfflineCollectorDockerTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.log; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.starwhale.mlops.domain.job.step.bo.Step; +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinder; +import ai.starwhale.mlops.schedule.impl.docker.LocalDockerTool; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Container; +import io.vavr.Tuple2; +import java.util.Map; +import java.util.UUID; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + + +@Slf4j +public class TaskLogOfflineCollectorDockerTest { + + static final String IMAGE_HELLO_WORLD = "hello-world:linux"; + static final String OUT_PUT_HELLO_WORLD = "STDOUT: \n" + + "STDOUT: Hello from Docker!\n" + + "STDOUT: This message shows that your installation appears to be working correctly.\n" + + "STDOUT: \n" + + "STDOUT: To generate this message, Docker took the following steps:\n" + + "STDOUT: 1. The Docker client contacted the Docker daemon.\n" + + "STDOUT: 2. The Docker daemon pulled the \"hello-world\" image from the Docker Hub.\n" + + "STDOUT: (amd64)\n" + + "STDOUT: 3. The Docker daemon created a new container from that image which runs the\n" + + "STDOUT: executable that produces the output you are currently reading.\n" + + "STDOUT: 4. The Docker daemon streamed that output to the Docker client, which sent it\n" + + "STDOUT: to your terminal.\n" + + "STDOUT: \n" + + "STDOUT: To try something more ambitious, you can run an Ubuntu container with:\n" + + "STDOUT: $ docker run -it ubuntu bash\n" + + "STDOUT: \n" + + "STDOUT: Share images, automate workflows, and more with a free Docker ID:\n" + + "STDOUT: https://hub.docker.com/\n" + + "STDOUT: \n" + + "STDOUT: For more examples and ideas, visit:\n" + + "STDOUT: https://docs.docker.com/get-started/\n" + + "STDOUT: \n"; + static String containerName = UUID.randomUUID().toString(); + DockerClient dockerClient; + LocalDockerTool localDockerTool = new LocalDockerTool(); + DockerClientFinder dockerClientFinder; + TaskLogOfflineCollectorDocker logOfflineCollectorDocker; + + @BeforeEach + public void setup() { + this.dockerClient = localDockerTool.getDockerClient(); + dockerClientFinder = mock(DockerClientFinder.class); + when(dockerClientFinder.findProperDockerClient(any())).thenReturn(this.dockerClient); + ContainerTaskMapper containerTaskMapper = mock(ContainerTaskMapper.class); + Task task = Task.builder().id(1L).step(new Step()).build(); + + Container container = mock(Container.class); + when(container.getId()).thenReturn(containerName); + when(container.getNames()).thenReturn(new String[]{containerName}); + when(containerTaskMapper.containerOfTask(task)).thenReturn(container); + when(containerTaskMapper.taskIfOfContainer(container)).thenReturn(1L); + logOfflineCollectorDocker = new TaskLogOfflineCollectorDocker(task, dockerClientFinder, containerTaskMapper); + } + + @Test + public void testOfflineLog() throws InterruptedException { + try (var td = localDockerTool.startContainerBlocking(IMAGE_HELLO_WORLD, + containerName, Map.of(), null, null)) { + Tuple2 stringTuple2 = logOfflineCollectorDocker.collect(); + Assertions.assertEquals(containerName, stringTuple2._1()); + Assertions.assertEquals(OUT_PUT_HELLO_WORLD, stringTuple2._2()); + } + } + + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogStreamingCollectorDockerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogStreamingCollectorDockerTest.java new file mode 100644 index 0000000000..52d73620b2 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/log/TaskLogStreamingCollectorDockerTest.java @@ -0,0 +1,116 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.log; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.starwhale.mlops.domain.job.step.bo.Step; +import ai.starwhale.mlops.domain.task.bo.Task; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinder; +import ai.starwhale.mlops.schedule.impl.docker.LocalDockerTool; +import ai.starwhale.mlops.schedule.impl.docker.LocalDockerTool.TempDockerContainer; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Container; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.UUID; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + + +@Slf4j +public class TaskLogStreamingCollectorDockerTest { + + static final String IMAGE_HELLO_WORLD = "hello-world:linux"; + static final String OUT_PUT_HELLO_WORLD = "Hello from Docker!\n" + + "This message shows that your installation appears to be working correctly.\n" + + "\n" + + "To generate this message, Docker took the following steps:\n" + + " 1. The Docker client contacted the Docker daemon.\n" + + " 2. The Docker daemon pulled the \"hello-world\" image from the Docker Hub.\n" + + " (amd64)\n" + + " 3. The Docker daemon created a new container from that image which runs the\n" + + " executable that produces the output you are currently reading.\n" + + " 4. The Docker daemon streamed that output to the Docker client, which sent it\n" + + " to your terminal.\n" + + "\n" + + "To try something more ambitious, you can run an Ubuntu container with:\n" + + " $ docker run -it ubuntu bash\n" + + "\n" + + "Share images, automate workflows, and more with a free Docker ID:\n" + + " https://hub.docker.com/\n" + + "\n" + + "For more examples and ideas, visit:\n" + + " https://docs.docker.com/get-started/\n"; + static String containerName = UUID.randomUUID().toString(); + DockerClient dockerClient; + DockerClientFinder dockerClientFinder; + + LocalDockerTool localDockerTool = new LocalDockerTool(); + TaskLogStreamingCollectorDocker logStreamingCollector; + + @BeforeEach + public void setup() { + this.dockerClient = localDockerTool.getDockerClient(); + dockerClientFinder = mock(DockerClientFinder.class); + when(dockerClientFinder.findProperDockerClient(any())).thenReturn(this.dockerClient); + } + + @Test + public void testReadLine() throws InterruptedException, IOException { + try (TempDockerContainer tempDockerContainer = localDockerTool.startContainerBlocking(IMAGE_HELLO_WORLD, + containerName, Map.of(), null, null)) { + doCollectLog(containerName); + } + + } + + private void doCollectLog(String containerName) throws IOException { + ContainerTaskMapper containerTaskMapper = mock(ContainerTaskMapper.class); + Task task = Task.builder().id(1L).step(new Step()).build(); + Container container = mock(Container.class); + when(container.getId()).thenReturn(containerName); + + when(containerTaskMapper.containerOfTask(task)).thenReturn(container); + when(containerTaskMapper.taskIfOfContainer(container)).thenReturn(1L); + logStreamingCollector = new TaskLogStreamingCollectorDocker(task, dockerClientFinder, containerTaskMapper); + + String log = ""; + StringBuilder wholeLog = new StringBuilder(); + while (null != log) { + log = logStreamingCollector.readLine(null); + if (null != log) { + wholeLog.append(log); + wholeLog.append("\n"); + } + + } + logStreamingCollector.cancel(); + String wholeLogStr = wholeLog.toString(); + Arrays.stream(OUT_PUT_HELLO_WORLD.split("\n")).forEach(l -> { + Assertions.assertTrue(wholeLogStr.contains(l.strip())); + }); + + } + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/reporting/ContainerStatusExplainerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/reporting/ContainerStatusExplainerTest.java new file mode 100644 index 0000000000..06d1b04e47 --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/reporting/ContainerStatusExplainerTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.reporting; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.starwhale.mlops.domain.task.status.TaskStatus; +import com.github.dockerjava.api.model.Container; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ContainerStatusExplainerTest { + + + @Test + public void testStatus() { + ContainerStatusExplainer containerStatusExplainer = new ContainerStatusExplainer(); + Container c = mock(Container.class); + when(c.getState()).thenReturn("running"); + Assertions.assertEquals(TaskStatus.RUNNING, containerStatusExplainer.statusOf(c)); + + when(c.getState()).thenReturn("exited"); + when(c.getStatus()).thenReturn("Exited (0) blab-la"); + Assertions.assertEquals(TaskStatus.SUCCESS, containerStatusExplainer.statusOf(c)); + + when(c.getState()).thenReturn("exited"); + when(c.getStatus()).thenReturn("Exited (1) blab-la"); + Assertions.assertEquals(TaskStatus.FAIL, containerStatusExplainer.statusOf(c)); + } + + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/reporting/DockerTaskReporterTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/reporting/DockerTaskReporterTest.java new file mode 100644 index 0000000000..2bd33629cf --- /dev/null +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/docker/reporting/DockerTaskReporterTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2022 Starwhale, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.starwhale.mlops.schedule.impl.docker.reporting; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import ai.starwhale.mlops.domain.system.SystemSettingService; +import ai.starwhale.mlops.domain.system.resourcepool.bo.ResourcePool; +import ai.starwhale.mlops.domain.task.status.TaskStatus; +import ai.starwhale.mlops.domain.task.status.TaskStatusMachine; +import ai.starwhale.mlops.schedule.impl.docker.ContainerTaskMapper; +import ai.starwhale.mlops.schedule.impl.docker.DockerClientFinder; +import ai.starwhale.mlops.schedule.impl.docker.SwTaskSchedulerDocker; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.command.ListContainersCmd; +import com.github.dockerjava.api.model.Container; +import java.util.List; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +public class DockerTaskReporterTest { + + TaskReportReceiver taskReportReceiver; + SystemSettingService systemSettingService; + DockerClientFinder dockerClientFinder; + ContainerTaskMapper containerTaskMapper; + ContainerStatusExplainer containerStatusExplainerOnLabel; + TaskStatusMachine taskStatusMachine; + DockerTaskReporter dockerTaskReporter; + + @BeforeEach + public void setup() { + taskReportReceiver = mock(TaskReportReceiver.class); + systemSettingService = mock(SystemSettingService.class); + dockerClientFinder = mock(DockerClientFinder.class); + containerTaskMapper = mock(ContainerTaskMapper.class); + containerStatusExplainerOnLabel = mock(ContainerStatusExplainer.class); + taskStatusMachine = mock(TaskStatusMachine.class); + dockerTaskReporter = new DockerTaskReporter(taskReportReceiver, systemSettingService, dockerClientFinder, + containerTaskMapper, containerStatusExplainerOnLabel, taskStatusMachine); + } + + @Test + public void testReportTask() { + Container c = mock(Container.class); + when(c.getNames()).thenReturn(new String[]{"a"}); + when(containerTaskMapper.taskIfOfContainer(any())).thenReturn(1L); + when(containerStatusExplainerOnLabel.statusOf(c)).thenReturn(TaskStatus.CANCELED); + when(taskStatusMachine.isFinal(any())).thenReturn(true); + dockerTaskReporter.reportTask(c); + ArgumentCaptor captor = ArgumentCaptor.forClass(List.class); + verify(taskReportReceiver).receive(captor.capture()); + List reportedTasks = captor.getValue(); + ReportedTask t = (ReportedTask) reportedTasks.get(0); + Assertions.assertEquals(1L, t.getId()); + Assertions.assertEquals(TaskStatus.CANCELED, t.getStatus()); + Assertions.assertNull(t.getFailedReason()); + Assertions.assertNotNull(t.getStopTimeMillis()); + + String failReason = "Exit (1) blab-la"; + when(c.getStatus()).thenReturn(failReason); + when(containerStatusExplainerOnLabel.statusOf(c)).thenReturn(TaskStatus.FAIL); + when(taskStatusMachine.isFinal(any())).thenReturn(true); + dockerTaskReporter.reportTask(c); + verify(taskReportReceiver, times(2)).receive(captor.capture()); + reportedTasks = captor.getValue(); + t = (ReportedTask) reportedTasks.get(0); + Assertions.assertEquals(1L, t.getId()); + Assertions.assertEquals(TaskStatus.FAIL, t.getStatus()); + Assertions.assertEquals(failReason, t.getFailedReason()); + Assertions.assertNotNull(t.getStopTimeMillis()); + + } + + @Test + public void testReportTasks() { + when(systemSettingService.getAllResourcePools()).thenReturn(List.of(new ResourcePool(), new ResourcePool())); + doTest(); + } + + @Test + public void testReportTasksWithNoResourcePool() { + when(systemSettingService.getAllResourcePools()).thenReturn(null); + doTest(); + } + + private void doTest() { + DockerClient dockerClient = mock(DockerClient.class); + when(dockerClientFinder.findProperDockerClient(any())).thenReturn(dockerClient); + ListContainersCmd listContainersCmd = mock( + ListContainersCmd.class); + when(listContainersCmd.withLabelFilter(SwTaskSchedulerDocker.CONTAINER_LABELS)).thenReturn(listContainersCmd); + when(listContainersCmd.withShowAll(true)).thenReturn(listContainersCmd); + when(dockerClient.listContainersCmd()).thenReturn(listContainersCmd); + Container c = mock(Container.class); + when(listContainersCmd.exec()).thenReturn(List.of(c)); + when(c.getNames()).thenReturn(new String[]{"a"}); + when(containerTaskMapper.taskIfOfContainer(any())).thenReturn(1L); + when(containerStatusExplainerOnLabel.statusOf(c)).thenReturn(TaskStatus.CANCELED); + when(taskStatusMachine.isFinal(any())).thenReturn(true); + + dockerTaskReporter.reportTasks(); + ArgumentCaptor captor = ArgumentCaptor.forClass(List.class); + verify(taskReportReceiver).receive(captor.capture()); + Assertions.assertEquals(1, captor.getValue().size()); + } + +} diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/JobEventHandlerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/JobEventHandlerTest.java similarity index 89% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/JobEventHandlerTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/JobEventHandlerTest.java index 59087b57fd..082717f82a 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/JobEventHandlerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/JobEventHandlerTest.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; -import static ai.starwhale.mlops.schedule.k8s.K8sJobTemplate.JOB_TYPE_LABEL; -import static ai.starwhale.mlops.schedule.k8s.K8sJobTemplate.WORKLOAD_TYPE_EVAL; +import static ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate.JOB_TYPE_LABEL; +import static ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate.WORKLOAD_TYPE_EVAL; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; @@ -30,8 +30,8 @@ import ai.starwhale.mlops.domain.runtime.RuntimeService; import ai.starwhale.mlops.domain.task.status.TaskStatus; import ai.starwhale.mlops.domain.task.status.TaskStatusMachine; -import ai.starwhale.mlops.reporting.ReportedTask; -import ai.starwhale.mlops.reporting.TaskModifyReceiver; +import ai.starwhale.mlops.schedule.reporting.ReportedTask; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1Job; import io.kubernetes.client.openapi.models.V1JobCondition; @@ -49,7 +49,7 @@ public class JobEventHandlerTest { - TaskModifyReceiver taskModifyReceiver; + TaskReportReceiver taskReportReceiver; DatasetService datasetService; JobEventHandler jobEventHandler; private final K8sClient k8sClient = mock(K8sClient.class); @@ -58,11 +58,11 @@ public class JobEventHandlerTest { @BeforeEach public void setUp() throws ApiException { - taskModifyReceiver = mock(TaskModifyReceiver.class); + taskReportReceiver = mock(TaskReportReceiver.class); datasetService = mock(DatasetService.class); TaskStatusMachine taskStatusMachine = new TaskStatusMachine(); jobEventHandler = new JobEventHandler( - taskModifyReceiver, taskStatusMachine, mock(RuntimeService.class), datasetService, k8sClient); + taskReportReceiver, taskStatusMachine, mock(RuntimeService.class), datasetService, k8sClient); var pod = new V1Pod().metadata(new V1ObjectMeta().name("1")); pod.setStatus(new V1PodStatus().startTime(startTime)); @@ -88,7 +88,7 @@ public void testOnAddSuccess() { .startTimeMillis(startTime.toInstant().toEpochMilli()) .stopTimeMillis(completeTime.toInstant().toEpochMilli()) .build(); - verify(taskModifyReceiver).receive(List.of(expected)); + verify(taskReportReceiver).receive(List.of(expected)); } @Test @@ -109,7 +109,7 @@ public void testOnAddFail() { .stopTimeMillis(endTime.toInstant().toEpochMilli()) .retryCount(1) .build(); - verify(taskModifyReceiver).receive(List.of(expected)); + verify(taskReportReceiver).receive(List.of(expected)); // test with reason and message var con = new V1JobCondition() @@ -124,7 +124,7 @@ public void testOnAddFail() { .failedReason("job failed: reason, message") .retryCount(1) .build(); - verify(taskModifyReceiver).receive(List.of(expected2)); + verify(taskReportReceiver).receive(List.of(expected2)); // prefer using pod status var pod = new V1Pod().metadata(new V1ObjectMeta().name("1")); @@ -140,7 +140,7 @@ public void testOnAddFail() { .failedReason("job failed: reason, message\npod failed: foo, bar\nfoo, bar") .retryCount(1) .build(); - verify(taskModifyReceiver).receive(List.of(expected3)); + verify(taskReportReceiver).receive(List.of(expected3)); } @Test @@ -161,7 +161,7 @@ public void testOnUpdateSuccess() { .stopTimeMillis(endTime.toInstant().toEpochMilli()) .ip(null) .build(); - verify(taskModifyReceiver).receive(List.of(expected)); + verify(taskReportReceiver).receive(List.of(expected)); } @Test @@ -183,7 +183,7 @@ public void testOnUpdateFail() { .startTimeMillis(startTime.toInstant().toEpochMilli()) .stopTimeMillis(endTime.toInstant().toEpochMilli()) .build(); - verify(taskModifyReceiver).receive(List.of(expected)); + verify(taskReportReceiver).receive(List.of(expected)); verify(k8sClient, times(2)).getPodsByJobNameQuietly("1"); } @@ -201,7 +201,7 @@ public void testOnUpdateUnknown() { .status(TaskStatus.UNKNOWN) .retryCount(0) .build(); - verify(taskModifyReceiver).receive(List.of(expected)); + verify(taskReportReceiver).receive(List.of(expected)); } @Test @@ -211,7 +211,7 @@ public void testOnDelete() { v1Job.setStatus(new V1JobStatus().active(1)); jobEventHandler.onDelete(v1Job, false); var args = ArgumentCaptor.forClass(List.class); - verify(taskModifyReceiver).receive(args.capture()); + verify(taskReportReceiver).receive(args.capture()); assertEquals(1, args.getValue().size()); ReportedTask reportedTask = (ReportedTask) args.getValue().get(0); assertEquals(1L, reportedTask.getId()); diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sClientTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sClientTest.java similarity index 97% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sClientTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sClientTest.java index 0c85dd36c5..fd567d2e90 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sClientTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sClientTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -23,6 +23,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; import io.kubernetes.client.informer.ResourceEventHandler; import io.kubernetes.client.informer.SharedIndexInformer; import io.kubernetes.client.informer.SharedInformerFactory; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sJobTemplateTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sJobTemplateTest.java similarity index 97% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sJobTemplateTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sJobTemplateTest.java index e7ae1140d4..75d94edbe9 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sJobTemplateTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sJobTemplateTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.hasEntry; @@ -22,6 +22,9 @@ import ai.starwhale.mlops.domain.runtime.RuntimeResource; import ai.starwhale.mlops.domain.system.resourcepool.bo.Toleration; +import ai.starwhale.mlops.schedule.impl.k8s.ContainerOverwriteSpec; +import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import io.kubernetes.client.custom.Quantity; import io.kubernetes.client.openapi.models.V1Container; import io.kubernetes.client.openapi.models.V1EnvVar; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sTaskSchedulerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sSwTaskSchedulerTest.java similarity index 55% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sTaskSchedulerTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sSwTaskSchedulerTest.java index 861ff6f74b..3898704be2 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/K8sTaskSchedulerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/K8sSwTaskSchedulerTest.java @@ -14,22 +14,17 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import ai.starwhale.mlops.configuration.RunTimeProperties; -import ai.starwhale.mlops.configuration.RunTimeProperties.Pypi; -import ai.starwhale.mlops.configuration.RunTimeProperties.RunConfig; -import ai.starwhale.mlops.configuration.security.TaskTokenValidator; import ai.starwhale.mlops.domain.dataset.bo.DataSet; import ai.starwhale.mlops.domain.job.JobType; import ai.starwhale.mlops.domain.job.bo.Job; @@ -47,8 +42,9 @@ import ai.starwhale.mlops.domain.task.bo.Task; import ai.starwhale.mlops.domain.task.bo.TaskRequest; import ai.starwhale.mlops.domain.task.status.TaskStatus; -import ai.starwhale.mlops.domain.task.status.watchers.log.TaskLogK8sCollector; import ai.starwhale.mlops.exception.SwProcessException; +import ai.starwhale.mlops.schedule.TaskRunningEnvBuilder; +import ai.starwhale.mlops.schedule.reporting.TaskReportReceiver; import ai.starwhale.mlops.storage.StorageAccessService; import io.kubernetes.client.custom.Quantity; import io.kubernetes.client.openapi.ApiException; @@ -58,12 +54,10 @@ import io.kubernetes.client.openapi.models.V1Pod; import io.kubernetes.client.openapi.models.V1PodList; import java.io.IOException; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; -import java.util.stream.Collectors; import java.util.stream.Stream; import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.Assertions; @@ -71,48 +65,29 @@ import org.mockito.ArgumentCaptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; -public class K8sTaskSchedulerTest { - - static final String CONDARC = "channels:\n" - + " - defaults\n" - + "show_channel_urls: true\n" - + "default_channels:\n" - + " - http://nexus.starwhale.ai/repository/anaconda/main\n" - + " - http://nexus.starwhale.ai/repository/anaconda/r\n" - + " - http://nexus.starwhale.ai/repository/anaconda/msys2\n" - + "custom_channels:\n" - + " conda-forge: http://nexus.starwhale.ai/repository/conda-cloud\n" - + " nvidia: http://nexus.starwhale.ai/repository/conda-cloud\n" - + "ssl_verify: false\n" - + "default_threads: 10"; +public class K8sSwTaskSchedulerTest { + @Test public void testScheduler() throws IOException, ApiException { K8sClient k8sClient = mock(K8sClient.class); - K8sTaskScheduler scheduler = buildK8sScheduler(k8sClient); - scheduler.schedule(Set.of(mockTask(false))); + K8sSwTaskScheduler scheduler = buildK8sScheduler(k8sClient); + scheduler.schedule(Set.of(mockTask(false)), mock(TaskReportReceiver.class)); verify(k8sClient).deployJob(any()); } @NotNull - private K8sTaskScheduler buildK8sScheduler(K8sClient k8sClient) throws IOException { - TaskTokenValidator taskTokenValidator = mock(TaskTokenValidator.class); - when(taskTokenValidator.getTaskToken(any(), any())).thenReturn("tt"); - RunTimeProperties runTimeProperties = new RunTimeProperties( - "", new RunConfig(), new RunConfig(), new Pypi("indexU", "extraU", "trustedH", 1, 2), CONDARC); + private K8sSwTaskScheduler buildK8sScheduler(K8sClient k8sClient) throws IOException { + StorageAccessService storageAccessService = mock(StorageAccessService.class); when(storageAccessService.list(eq("path_rt"))).thenReturn(Stream.of("path_rt")); when(storageAccessService.signedUrl(eq("path_rt"), any())).thenReturn("s3://bucket/path_rt"); - return new K8sTaskScheduler(k8sClient, - taskTokenValidator, - runTimeProperties, + return new K8sSwTaskScheduler(k8sClient, new K8sJobTemplateMock(""), - "http://instanceUri", - 8000, - 50, - "OnFailure", 10, + mock(TaskRunningEnvBuilder.class), + "rp", + 10, storageAccessService, - mock(TaskLogK8sCollector.class), mock(ThreadPoolTaskScheduler.class) ); } @@ -121,68 +96,24 @@ private K8sTaskScheduler buildK8sScheduler(K8sClient k8sClient) throws IOExcepti public void testException() throws ApiException, IOException { K8sClient k8sClient = mock(K8sClient.class); when(k8sClient.deployJob(any())).thenThrow(new ApiException()); - K8sTaskScheduler scheduler = buildK8sScheduler(k8sClient); + K8sSwTaskScheduler scheduler = buildK8sScheduler(k8sClient); Task task = mockTask(false); - scheduler.schedule(Set.of(task)); + scheduler.schedule(Set.of(task), mock(TaskReportReceiver.class)); Assertions.assertEquals(TaskStatus.FAIL, task.getStatus()); } - @Test - public void testRenderWithoutGpuResource() throws IOException, ApiException { - var client = mock(K8sClient.class); - - var runTimeProperties = new RunTimeProperties( - "", new RunConfig(), new RunConfig(), new Pypi("", "", "", 1, 2), CONDARC); - var k8sJobTemplate = new K8sJobTemplate("", "", "", ""); - var scheduler = new K8sTaskScheduler( - client, - mock(TaskTokenValidator.class), - runTimeProperties, - k8sJobTemplate, - "", - 8000, - 50, - "OnFailure", - 10, - mock(StorageAccessService.class), - mock(TaskLogK8sCollector.class), - mock(ThreadPoolTaskScheduler.class) - ); - var task = mockTask(false); - scheduler.schedule(Set.of(task)); - var jobArgumentCaptor = ArgumentCaptor.forClass(V1Job.class); - task.getTaskRequest() - .setRuntimeResources(List.of(new RuntimeResource(ResourceOverwriteSpec.RESOURCE_GPU, 1f, 0f))); - scheduler.schedule(Set.of(task)); - - verify(client, times(2)).deployJob(jobArgumentCaptor.capture()); - var jobs = jobArgumentCaptor.getAllValues(); - var expectedEnv = new V1EnvVar().name("NVIDIA_VISIBLE_DEVICES").value(""); - Assertions.assertTrue(jobs.get(0).getSpec().getTemplate().getSpec() - .getContainers().get(0).getEnv().contains(expectedEnv)); - Assertions.assertFalse(jobs.get(1).getSpec().getTemplate().getSpec() - .getContainers().get(0).getEnv().contains(expectedEnv)); - } - @Test public void testRenderWithDefaultGpuResourceInPool() throws IOException, ApiException { var client = mock(K8sClient.class); - var runTimeProperties = new RunTimeProperties( - "", new RunConfig(), new RunConfig(), new Pypi("", "", "", 1, 2), CONDARC); var k8sJobTemplate = new K8sJobTemplate("", "", "", ""); - var scheduler = new K8sTaskScheduler( + var scheduler = new K8sSwTaskScheduler( client, - mock(TaskTokenValidator.class), - runTimeProperties, k8sJobTemplate, - "", - 8000, + mock(TaskRunningEnvBuilder.class), + "rp", 50, - "OnFailure", - 10, mock(StorageAccessService.class), - mock(TaskLogK8sCollector.class), mock(ThreadPoolTaskScheduler.class) ); var task = mockTask(false); @@ -194,7 +125,7 @@ public void testRenderWithDefaultGpuResourceInPool() throws IOException, ApiExce var jobArgumentCaptor = ArgumentCaptor.forClass(V1Job.class); // set no resource spec in task task.getTaskRequest().setRuntimeResources(List.of()); - scheduler.schedule(Set.of(task)); + scheduler.schedule(Set.of(task), mock(TaskReportReceiver.class)); verify(client, times(1)).deployJob(jobArgumentCaptor.capture()); var jobs = jobArgumentCaptor.getAllValues(); @@ -209,25 +140,18 @@ public void testRenderWithDefaultGpuResourceInPool() throws IOException, ApiExce public void testDevMode() throws IOException, ApiException { var client = mock(K8sClient.class); - var runTimeProperties = new RunTimeProperties( - "", new RunConfig(), new RunConfig(), new Pypi("", "", "", 1, 2), CONDARC); var k8sJobTemplate = new K8sJobTemplate("", "", "", ""); - var scheduler = new K8sTaskScheduler( + var scheduler = new K8sSwTaskScheduler( client, - mock(TaskTokenValidator.class), - runTimeProperties, k8sJobTemplate, - "", - 8000, + mock(TaskRunningEnvBuilder.class), + "rp", 50, - "OnFailure", - 10, mock(StorageAccessService.class), - mock(TaskLogK8sCollector.class), mock(ThreadPoolTaskScheduler.class) ); var task = mockTask(true); - scheduler.schedule(Set.of(task)); + scheduler.schedule(Set.of(task), mock(TaskReportReceiver.class)); var jobArgumentCaptor = ArgumentCaptor.forClass(V1Job.class); verify(client, times(1)).deployJob(jobArgumentCaptor.capture()); @@ -241,16 +165,16 @@ private Task mockTask(boolean devMode) { .id(1L) .model(Model.builder().name("swmpN").version("swmpV").projectId(101L).build()) .jobRuntime(JobRuntime.builder() - .name("swrtN") - .version("swrtV") - .image("imageRT") - .storagePath("path_rt") - .projectId(102L) - .manifest(new RuntimeService.RuntimeManifest( - "", - new RuntimeService.RuntimeManifest.Environment("3.10", - new RuntimeService.RuntimeManifest.Lock("0.5.1")), null)) - .build()) + .name("swrtN") + .version("swrtV") + .image("imageRT") + .storagePath("path_rt") + .projectId(102L) + .manifest(new RuntimeService.RuntimeManifest( + "", + new RuntimeService.RuntimeManifest.Environment("3.10", + new RuntimeService.RuntimeManifest.Lock("0.5.1")), null)) + .build()) .type(JobType.EVALUATION) .devMode(devMode) .uuid("juuid") @@ -298,46 +222,6 @@ public V1Job renderJob(V1Job job, String jobName, String restartPolicy, int back Assertions.assertEquals("imageRT", worker.getImage()); Assertions.assertIterableEquals(Map.of("cpu", new Quantity("1000m")).entrySet(), worker.getResourceOverwriteSpec().getResourceSelector().getRequests().entrySet()); - Map expectedEnvs = new HashMap<>() { - }; - expectedEnvs.put("SW_RUNTIME_PYTHON_VERSION", "3.10"); - expectedEnvs.put("SW_VERSION", "0.5.1"); - expectedEnvs.put("SW_ENV", "test"); - expectedEnvs.put("SW_PROJECT", "project"); - expectedEnvs.put("DATASET_CONSUMPTION_BATCH_SIZE", "50"); - expectedEnvs.put("SW_DATASET_URI", "http://instanceUri/project/103/dataset/swdsN/version/swdsV"); - expectedEnvs.put("SW_MODEL_URI", "http://instanceUri/project/101/model/swmpN/version/swmpV"); - expectedEnvs.put("SW_RUNTIME_URI", "http://instanceUri/project/102/runtime/swrtN/version/swrtV"); - expectedEnvs.put("SW_MODEL_VERSION", "swmpN/version/swmpV"); - expectedEnvs.put("SW_RUNTIME_VERSION", "swrtN/version/swrtV"); - expectedEnvs.put("SW_TASK_INDEX", "1"); - expectedEnvs.put("SW_TASK_NUM", "1"); - expectedEnvs.put("SW_PYPI_INDEX_URL", "indexU"); - expectedEnvs.put("SW_PYPI_EXTRA_INDEX_URL", "extraU"); - expectedEnvs.put("SW_PYPI_TRUSTED_HOST", "trustedH"); - expectedEnvs.put("SW_JOB_VERSION", "juuid"); - expectedEnvs.put("SW_TOKEN", "tt"); - expectedEnvs.put("SW_INSTANCE_URI", "http://instanceUri"); - expectedEnvs.put("SW_TASK_STEP", "cmp"); - expectedEnvs.put("NVIDIA_VISIBLE_DEVICES", ""); - expectedEnvs.put("SW_PYPI_RETRIES", "1"); - expectedEnvs.put("SW_PYPI_TIMEOUT", "2"); - expectedEnvs.put("SW_CONDA_CONFIG", "channels:\n" - + " - defaults\n" - + "show_channel_urls: true\n" - + "default_channels:\n" - + " - http://nexus.starwhale.ai/repository/anaconda/main\n" - + " - http://nexus.starwhale.ai/repository/anaconda/r\n" - + " - http://nexus.starwhale.ai/repository/anaconda/msys2\n" - + "custom_channels:\n" - + " conda-forge: http://nexus.starwhale.ai/repository/conda-cloud\n" - + " nvidia: http://nexus.starwhale.ai/repository/conda-cloud\n" - + "ssl_verify: false\n" - + "default_threads: 10"); - Map actualEnv = worker.getEnvs().stream() - .filter(envVar -> envVar.getValue() != null) - .collect(Collectors.toMap(V1EnvVar::getName, V1EnvVar::getValue)); - assertMapEquals(expectedEnvs, actualEnv); return null; } @@ -350,27 +234,19 @@ private void assertMapEquals(Map expectedEnvs, Map + verify(taskLogSaver, times(0)).saveLog(any()); + verify(taskReportReceiver, times(1)).receive(any()); + verify(taskReportReceiver).receive(argThat(tasks -> tasks.size() == 1 && tasks.get(0).getId() == 3L && tasks.get(0).getStatus() == TaskStatus.PREPARING @@ -103,7 +104,7 @@ public void testRunning() { public void testTaskNotFound() { when(hotJobHolder.tasksOfIds(List.of(3L))).thenReturn(List.of()); podEventHandler.onUpdate(null, v1Pod); - verify(taskLogK8sCollector, times(0)).collect(any()); + verify(taskLogSaver, times(0)).saveLog(any()); } @Test @@ -112,12 +113,12 @@ public void testPodScheduled() { v1Pod.getStatus().phase("Pending"); v1Pod.getStatus().conditions(List.of(new V1PodCondition().status("True").type("PodScheduled"))); podEventHandler.onUpdate(null, v1Pod); - verify(taskLogK8sCollector, times(0)).collect(any()); + verify(taskLogSaver, times(0)).saveLog(any()); var expect = ReportedTask.builder() .id(3L) .status(TaskStatus.PREPARING) .build(); - verify(taskModifyReceiver, times(1)).receive(List.of(expect)); + verify(taskReportReceiver, times(1)).receive(List.of(expect)); } @Test @@ -125,7 +126,7 @@ public void testTerminating() { v1Pod.getMetadata().setDeletionTimestamp(OffsetDateTime.now()); v1Pod.getStatus().setPhase("Running"); podEventHandler.onUpdate(null, v1Pod); - verify(taskLogK8sCollector, never()).collect(any()); - verify(taskModifyReceiver, never()).receive(any()); + verify(taskLogSaver, never()).saveLog(any()); + verify(taskReportReceiver, never()).receive(any()); } } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/ResourceEventHolderTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceEventHolderTest.java similarity index 96% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/ResourceEventHolderTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceEventHolderTest.java index 5716f56ff9..82581afc2b 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/ResourceEventHolderTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceEventHolderTest.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder; import io.kubernetes.client.openapi.models.CoreV1Event; import io.kubernetes.client.openapi.models.V1ObjectMeta; import io.kubernetes.client.openapi.models.V1ObjectReference; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/ResourceOverwriteSpecTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceOverwriteSpecTest.java similarity index 95% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/ResourceOverwriteSpecTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceOverwriteSpecTest.java index c101d3c88d..62dc8835ce 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/ResourceOverwriteSpecTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/ResourceOverwriteSpecTest.java @@ -14,9 +14,10 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s; +package ai.starwhale.mlops.schedule.impl.k8s; import ai.starwhale.mlops.domain.runtime.RuntimeResource; +import ai.starwhale.mlops.schedule.impl.k8s.ResourceOverwriteSpec; import io.kubernetes.client.custom.Quantity; import java.util.List; import org.junit.jupiter.api.Assertions; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorFactoryTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorFactoryTest.java similarity index 95% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorFactoryTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorFactoryTest.java index b9d225bb15..9d0955794f 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorFactoryTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorFactoryTest.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s.log; +package ai.starwhale.mlops.schedule.impl.k8s.log; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1ObjectMeta; import io.kubernetes.client.openapi.models.V1Pod; @@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test; public class CancellableJobLogK8sCollectorFactoryTest { + @Test public void testMake() throws IOException, ApiException { var k8sClient = mock(K8sClient.class); diff --git a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorTest.java b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorTest.java similarity index 92% rename from server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorTest.java rename to server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorTest.java index b5c9094ebc..d1c914f55d 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/schedule/k8s/log/CancellableJobLogK8sCollectorTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/schedule/impl/k8s/log/CancellableJobLogK8sCollectorTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package ai.starwhale.mlops.schedule.k8s.log; +package ai.starwhale.mlops.schedule.impl.k8s.log; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -25,7 +25,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import ai.starwhale.mlops.schedule.k8s.K8sClient; +import ai.starwhale.mlops.schedule.impl.k8s.K8sClient; import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.models.V1ObjectMeta; import io.kubernetes.client.openapi.models.V1Pod; @@ -41,6 +41,7 @@ import org.junit.jupiter.api.Test; public class CancellableJobLogK8sCollectorTest { + K8sClient k8sClient; @BeforeEach @@ -66,9 +67,9 @@ public void testInitAndRead() throws IOException, ApiException { when(call.execute()).thenReturn(resp); when(k8sClient.readLog(eq("running-pod"), anyString(), anyBoolean())).thenReturn(call); - var ins = new CancellableJobLogK8sCollector(k8sClient, "1"); + var ins = new TaskLogK8sStreamingCollector(k8sClient, "1"); - assertThat(ins.readLine(), is(line)); + assertThat(ins.readLine(1L), is(line)); verify(k8sClient).getPodsByJobName("1"); verify(call).execute(); } diff --git a/server/pom.xml b/server/pom.xml index a42a41b5d8..b73a879c25 100644 --- a/server/pom.xml +++ b/server/pom.xml @@ -49,6 +49,8 @@ 2.7.7 2.7.7 1.11.2 + 1.9.0 + 3.3.2 @@ -265,8 +267,22 @@ org.apache.commons commons-csv - 1.9.0 + ${commons-csv.version} + + + com.github.docker-java + docker-java + ${docker-java.version} + + + + com.github.docker-java + docker-java-transport-httpclient5 + ${docker-java.version} + + +