Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai): AI增加多机多卡 分布式训练 #1261

Merged
merged 17 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/funny-peas-wave.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@scow/scheduler-adapter-protos": patch
"@scow/config": patch
"@scow/ai": patch
---

AI 增加多机多卡分布式训练和对华为 GPU 的特殊处理
2 changes: 2 additions & 0 deletions apps/ai/src/app/(auth)/jobs/[clusterId]/AppSessionsTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ export const AppSessionsTable: React.FC<Props> = ({ cluster, status }) => {
title: "作业ID",
dataIndex: "jobId",
width: "8%",
defaultSortOrder: "descend",
sorter: (a, b) => compareNumber(a.jobId, b.jobId),
},
{
Expand Down Expand Up @@ -258,6 +259,7 @@ export const AppSessionsTable: React.FC<Props> = ({ cluster, status }) => {

const filteredData = useMemo(() => {
if (!data) { return []; }

return data.sessions.filter((x) => {
if (query.appJobName) {
return x.sessionId.toLowerCase().includes(query.appJobName.toLowerCase());
Expand Down
102 changes: 83 additions & 19 deletions apps/ai/src/app/(auth)/jobs/[clusterId]/LaunchAppForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import { ModelInterface, ModelVersionInterface } from "src/models/Model";
import { DatasetInterface } from "src/server/trpc/route/dataset/dataset";
import { DatasetVersionInterface } from "src/server/trpc/route/dataset/datasetVersion";
import { AppCustomAttribute, CreateAppInput } from "src/server/trpc/route/jobs/apps";
import { TrainJobInput } from "src/server/trpc/route/jobs/jobs";
import { FrameworkType, TrainJobInput } from "src/server/trpc/route/jobs/jobs";
import { formatSize } from "src/utils/format";
import { parseBooleanParam } from "src/utils/parse";
import { trpc } from "src/utils/trpc";
Expand Down Expand Up @@ -59,6 +59,7 @@ interface FixedFormFields {
useCustomImage: boolean;
image: { type: AccessibilityType, name: number };
remoteImageUrl: string | undefined;
framework: FrameworkType | undefined;
startCommand?: string;
showDataset: boolean;
dataset: { type: AccessibilityType, name: number, version: number };
Expand All @@ -67,6 +68,7 @@ interface FixedFormFields {
mountPoints: string[] | undefined;
partition: string | undefined;
coreCount: number;
nodeCount: number;
gpuCount: number | undefined;
account: string;
maxTime: number;
Expand All @@ -90,6 +92,8 @@ interface Partition {
gpus: number;
nodes: number;
comment?: string;
gpuType?: string;

}

export enum AccessibilityType {
Expand All @@ -103,6 +107,7 @@ const genAppJobName = (appName: string): string => {
};

const initialValues = {
nodeCount:1,
coreCount: 1,
gpuCount: 1,
maxTime: 60,
Expand All @@ -127,7 +132,16 @@ export const LaunchAppForm = (props: Props) => {

const [currentPartitionInfo, setCurrentPartitionInfo] = useState<Partition | undefined>();


const [frameworkOptions, setFrameworkOptions] = useState<{ value: FrameworkType, label: string }[]>([
{
value: "tensorflow",
label: "TensorFlow",
},
{
value: "pytorch",
label: "PyTorch",
},
]);

const showAlgorithm = Form.useWatch("showAlgorithm", form);
const showDataset = Form.useWatch("showDataset", form);
Expand Down Expand Up @@ -223,8 +237,7 @@ export const LaunchAppForm = (props: Props) => {
.map((x) => ({ label: `${x.name}:${x.tag}`, value: x.id }));
}, [images]);

// 暂时写死为1
const nodeCount = 1;
const nodeCount = Form.useWatch("nodeCount", form);

const coreCount = Form.useWatch("coreCount", form);

Expand All @@ -251,10 +264,26 @@ export const LaunchAppForm = (props: Props) => {
} else {
form.setFieldValue("coreCount", 1);
}

form.setFieldValue("framework", undefined);

setCurrentPartitionInfo(partitionInfo);

};

useEffect(() => {
// 特殊处理,如果是华为Ascend910,则增加MindSpore选项
if (currentPartitionInfo?.gpuType === "huawei.com/Ascend910") {
setFrameworkOptions((prevOptions) => {
return prevOptions.find((item) => item.value === "mindspore") ?
prevOptions : [...prevOptions, { value: "mindspore", label: "MindSpore" }];
});
} else {
setFrameworkOptions((prevOptions) => {
return prevOptions.filter((item) => item.value !== "mindspore");
});
}
}, [currentPartitionInfo]);
const customFormItems = useMemo(() => attributes.map((item, index) => {
const rules: Rule[] = item.type === "NUMBER"
? [{ type: "integer" }, { required: item.required }]
Expand Down Expand Up @@ -437,16 +466,19 @@ export const LaunchAppForm = (props: Props) => {
appJobName: genAppJobName(appName ?? "trainJobs"),
});
} else {
const { account, partition, gpuCount, coreCount, maxTime, mountPoints } = inputParams;
const { account, partition, gpuCount, coreCount, maxTime, mountPoints, nodeCount } = inputParams;
const workingDir = "workingDirectory" in inputParams ? inputParams.workingDirectory : undefined;
const customAttributes = "customAttributes" in inputParams ? inputParams.customAttributes : {};
const command = "command" in inputParams ? inputParams.command : undefined;
const framework = "framework" in inputParams ? inputParams.framework : undefined;
form.setFieldsValue({
mountPoints,
customFields: {
...customAttributes,
workingDir,
},
nodeCount,
framework,
account,
partition,
gpuCount,
Expand Down Expand Up @@ -487,7 +519,7 @@ export const LaunchAppForm = (props: Props) => {
}}
onFinish={async () => {

const { appJobName, algorithm, dataset, image, remoteImageUrl, startCommand, model,
const { appJobName, algorithm, dataset, image, remoteImageUrl, framework, startCommand, model,
mountPoints, account, partition, coreCount,
gpuCount, maxTime, command, customFields } = await form.validateFields();

Expand All @@ -499,6 +531,7 @@ export const LaunchAppForm = (props: Props) => {
algorithm: algorithm?.version,
image: image?.name,
remoteImageUrl,
framework,
isDatasetPrivate,
dataset: dataset?.version,
isModelPrivate,
Expand All @@ -514,6 +547,7 @@ export const LaunchAppForm = (props: Props) => {
maxTime: maxTime,
memory: memorySize,
command: command || "",
gpuType: currentPartitionInfo!.gpuType,
});
} else {
let workingDirectory: string | undefined;
Expand Down Expand Up @@ -551,6 +585,7 @@ export const LaunchAppForm = (props: Props) => {
memory: memorySize,
workingDirectory,
customAttributes: customFormKeyValue.customFields,
gpuType: currentPartitionInfo!.gpuType,
});
}
}
Expand Down Expand Up @@ -656,6 +691,20 @@ export const LaunchAppForm = (props: Props) => {
>
<Input placeholder="请输入远程镜像地址" />
</Form.Item>
{/* 分布式训练或者华为的卡训练,需要指定训练框架 */}
{(isTraining && (nodeCount > 1 || currentPartitionInfo?.gpuType === "huawei.com/Ascend910")) ? (
<>
{/* 手动选择算法框架,下拉框只有 tensorflow, pytorch */}
<Form.Item
label="分布式训练框架"
name="framework"
rules={[{ required: true }]}
>
<Select options={frameworkOptions}>
</Select>
</Form.Item>
</>
) : null}
{(customImage && !isTraining) ? (
<Form.Item
label="启动命令"
Expand Down Expand Up @@ -946,58 +995,73 @@ export const LaunchAppForm = (props: Props) => {
onChange={handlePartitionChange}
/>
</Form.Item>
{/* <Form.Item
<Form.Item
label="节点数"
name="nodeCount"
dependencies={["partition"]}
rules={[
{ required: true, type: "integer",
max: currentPartitionInfo?.nodes,
{
required: true,
type: "integer",
},
]}
>
<InputNumber
min={1}
max={currentPartitionInfo?.nodes}
max={isTraining ? undefined : 1}
{...inputNumberFloorConfig}
/>
</Form.Item> */}
</Form.Item>
{
currentPartitionInfo?.gpus ? (
<Form.Item
label="GPU卡数"
label="单节点GPU卡数"
name="gpuCount"
dependencies={["partition"]}
rules={[
{
required: true,
type: "integer",
max: currentPartitionInfo?.gpus,
// 单机最多8张卡
max: 8,
validator: (_, value) => {
const nodeCount = form.getFieldValue("nodeCount") || 0;
if (currentPartitionInfo
&& currentPartitionInfo.gpus > 0
&& (nodeCount * value > currentPartitionInfo.gpus)) {
return Promise.reject(new Error("Total GPUs exceed the available GPUs in the partition"));
}
return Promise.resolve();
},
},
]}
>
<InputNumber
min={1}
max={currentPartitionInfo?.gpus}
max={8}
{...inputNumberFloorConfig}
/>
</Form.Item>
) : (
<Form.Item
label="CPU核心数"
label="单节点CPU核心数"
name="coreCount"
dependencies={["partition"]}
rules={[
{ required: true,
type: "integer",
max: currentPartitionInfo ?
currentPartitionInfo.cores : undefined },
validator: (_, value) => {
const nodeCount = form.getFieldValue("nodeCount") || 0;
if (currentPartitionInfo && (nodeCount * value > currentPartitionInfo.cores)) {
return Promise.reject(new Error("Total cores exceed the available cores in the partition"));
}
return Promise.resolve();
},
},
]}
>
<InputNumber
min={1}
max={currentPartitionInfo ?
currentPartitionInfo.cores / currentPartitionInfo.nodes : undefined }
{...inputNumberFloorConfig}
/>
</Form.Item>
Expand Down
4 changes: 3 additions & 1 deletion apps/ai/src/server/trpc/route/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,16 @@ export type NavLink = z.infer<typeof NavLinkSchema>;
export type UiConfig = z.infer<typeof UiConfigSchema>;


const PartitionSchema = z.object({
export const PartitionSchema = z.object({
name: z.string(),
memMb: z.number(),
cores: z.number(),
gpus: z.number(),
nodes: z.number(),
qos: z.array(z.string()),
comment: z.string().optional(),
gpuType: z.string().optional(),
vramMb: z.number().optional(),
});

const ClusterConfigSchema = z.object({
Expand Down
Loading
Loading