diff --git a/x-pack/plugins/ml/common/types/trained_models.ts b/x-pack/plugins/ml/common/types/trained_models.ts index b4767db5d5344..1de77b523266d 100644 --- a/x-pack/plugins/ml/common/types/trained_models.ts +++ b/x-pack/plugins/ml/common/types/trained_models.ts @@ -51,7 +51,7 @@ export interface TrainedModelStat { } >; }; - deployment_stats?: Omit; + deployment_stats?: TrainedModelDeploymentStatsResponse; model_size_stats?: TrainedModelModelSizeStats; } @@ -128,6 +128,7 @@ export interface InferenceConfigResponse { export interface TrainedModelDeploymentStatsResponse { model_id: string; + deployment_id: string; inference_threads: number; model_threads: number; state: DeploymentState; @@ -163,6 +164,8 @@ export interface TrainedModelDeploymentStatsResponse { } export interface AllocatedModel { + key: string; + deployment_id: string; inference_threads: number; allocation_status: { target_allocation_count: number; diff --git a/x-pack/plugins/ml/common/util/validators.ts b/x-pack/plugins/ml/common/util/validators.ts index 4cbef8470cfc0..e890db5893ad6 100644 --- a/x-pack/plugins/ml/common/util/validators.ts +++ b/x-pack/plugins/ml/common/util/validators.ts @@ -99,3 +99,13 @@ export function timeIntervalInputValidator() { return null; }; } + +export function dictionaryValidator(dict: string[], shouldInclude: boolean = false) { + const dictSet = new Set(dict); + return (value: string) => { + if (dictSet.has(value) !== shouldInclude) { + return { matchDict: value }; + } + return null; + }; +} diff --git a/x-pack/plugins/ml/public/application/memory_usage/nodes_overview/allocated_models.tsx b/x-pack/plugins/ml/public/application/memory_usage/nodes_overview/allocated_models.tsx index 198c01806e06b..6b86287b6faa1 100644 --- a/x-pack/plugins/ml/public/application/memory_usage/nodes_overview/allocated_models.tsx +++ b/x-pack/plugins/ml/public/application/memory_usage/nodes_overview/allocated_models.tsx @@ -39,6 +39,33 @@ export const AllocatedModels: FC = ({ const euiTheme = useEuiTheme(); const columns: Array> = [ + { + id: 'deployment_id', + field: 'deployment_id', + name: i18n.translate('xpack.ml.trainedModels.nodesList.modelsList.deploymentIdHeader', { + defaultMessage: 'ID', + }), + width: '150px', + sortable: true, + truncateText: false, + 'data-test-subj': 'mlAllocatedModelsTableDeploymentId', + }, + { + name: i18n.translate('xpack.ml.trainedModels.nodesList.modelsList.modelRoutingStateHeader', { + defaultMessage: 'Routing state', + }), + width: '100px', + 'data-test-subj': 'mlAllocatedModelsTableRoutingState', + render: (v: AllocatedModel) => { + const { routing_state: routingState, reason } = v.node.routing_state; + + return ( + + {routingState} + + ); + }, + }, { id: 'node_name', field: 'node.name', @@ -193,22 +220,6 @@ export const AllocatedModels: FC = ({ return v.node.number_of_pending_requests; }, }, - { - name: i18n.translate('xpack.ml.trainedModels.nodesList.modelsList.modelRoutingStateHeader', { - defaultMessage: 'Routing state', - }), - width: '100px', - 'data-test-subj': 'mlAllocatedModelsTableRoutingState', - render: (v: AllocatedModel) => { - const { routing_state: routingState, reason } = v.node.routing_state; - - return ( - - {routingState} - - ); - }, - }, ].filter((v) => !hideColumns.includes(v.id!)); return ( @@ -219,7 +230,7 @@ export const AllocatedModels: FC = ({ isExpandable={false} isSelectable={false} items={models} - itemId={'model_id'} + itemId={'key'} rowProps={(item) => ({ 'data-test-subj': `mlAllocatedModelTableRow row-${item.model_id}`, })} diff --git a/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx b/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx index 9068f679cf261..622fb0a961183 100644 --- a/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx +++ b/x-pack/plugins/ml/public/application/model_management/deployment_setup.tsx @@ -5,25 +5,27 @@ * 2.0. */ -import React, { FC, useState, useMemo } from 'react'; +import React, { FC, useMemo, useState } from 'react'; import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; import { - EuiForm, + EuiButton, + EuiButtonEmpty, EuiButtonGroup, - EuiFormRow, + EuiCallOut, + EuiDescribedFormGroup, EuiFieldNumber, + EuiFieldText, + EuiForm, + EuiFormRow, + EuiLink, EuiModal, - EuiModalHeader, - EuiModalHeaderTitle, EuiModalBody, EuiModalFooter, - EuiButtonEmpty, - EuiButton, - EuiCallOut, + EuiModalHeader, + EuiModalHeaderTitle, + EuiSelect, EuiSpacer, - EuiDescribedFormGroup, - EuiLink, } from '@elastic/eui'; import { toMountPoint, wrapWithTheme } from '@kbn/kibana-react-plugin/public'; import type { Observable } from 'rxjs'; @@ -31,17 +33,26 @@ import type { CoreTheme, OverlayStart } from '@kbn/core/public'; import { css } from '@emotion/react'; import { numberValidator } from '@kbn/ml-agg-utils'; import { isCloudTrial } from '../services/ml_server_info'; -import { composeValidators, requiredValidator } from '../../../common/util/validators'; +import { + composeValidators, + dictionaryValidator, + requiredValidator, +} from '../../../common/util/validators'; +import { ModelItem } from './models_list'; interface DeploymentSetupProps { config: ThreadingParams; onConfigChange: (config: ThreadingParams) => void; + errors: Partial>; + isUpdate?: boolean; + deploymentsParams?: Record; } export interface ThreadingParams { numOfAllocations: number; threadsPerAllocations?: number; priority?: 'low' | 'normal'; + deploymentId?: string; } const THREADS_MAX_EXPONENT = 4; @@ -49,10 +60,21 @@ const THREADS_MAX_EXPONENT = 4; /** * Form for setting threading params. */ -export const DeploymentSetup: FC = ({ config, onConfigChange }) => { +export const DeploymentSetup: FC = ({ + config, + onConfigChange, + errors, + isUpdate, + deploymentsParams, +}) => { const numOfAllocation = config.numOfAllocations; const threadsPerAllocations = config.threadsPerAllocations; + const defaultDeploymentId = useMemo(() => { + return config.deploymentId; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + const threadsPerAllocationsOptions = useMemo( () => new Array(THREADS_MAX_EXPONENT).fill(null).map((v, i) => { @@ -72,6 +94,70 @@ export const DeploymentSetup: FC = ({ config, onConfigChan return ( + + + + } + description={ + + } + > + + } + hasChildLabel={false} + isInvalid={!!errors.deploymentId} + error={ + + } + > + {!isUpdate ? ( + { + onConfigChange({ ...config, deploymentId: e.target.value }); + }} + data-test-subj={'mlModelsStartDeploymentModalDeploymentId'} + /> + ) : ( + { + return { text: v, value: v }; + })} + value={config.deploymentId} + onChange={(e) => { + const update = e.target.value; + onConfigChange({ + ...config, + deploymentId: update, + numOfAllocations: deploymentsParams![update].numOfAllocations, + }); + }} + data-test-subj={'mlModelsStartDeploymentModalDeploymentSelectId'} + /> + )} + + + {config.priority !== undefined ? ( = ({ config, onConfigChan }; interface StartDeploymentModalProps { - modelId: string; + model: ModelItem; startModelDeploymentDocUrl: string; onConfigChange: (config: ThreadingParams) => void; onClose: () => void; initialParams?: ThreadingParams; + modelAndDeploymentIds?: string[]; } /** * Modal window wrapper for {@link DeploymentSetup} */ export const StartUpdateDeploymentModal: FC = ({ - modelId, + model, onConfigChange, onClose, startModelDeploymentDocUrl, initialParams, + modelAndDeploymentIds, }) => { + const isUpdate = !!initialParams; + const [config, setConfig] = useState( initialParams ?? { numOfAllocations: 1, threadsPerAllocations: 1, priority: isCloudTrial() ? 'low' : 'normal', + deploymentId: model.model_id, } ); - const isUpdate = initialParams !== undefined; + const deploymentIdValidator = useMemo(() => { + if (isUpdate) { + return () => null; + } + + const otherModelAndDeploymentIds = [...(modelAndDeploymentIds ?? [])]; + otherModelAndDeploymentIds.splice(otherModelAndDeploymentIds?.indexOf(model.model_id), 1); + + return dictionaryValidator([ + ...model.deployment_ids, + ...otherModelAndDeploymentIds, + // check for deployment with the default ID + ...(model.deployment_ids.includes(model.model_id) ? [''] : []), + ]); + }, [modelAndDeploymentIds, model.deployment_ids, model.model_id, isUpdate]); const numOfAllocationsValidator = composeValidators( requiredValidator(), numberValidator({ min: 1, integerOnly: true }) ); - const errors = numOfAllocationsValidator(config.numOfAllocations); + const numOfAllocationsErrors = numOfAllocationsValidator(config.numOfAllocations); + const deploymentIdErrors = deploymentIdValidator(config.deploymentId ?? ''); + + const errors = { + ...(numOfAllocationsErrors ? { numOfAllocations: numOfAllocationsErrors } : {}), + ...(deploymentIdErrors ? { deploymentId: deploymentIdErrors } : {}), + }; return ( = ({ ) : ( )} @@ -313,7 +424,19 @@ export const StartUpdateDeploymentModal: FC = ({ /> - + >( + (acc, curr) => { + acc[curr.deployment_id] = { numOfAllocations: curr.number_of_allocations }; + return acc; + }, + {} + )} + /> @@ -346,7 +469,7 @@ export const StartUpdateDeploymentModal: FC = ({ form={'startDeploymentForm'} onClick={onConfigChange.bind(null, config)} fill - disabled={!!errors} + disabled={Object.keys(errors).length > 0} data-test-subj={'mlModelsStartDeploymentModalStartButton'} > {isUpdate ? ( @@ -373,9 +496,13 @@ export const StartUpdateDeploymentModal: FC = ({ * @param overlays * @param theme$ */ -export const getUserInputThreadingParamsProvider = +export const getUserInputModelDeploymentParamsProvider = (overlays: OverlayStart, theme$: Observable, startModelDeploymentDocUrl: string) => - (modelId: string, initialParams?: ThreadingParams): Promise => { + ( + model: ModelItem, + initialParams?: ThreadingParams, + deploymentIds?: string[] + ): Promise => { return new Promise(async (resolve) => { try { const modalSession = overlays.openModal( @@ -384,7 +511,8 @@ export const getUserInputThreadingParamsProvider = { modalSession.close(); diff --git a/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx b/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx index 4d6a3c744408a..0fd2068d8bcf3 100644 --- a/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx +++ b/x-pack/plugins/ml/public/application/model_management/expanded_row.tsx @@ -5,7 +5,7 @@ * 2.0. */ -import React, { FC, useEffect, useState, useMemo, useCallback } from 'react'; +import React, { FC, useMemo, useCallback } from 'react'; import { omit, pick } from 'lodash'; import { EuiBadge, @@ -110,8 +110,6 @@ export function useListItemsFormatter() { } export const ExpandedRow: FC = ({ item }) => { - const [modelItems, setModelItems] = useState([]); - const formatToListItems = useListItemsFormatter(); const { @@ -144,42 +142,39 @@ export const ExpandedRow: FC = ({ item }) => { license_level, }; - useEffect( - function updateModelItems() { - (async function () { - const deploymentStats = stats.deployment_stats; - const modelSizeStats = stats.model_size_stats; + const deploymentStatItems: AllocatedModel[] = useMemo(() => { + const deploymentStats = stats.deployment_stats; + const modelSizeStats = stats.model_size_stats; - if (!deploymentStats || !modelSizeStats) return; + if (!deploymentStats || !modelSizeStats) return []; - const items: AllocatedModel[] = deploymentStats.nodes.map((n) => { - const nodeName = Object.values(n.node)[0].name; - return { - ...deploymentStats, - ...modelSizeStats, - node: { - ...pick(n, [ - 'average_inference_time_ms', - 'inference_count', - 'routing_state', - 'last_access', - 'number_of_pending_requests', - 'start_time', - 'throughput_last_minute', - 'number_of_allocations', - 'threads_per_allocation', - ]), - name: nodeName, - } as AllocatedModel['node'], - }; - }); + const items: AllocatedModel[] = deploymentStats.flatMap((perDeploymentStat) => { + return perDeploymentStat.nodes.map((n) => { + const nodeName = Object.values(n.node)[0].name; + return { + key: `${perDeploymentStat.deployment_id}_${nodeName}`, + ...perDeploymentStat, + ...modelSizeStats, + node: { + ...pick(n, [ + 'average_inference_time_ms', + 'inference_count', + 'routing_state', + 'last_access', + 'number_of_pending_requests', + 'start_time', + 'throughput_last_minute', + 'number_of_allocations', + 'threads_per_allocation', + ]), + name: nodeName, + } as AllocatedModel['node'], + }; + }); + }); - setModelItems(items); - })(); - }, - // eslint-disable-next-line react-hooks/exhaustive-deps - [stats.deployment_stats] - ); + return items; + }, [stats]); const tabs: EuiTabbedContentTab[] = [ { @@ -313,7 +308,7 @@ export const ExpandedRow: FC = ({ item }) => {
- {!!modelItems?.length ? ( + {!!deploymentStatItems?.length ? ( <> @@ -325,7 +320,7 @@ export const ExpandedRow: FC = ({ item }) => { - + @@ -379,7 +374,7 @@ export const ExpandedRow: FC = ({ item }) => { }, ] : []), - ...((pipelines && Object.keys(pipelines).length > 0) || stats.ingest + ...((isPopulatedObject(pipelines) && Object.keys(pipelines).length > 0) || stats.ingest ? [ { id: 'pipelines', @@ -389,8 +384,10 @@ export const ExpandedRow: FC = ({ item }) => { {' '} - {stats.pipeline_count} + /> + + {isPopulatedObject(pipelines) ? Object.keys(pipelines!).length : 0} + ), content: ( diff --git a/x-pack/plugins/ml/public/application/model_management/force_stop_dialog.tsx b/x-pack/plugins/ml/public/application/model_management/force_stop_dialog.tsx index 1800f9b13db03..39b108127b127 100644 --- a/x-pack/plugins/ml/public/application/model_management/force_stop_dialog.tsx +++ b/x-pack/plugins/ml/public/application/model_management/force_stop_dialog.tsx @@ -5,33 +5,108 @@ * 2.0. */ -import React, { FC } from 'react'; -import { EuiConfirmModal } from '@elastic/eui'; +import React, { type FC, useState, useMemo, useCallback } from 'react'; +import { + EuiCallOut, + EuiCheckboxGroup, + EuiCheckboxGroupOption, + EuiConfirmModal, + EuiSpacer, +} from '@elastic/eui'; import { FormattedMessage } from '@kbn/i18n-react'; import { i18n } from '@kbn/i18n'; import type { OverlayStart, ThemeServiceStart } from '@kbn/core/public'; import { toMountPoint, wrapWithTheme } from '@kbn/kibana-react-plugin/public'; +import { isPopulatedObject } from '@kbn/ml-is-populated-object'; +import { isDefined } from '@kbn/ml-is-defined'; import type { ModelItem } from './models_list'; interface ForceStopModelConfirmDialogProps { model: ModelItem; onCancel: () => void; - onConfirm: () => void; + onConfirm: (deploymentIds: string[]) => void; } -export const ForceStopModelConfirmDialog: FC = ({ +/** + * Confirmation is required when there are multiple model deployments + * or associated pipelines. + */ +export const StopModelDeploymentsConfirmDialog: FC = ({ model, onConfirm, onCancel, }) => { + const [checkboxIdToSelectedMap, setCheckboxIdToSelectedMap] = useState>( + {} + ); + + const options: EuiCheckboxGroupOption[] = useMemo( + () => + model.deployment_ids.map((deploymentId) => { + return { + id: deploymentId, + label: deploymentId, + }; + }), + [model.deployment_ids] + ); + + const onChange = useCallback((id: string) => { + setCheckboxIdToSelectedMap((prev) => { + return { + ...prev, + [id]: !prev[id], + }; + }); + }, []); + + const selectedDeploymentIds = useMemo( + () => + model.deployment_ids.length > 1 + ? Object.keys(checkboxIdToSelectedMap).filter((id) => checkboxIdToSelectedMap[id]) + : model.deployment_ids, + [model.deployment_ids, checkboxIdToSelectedMap] + ); + + const deploymentPipelinesMap = useMemo(() => { + if (!isPopulatedObject(model.pipelines)) return {}; + return Object.entries(model.pipelines).reduce((acc, [pipelineId, pipelineDef]) => { + const deploymentIds: string[] = (pipelineDef?.processors ?? []) + .map((v) => v?.inference?.model_id) + .filter(isDefined); + deploymentIds.forEach((dId) => { + if (acc[dId]) { + acc[dId].push(pipelineId); + } else { + acc[dId] = [pipelineId]; + } + }); + return acc; + }, {} as Record); + }, [model.pipelines]); + + const pipelineWarning = useMemo(() => { + if (model.deployment_ids.length === 1 && isPopulatedObject(model.pipelines)) { + return Object.keys(model.pipelines); + } + return [ + ...new Set( + Object.entries(deploymentPipelinesMap) + .filter(([deploymentId]) => selectedDeploymentIds.includes(deploymentId)) + .flatMap(([, pipelineNames]) => pipelineNames) + ), + ].sort(); + }, [model, deploymentPipelinesMap, selectedDeploymentIds]); + return ( = { defaultMessage: 'Stop' } )} buttonColor="danger" + confirmButtonDisabled={model.deployment_ids.length > 1 && selectedDeploymentIds.length === 0} > - -
    - {Object.keys(model.pipelines!) - .sort() - .map((pipelineName) => { - return
  • {pipelineName}
  • ; - })} -
+ {model.deployment_ids.length > 1 ? ( + <> + + ), + }} + options={options} + idToSelectedMap={checkboxIdToSelectedMap} + onChange={onChange} + /> + + + ) : null} + + {pipelineWarning.length > 0 ? ( + <> + + } + color="warning" + iconType="warning" + > +

+

    + {pipelineWarning.map((pipelineName) => { + return
  • {pipelineName}
  • ; + })} +
+

+
+ + ) : null}
); }; export const getUserConfirmationProvider = - (overlays: OverlayStart, theme: ThemeServiceStart) => async (forceStopModel: ModelItem) => { + (overlays: OverlayStart, theme: ThemeServiceStart) => + async (forceStopModel: ModelItem): Promise => { return new Promise(async (resolve, reject) => { try { const modalSession = overlays.openModal( toMountPoint( wrapWithTheme( - { modalSession.close(); - resolve(false); + reject(); }} - onConfirm={() => { + onConfirm={(deploymentIds: string[]) => { modalSession.close(); - resolve(true); + resolve(deploymentIds); }} />, theme.theme$ @@ -80,7 +188,7 @@ export const getUserConfirmationProvider = ) ); } catch (e) { - resolve(false); + reject(); } }); }; diff --git a/x-pack/plugins/ml/public/application/model_management/model_actions.tsx b/x-pack/plugins/ml/public/application/model_management/model_actions.tsx index 7d4b6866f7a16..8fdb50f7b40f3 100644 --- a/x-pack/plugins/ml/public/application/model_management/model_actions.tsx +++ b/x-pack/plugins/ml/public/application/model_management/model_actions.tsx @@ -11,14 +11,14 @@ import { isPopulatedObject } from '@kbn/ml-is-populated-object'; import { EuiToolTip } from '@elastic/eui'; import React, { useCallback, useMemo } from 'react'; import { + BUILT_IN_MODEL_TAG, DEPLOYMENT_STATE, TRAINED_MODEL_TYPE, - BUILT_IN_MODEL_TAG, } from '@kbn/ml-trained-models-utils'; import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models'; import { getUserConfirmationProvider } from './force_stop_dialog'; import { useToastNotificationService } from '../services/toast_notification_service'; -import { getUserInputThreadingParamsProvider } from './deployment_setup'; +import { getUserInputModelDeploymentParamsProvider } from './deployment_setup'; import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana'; import { getAnalysisType } from '../../../common/util/analytics_utils'; import { DataFrameAnalysisConfigType } from '../../../common/types/data_frame_analytics'; @@ -32,12 +32,14 @@ export function useModelActions({ onLoading, isLoading, fetchModels, + modelAndDeploymentIds, }: { isLoading: boolean; - onTestAction: (model: string) => void; + onTestAction: (model: ModelItem) => void; onModelsDeleteRequest: (modelsIds: string[]) => void; onLoading: (isLoading: boolean) => void; fetchModels: () => void; + modelAndDeploymentIds: string[]; }): Array> { const { services: { @@ -67,8 +69,9 @@ export function useModelActions({ [overlays, theme] ); - const getUserInputThreadingParams = useMemo( - () => getUserInputThreadingParamsProvider(overlays, theme.theme$, startModelDeploymentDocUrl), + const getUserInputModelDeploymentParams = useMemo( + () => + getUserInputModelDeploymentParamsProvider(overlays, theme.theme$, startModelDeploymentDocUrl), [overlays, theme.theme$, startModelDeploymentDocUrl] ); @@ -151,26 +154,27 @@ export function useModelActions({ type: 'icon', isPrimary: true, enabled: (item) => { - const { state } = item.stats?.deployment_stats ?? {}; - return ( - canStartStopTrainedModels && - !isLoading && - state !== DEPLOYMENT_STATE.STARTED && - state !== DEPLOYMENT_STATE.STARTING - ); + return canStartStopTrainedModels && !isLoading; }, available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH, onClick: async (item) => { - const threadingParams = await getUserInputThreadingParams(item.model_id); + const modelDeploymentParams = await getUserInputModelDeploymentParams( + item, + undefined, + modelAndDeploymentIds + ); - if (!threadingParams) return; + if (!modelDeploymentParams) return; try { onLoading(true); await trainedModelsApiService.startModelAllocation(item.model_id, { - number_of_allocations: threadingParams.numOfAllocations, - threads_per_allocation: threadingParams.threadsPerAllocations!, - priority: threadingParams.priority!, + number_of_allocations: modelDeploymentParams.numOfAllocations, + threads_per_allocation: modelDeploymentParams.threadsPerAllocations!, + priority: modelDeploymentParams.priority!, + deployment_id: !!modelDeploymentParams.deploymentId + ? modelDeploymentParams.deploymentId + : item.model_id, }); displaySuccessToast( i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', { @@ -213,18 +217,23 @@ export function useModelActions({ item.model_type === TRAINED_MODEL_TYPE.PYTORCH && canStartStopTrainedModels && !isLoading && - item.stats?.deployment_stats?.state === DEPLOYMENT_STATE.STARTED, + !!item.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED), onClick: async (item) => { - const threadingParams = await getUserInputThreadingParams(item.model_id, { - numOfAllocations: item.stats?.deployment_stats?.number_of_allocations!, + const deploymentToUpdate = item.deployment_ids[0]; + + const deploymentParams = await getUserInputModelDeploymentParams(item, { + deploymentId: deploymentToUpdate, + numOfAllocations: item.stats!.deployment_stats.find( + (v) => v.deployment_id === deploymentToUpdate + )!.number_of_allocations, }); - if (!threadingParams) return; + if (!deploymentParams) return; try { onLoading(true); - await trainedModelsApiService.updateModelDeployment(item.model_id, { - number_of_allocations: threadingParams.numOfAllocations, + await trainedModelsApiService.updateModelDeployment(deploymentParams.deploymentId!, { + number_of_allocations: deploymentParams.numOfAllocations, }); displaySuccessToast( i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', { @@ -265,26 +274,23 @@ export function useModelActions({ isPrimary: true, available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH, enabled: (item) => - canStartStopTrainedModels && - !isLoading && - isPopulatedObject(item.stats?.deployment_stats) && - item.stats?.deployment_stats?.state !== DEPLOYMENT_STATE.STOPPING, + canStartStopTrainedModels && !isLoading && item.deployment_ids.length > 0, onClick: async (item) => { const requireForceStop = isPopulatedObject(item.pipelines); + const hasMultipleDeployments = item.deployment_ids.length > 1; - if (requireForceStop) { - const hasUserApproved = await getUserConfirmation(item); - if (!hasUserApproved) return; - } - - if (requireForceStop) { - const hasUserApproved = await getUserConfirmation(item); - if (!hasUserApproved) return; + let deploymentIds: string[] = item.deployment_ids; + if (requireForceStop || hasMultipleDeployments) { + try { + deploymentIds = await getUserConfirmation(item); + } catch (error) { + return; + } } try { onLoading(true); - await trainedModelsApiService.stopModelAllocation(item.model_id, { + await trainedModelsApiService.stopModelAllocation(deploymentIds, { force: requireForceStop, }); displaySuccessToast( @@ -363,28 +369,29 @@ export function useModelActions({ type: 'icon', isPrimary: true, available: isTestable, - onClick: (item) => onTestAction(item.model_id), - enabled: (item) => canTestTrainedModels && isTestable(item, true), + onClick: (item) => onTestAction(item), + enabled: (item) => canTestTrainedModels && isTestable(item, true) && !isLoading, }, ], [ - canDeleteTrainedModels, + urlLocator, + navigateToUrl, + navigateToPath, canStartStopTrainedModels, - canTestTrainedModels, - displayErrorToast, + isLoading, + getUserInputModelDeploymentParams, + modelAndDeploymentIds, + onLoading, + trainedModelsApiService, displaySuccessToast, + fetchModels, + displayErrorToast, getUserConfirmation, - getUserInputThreadingParams, + onModelsDeleteRequest, + canDeleteTrainedModels, isBuiltInModel, - navigateToPath, - navigateToUrl, onTestAction, - trainedModelsApiService, - urlLocator, - onModelsDeleteRequest, - onLoading, - fetchModels, - isLoading, + canTestTrainedModels, ] ); } diff --git a/x-pack/plugins/ml/public/application/model_management/models_list.tsx b/x-pack/plugins/ml/public/application/model_management/models_list.tsx index 48a1e6266b235..5887dc05c5369 100644 --- a/x-pack/plugins/ml/public/application/model_management/models_list.tsx +++ b/x-pack/plugins/ml/public/application/model_management/models_list.tsx @@ -18,7 +18,7 @@ import { EuiTitle, SearchFilterConfig, } from '@elastic/eui'; - +import { groupBy } from 'lodash'; import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiBasicTableColumn } from '@elastic/eui/src/components/basic_table/basic_table'; @@ -27,15 +27,21 @@ import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; import { usePageUrlState } from '@kbn/ml-url-state'; import { useTimefilter } from '@kbn/ml-date-picker'; -import { BUILT_IN_MODEL_TYPE, BUILT_IN_MODEL_TAG } from '@kbn/ml-trained-models-utils'; +import { + BUILT_IN_MODEL_TYPE, + BUILT_IN_MODEL_TAG, + DEPLOYMENT_STATE, +} from '@kbn/ml-trained-models-utils'; +import { isDefined } from '@kbn/ml-is-defined'; import { useModelActions } from './model_actions'; import { ModelsTableToConfigMapping } from '.'; import { ModelsBarStats, StatsBar } from '../components/stats_bar'; import { useMlKibana } from '../contexts/kibana'; import { useTrainedModelsApiService } from '../services/ml_api_service/trained_models'; -import { +import type { ModelPipelines, TrainedModelConfigResponse, + TrainedModelDeploymentStatsResponse, TrainedModelStat, } from '../../../common/types/trained_models'; import { DeleteModelsModal } from './delete_models_modal'; @@ -49,12 +55,13 @@ import { useRefresh } from '../routing/use_refresh'; import { SavedObjectsWarning } from '../components/saved_objects_warning'; import { TestTrainedModelFlyout } from './test_models'; -type Stats = Omit; +type Stats = Omit; export type ModelItem = TrainedModelConfigResponse & { type?: string[]; - stats?: Stats; + stats?: Stats & { deployment_stats: TrainedModelDeploymentStatsResponse[] }; pipelines?: ModelPipelines['pipelines'] | null; + deployment_ids: string[]; }; export type ModelItemFull = Required; @@ -120,7 +127,7 @@ export const ModelsList: FC = ({ const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState>( {} ); - const [showTestFlyout, setShowTestFlyout] = useState(null); + const [modelToTest, setModelToTest] = useState(null); const isBuiltInModel = useCallback( (item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG), @@ -150,11 +157,11 @@ export const ModelsList: FC = ({ type: [ model.model_type, ...Object.keys(model.inference_config), - ...(isBuiltInModel(model) ? [BUILT_IN_MODEL_TYPE] : []), + ...(isBuiltInModel(model as ModelItem) ? [BUILT_IN_MODEL_TYPE] : []), ], } : {}), - }; + } as ModelItem; newItems.push(tableItem); if (itemIdToExpandedRowMap[model.model_id]) { @@ -162,7 +169,8 @@ export const ModelsList: FC = ({ } } - // Need to fetch state for all models to enable/disable actions + // Need to fetch stats for all models to enable/disable actions + // TODO combine fetching models definitions and stats into a single function await fetchModelsStats(newItems); setItems(newItems); @@ -219,15 +227,19 @@ export const ModelsList: FC = ({ const { trained_model_stats: modelsStatsResponse } = await trainedModelsApiService.getTrainedModelStats(models.map((m) => m.model_id)); - for (const { model_id: id, ...stats } of modelsStatsResponse) { - const model = models.find((m) => m.model_id === id); - if (model) { - model.stats = { - ...(model.stats ?? {}), - ...stats, - }; - } - } + const groupByModelId = groupBy(modelsStatsResponse, 'model_id'); + + models.forEach((model) => { + const modelStats = groupByModelId[model.model_id]; + model.stats = { + ...(model.stats ?? {}), + ...modelStats[0], + deployment_stats: modelStats.map((d) => d.deployment_stats).filter(isDefined), + }; + model.deployment_ids = modelStats + .map((v) => v.deployment_stats?.deployment_id) + .filter(isDefined); + }); } return true; @@ -263,15 +275,23 @@ export const ModelsList: FC = ({ })); }, [items]); + const modelAndDeploymentIds = useMemo( + () => [ + ...new Set([...items.flatMap((v) => v.deployment_ids), ...items.map((i) => i.model_id)]), + ], + [items] + ); + /** * Table actions */ const actions = useModelActions({ isLoading, fetchModels: fetchModelsData, - onTestAction: setShowTestFlyout, + onTestAction: setModelToTest, onModelsDeleteRequest: setModelIdsToDelete, onLoading: setIsLoading, + modelAndDeploymentIds, }); const toggleDetails = async (item: ModelItem) => { @@ -351,11 +371,14 @@ export const ModelsList: FC = ({ name: i18n.translate('xpack.ml.trainedModels.modelsList.stateHeader', { defaultMessage: 'State', }), - sortable: (item) => item.stats?.deployment_stats?.state, align: 'left', - truncateText: true, + truncateText: false, render: (model: ModelItem) => { - const state = model.stats?.deployment_stats?.state; + const state = model.stats?.deployment_stats?.some( + (v) => v.state === DEPLOYMENT_STATE.STARTED + ) + ? DEPLOYMENT_STATE.STARTED + : ''; return state ? {state} : null; }, 'data-test-subj': 'mlModelsTableColumnDeploymentState', @@ -533,11 +556,8 @@ export const ModelsList: FC = ({ modelIds={modelIdsToDelete} /> )} - {showTestFlyout === null ? null : ( - + {modelToTest === null ? null : ( + )} ); diff --git a/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx b/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx index ac8156ae0053b..7eb20d77ec8d1 100644 --- a/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx +++ b/x-pack/plugins/ml/public/application/model_management/pipelines/pipelines.tsx @@ -45,9 +45,8 @@ export const ModelPipelines: FC = ({ pipelines, ingestStats const pipelineDefinition = pipelines?.[pipelineName]; return ( - <> + @@ -81,7 +80,7 @@ export const ModelPipelines: FC = ({ pipelines, ingestStats initialIsOpen={initialIsOpen} > - {ingestStats?.pipelines ? ( + {ingestStats!.pipelines[pipelineName]?.processors ? ( @@ -93,7 +92,7 @@ export const ModelPipelines: FC = ({ pipelines, ingestStats - + ) : null} @@ -123,7 +122,7 @@ export const ModelPipelines: FC = ({ pipelines, ingestStats ) : null} - + ); })} diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/inference_base.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/inference_base.ts index 1e3e3ecd5dfd7..2eff762343077 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/inference_base.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/inference_base.ts @@ -61,6 +61,8 @@ export abstract class InferenceBase { protected abstract readonly inferenceTypeLabel: string; protected readonly modelInputField: string; + protected _deploymentId: string | null = null; + protected inputText$ = new BehaviorSubject([]); private inputField$ = new BehaviorSubject(''); private inferenceResult$ = new BehaviorSubject(null); @@ -76,7 +78,8 @@ export abstract class InferenceBase { constructor( protected readonly trainedModelsApi: ReturnType, protected readonly model: estypes.MlTrainedModelConfig, - protected readonly inputType: INPUT_TYPE + protected readonly inputType: INPUT_TYPE, + protected readonly deploymentId: string ) { this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD; this.inputField$.next(this.modelInputField); @@ -243,7 +246,7 @@ export abstract class InferenceBase { ): estypes.IngestProcessorContainer[] { const processor: estypes.IngestProcessorContainer = { inference: { - model_id: this.model.model_id, + model_id: this.deploymentId ?? this.model.model_id, target_field: this.inferenceType, field_map: { [this.inputField$.getValue()]: this.modelInputField, @@ -277,7 +280,7 @@ export abstract class InferenceBase { const inferenceConfig = getInferenceConfig(); const resp = (await this.trainedModelsApi.inferTrainedModel( - this.model.model_id, + this.deploymentId ?? this.model.model_id, { docs: this.getInferDocs(), ...(inferenceConfig ? { inference_config: inferenceConfig } : {}), diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/ner/ner_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/ner/ner_inference.ts index f01127e94cda6..15cdec114aad4 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/ner/ner_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/ner/ner_inference.ts @@ -36,9 +36,10 @@ export class NerInference extends InferenceBase { constructor( trainedModelsApi: ReturnType, model: estypes.MlTrainedModelConfig, - inputType: INPUT_TYPE + inputType: INPUT_TYPE, + deploymentId: string ) { - super(trainedModelsApi, model, inputType); + super(trainedModelsApi, model, inputType, deploymentId); this.initialize(); } diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts index 6c80703340d1d..b428442f8908d 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts @@ -62,9 +62,10 @@ export class QuestionAnsweringInference extends InferenceBase, model: estypes.MlTrainedModelConfig, - inputType: INPUT_TYPE + inputType: INPUT_TYPE, + deploymentId: string ) { - super(trainedModelsApi, model, inputType); + super(trainedModelsApi, model, inputType, deploymentId); this.initialize( [this.questionText$.pipe(map((questionText) => questionText !== ''))], diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts index 2ec63d4453288..0c109292f16f1 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts @@ -34,9 +34,10 @@ export class FillMaskInference extends InferenceBase constructor( trainedModelsApi: ReturnType, model: estypes.MlTrainedModelConfig, - inputType: INPUT_TYPE + inputType: INPUT_TYPE, + deploymentId: string ) { - super(trainedModelsApi, model, inputType); + super(trainedModelsApi, model, inputType, deploymentId); this.initialize([ this.inputText$.pipe(map((inputText) => inputText.every((t) => t.includes(MASK)))), diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts index ff6565de32b19..198d30f42b677 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts @@ -30,9 +30,10 @@ export class LangIdentInference extends InferenceBase, model: estypes.MlTrainedModelConfig, - inputType: INPUT_TYPE + inputType: INPUT_TYPE, + deploymentId: string ) { - super(trainedModelsApi, model, inputType); + super(trainedModelsApi, model, inputType, deploymentId); this.initialize(); } diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts index 362e162c24b5e..3ee9588162544 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts @@ -30,9 +30,10 @@ export class TextClassificationInference extends InferenceBase, model: estypes.MlTrainedModelConfig, - inputType: INPUT_TYPE + inputType: INPUT_TYPE, + deploymentId: string ) { - super(trainedModelsApi, model, inputType); + super(trainedModelsApi, model, inputType, deploymentId); this.initialize(); } diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts index 62af6f2dd55ae..19cc970826821 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts @@ -37,9 +37,10 @@ export class ZeroShotClassificationInference extends InferenceBase, model: estypes.MlTrainedModelConfig, - inputType: INPUT_TYPE + inputType: INPUT_TYPE, + deploymentId: string ) { - super(trainedModelsApi, model, inputType); + super(trainedModelsApi, model, inputType, deploymentId); this.initialize( [this.labelsText$.pipe(map((labelsText) => labelsText !== ''))], diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts index 1200690d665aa..258e4161f626b 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts @@ -42,9 +42,10 @@ export class TextEmbeddingInference extends InferenceBase constructor( trainedModelsApi: ReturnType, model: estypes.MlTrainedModelConfig, - inputType: INPUT_TYPE + inputType: INPUT_TYPE, + deploymentId: string ) { - super(trainedModelsApi, model, inputType); + super(trainedModelsApi, model, inputType, deploymentId); this.initialize(); } diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/selected_model.tsx b/x-pack/plugins/ml/public/application/model_management/test_models/selected_model.tsx index 3749620b47b13..c719aa7368cfa 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/selected_model.tsx +++ b/x-pack/plugins/ml/public/application/model_management/test_models/selected_model.tsx @@ -29,42 +29,42 @@ import { INPUT_TYPE } from './models/inference_base'; interface Props { model: estypes.MlTrainedModelConfig; inputType: INPUT_TYPE; + deploymentId: string; } -export const SelectedModel: FC = ({ model, inputType }) => { +export const SelectedModel: FC = ({ model, inputType, deploymentId }) => { const { trainedModels } = useMlApiContext(); - const inferrer: InferrerType | undefined = useMemo(() => { + const inferrer = useMemo(() => { if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) { const taskType = Object.keys(model.inference_config)[0]; switch (taskType) { case SUPPORTED_PYTORCH_TASKS.NER: - return new NerInference(trainedModels, model, inputType); + return new NerInference(trainedModels, model, inputType, deploymentId); break; case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION: - return new TextClassificationInference(trainedModels, model, inputType); + return new TextClassificationInference(trainedModels, model, inputType, deploymentId); break; case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION: - return new ZeroShotClassificationInference(trainedModels, model, inputType); + return new ZeroShotClassificationInference(trainedModels, model, inputType, deploymentId); break; case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING: - return new TextEmbeddingInference(trainedModels, model, inputType); + return new TextEmbeddingInference(trainedModels, model, inputType, deploymentId); break; case SUPPORTED_PYTORCH_TASKS.FILL_MASK: - return new FillMaskInference(trainedModels, model, inputType); + return new FillMaskInference(trainedModels, model, inputType, deploymentId); break; case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING: - return new QuestionAnsweringInference(trainedModels, model, inputType); + return new QuestionAnsweringInference(trainedModels, model, inputType, deploymentId); break; - default: break; } } else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) { - return new LangIdentInference(trainedModels, model, inputType); + return new LangIdentInference(trainedModels, model, inputType, deploymentId); } - }, [inputType, model, trainedModels]); + }, [inputType, model, trainedModels, deploymentId]); useEffect(() => { return () => { diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx b/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx index 1ee8a853bb477..d3fa13b233f5f 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx +++ b/x-pack/plugins/ml/public/application/model_management/test_models/test_flyout.tsx @@ -5,50 +5,35 @@ * 2.0. */ -import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; -import React, { FC, useState, useEffect } from 'react'; +import React, { FC, useState } from 'react'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiFlyout, - EuiFlyoutHeader, - EuiTitle, EuiFlyoutBody, + EuiFlyoutHeader, + EuiFormRow, + EuiSelect, EuiSpacer, EuiTab, EuiTabs, + EuiTitle, useEuiPaddingSize, } from '@elastic/eui'; import { SelectedModel } from './selected_model'; import { INPUT_TYPE } from './models/inference_base'; -import { useTrainedModelsApiService } from '../../services/ml_api_service/trained_models'; +import { type ModelItem } from '../models_list'; interface Props { - modelId: string; + model: ModelItem; onClose: () => void; } -export const TestTrainedModelFlyout: FC = ({ modelId, onClose }) => { +export const TestTrainedModelFlyout: FC = ({ model, onClose }) => { + const [deploymentId, setDeploymentId] = useState(model.deployment_ids[0]); const mediumPadding = useEuiPaddingSize('m'); - const trainedModelsApiService = useTrainedModelsApiService(); const [inputType, setInputType] = useState(INPUT_TYPE.TEXT); - const [model, setModel] = useState(null); - - useEffect( - function fetchModel() { - trainedModelsApiService.getTrainedModels(modelId).then((resp) => { - if (resp.length) { - setModel(resp[0]); - } - }); - }, - [modelId, trainedModelsApiService] - ); - - if (model === null) { - return null; - } return ( <> @@ -68,6 +53,32 @@ export const TestTrainedModelFlyout: FC = ({ modelId, onClose }) => { + {model.deployment_ids.length > 1 ? ( + <> + + } + > + { + return { text: v, value: v }; + })} + value={deploymentId} + onChange={(e) => { + setDeploymentId(e.target.value); + }} + /> + + + + ) : null} + = ({ modelId, onClose }) => { - + diff --git a/x-pack/plugins/ml/public/application/model_management/test_models/utils.ts b/x-pack/plugins/ml/public/application/model_management/test_models/utils.ts index 2048af3e31173..bb0dc6e9973e8 100644 --- a/x-pack/plugins/ml/public/application/model_management/test_models/utils.ts +++ b/x-pack/plugins/ml/public/application/model_management/test_models/utils.ts @@ -22,7 +22,7 @@ export function isTestable(modelItem: ModelItem, checkForState = false) { Object.keys(modelItem.inference_config)[0] as SupportedPytorchTasksType ) && (checkForState === false || - modelItem.stats?.deployment_stats?.state === DEPLOYMENT_STATE.STARTED) + modelItem.stats?.deployment_stats?.some((v) => v.state === DEPLOYMENT_STATE.STARTED)) ) { return true; } diff --git a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts index 1548a298acd18..2975982eccad0 100644 --- a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts +++ b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts @@ -132,6 +132,7 @@ export function trainedModelsApiProvider(httpService: HttpService) { number_of_allocations: number; threads_per_allocation: number; priority: 'low' | 'normal'; + deployment_id?: string; } ) { return httpService.http<{ acknowledge: boolean }>({ @@ -141,11 +142,11 @@ export function trainedModelsApiProvider(httpService: HttpService) { }); }, - stopModelAllocation(modelId: string, options: { force: boolean } = { force: false }) { + stopModelAllocation(deploymentsIds: string[], options: { force: boolean } = { force: false }) { const force = options?.force; return httpService.http<{ acknowledge: boolean }>({ - path: `${apiBasePath}/trained_models/${modelId}/deployment/_stop`, + path: `${apiBasePath}/trained_models/${deploymentsIds.join(',')}/deployment/_stop`, method: 'POST', query: { force }, }); diff --git a/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts b/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts index 691279977740d..d7904c182d09e 100644 --- a/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts +++ b/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts @@ -491,7 +491,6 @@ export function getMlClient( return mlClient.startTrainedModelDeployment(...p); }, async updateTrainedModelDeployment(...p: Parameters) { - await modelIdsCheck(p); const { model_id: modelId, number_of_allocations: numberOfAllocations } = p[0]; return client.asInternalUser.transport.request({ method: 'POST', @@ -500,11 +499,9 @@ export function getMlClient( }); }, async stopTrainedModelDeployment(...p: Parameters) { - await modelIdsCheck(p); return mlClient.stopTrainedModelDeployment(...p); }, async inferTrainedModel(...p: Parameters) { - await modelIdsCheck(p); // Temporary workaround for the incorrect inferTrainedModelDeployment function in the esclient if ( // @ts-expect-error TS complains it's always false diff --git a/x-pack/plugins/ml/server/models/model_management/__mocks__/mock_deployment_response.json b/x-pack/plugins/ml/server/models/model_management/__mocks__/mock_deployment_response.json index cc36165debe5d..cbeda1e304a16 100644 --- a/x-pack/plugins/ml/server/models/model_management/__mocks__/mock_deployment_response.json +++ b/x-pack/plugins/ml/server/models/model_management/__mocks__/mock_deployment_response.json @@ -7,6 +7,7 @@ }, "pipeline_count" : 0, "deployment_stats": { + "deployment_id": "distilbert-base-uncased-finetuned-sst-2-english", "model_id": "distilbert-base-uncased-finetuned-sst-2-english", "inference_threads": 1, "model_threads": 1, @@ -102,6 +103,7 @@ }, "pipeline_count" : 0, "deployment_stats": { + "deployment_id": "elastic__distilbert-base-cased-finetuned-conll03-english", "model_id": "elastic__distilbert-base-cased-finetuned-conll03-english", "inference_threads": 1, "model_threads": 1, @@ -197,6 +199,7 @@ }, "pipeline_count" : 0, "deployment_stats": { + "deployment_id": "sentence-transformers__msmarco-minilm-l-12-v3", "model_id": "sentence-transformers__msmarco-minilm-l-12-v3", "inference_threads": 1, "model_threads": 1, @@ -292,6 +295,7 @@ }, "pipeline_count" : 0, "deployment_stats": { + "deployment_id": "typeform__mobilebert-uncased-mnli", "model_id": "typeform__mobilebert-uncased-mnli", "inference_threads": 1, "model_threads": 1, diff --git a/x-pack/plugins/ml/server/models/model_management/memory_usage.test.ts b/x-pack/plugins/ml/server/models/model_management/memory_usage.test.ts index e08a94e4d951c..3f85487a4cbf3 100644 --- a/x-pack/plugins/ml/server/models/model_management/memory_usage.test.ts +++ b/x-pack/plugins/ml/server/models/model_management/memory_usage.test.ts @@ -150,7 +150,6 @@ describe('Model service', () => { }, nodes: [ { - name: 'node3', allocated_models: [ { allocation_status: { @@ -158,12 +157,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'distilbert-base-uncased-finetuned-sst-2-english', inference_threads: 1, + key: 'distilbert-base-uncased-finetuned-sst-2-english_node3', model_id: 'distilbert-base-uncased-finetuned-sst-2-english', model_size_bytes: 267386880, - required_native_memory_bytes: 534773760, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -171,6 +170,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 534773760, + state: 'started', }, { allocation_status: { @@ -178,12 +179,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'elastic__distilbert-base-cased-finetuned-conll03-english', inference_threads: 1, + key: 'elastic__distilbert-base-cased-finetuned-conll03-english_node3', model_id: 'elastic__distilbert-base-cased-finetuned-conll03-english', model_size_bytes: 260947500, - required_native_memory_bytes: 521895000, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -191,6 +192,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 521895000, + state: 'started', }, { allocation_status: { @@ -198,12 +201,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'sentence-transformers__msmarco-minilm-l-12-v3', inference_threads: 1, + key: 'sentence-transformers__msmarco-minilm-l-12-v3_node3', model_id: 'sentence-transformers__msmarco-minilm-l-12-v3', model_size_bytes: 133378867, - required_native_memory_bytes: 266757734, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -211,6 +214,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 266757734, + state: 'started', }, { allocation_status: { @@ -218,12 +223,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'typeform__mobilebert-uncased-mnli', inference_threads: 1, + key: 'typeform__mobilebert-uncased-mnli_node3', model_id: 'typeform__mobilebert-uncased-mnli', model_size_bytes: 100139008, - required_native_memory_bytes: 200278016, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -231,6 +236,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 200278016, + state: 'started', }, ], attributes: { @@ -239,7 +246,6 @@ describe('Model service', () => { }, id: '3qIoLFnbSi-DwVrYioUCdw', memory_overview: { - ml_max_in_bytes: 1073741824, anomaly_detection: { total: 0, }, @@ -250,6 +256,7 @@ describe('Model service', () => { jvm: 1073741824, total: 15599742976, }, + ml_max_in_bytes: 1073741824, trained_models: { by_model: [ { @@ -272,10 +279,10 @@ describe('Model service', () => { total: 1555161790, }, }, + name: 'node3', roles: ['data', 'ingest', 'master', 'ml', 'transform'], }, { - name: 'node2', allocated_models: [ { allocation_status: { @@ -283,18 +290,20 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'distilbert-base-uncased-finetuned-sst-2-english', inference_threads: 1, + key: 'distilbert-base-uncased-finetuned-sst-2-english_node2', model_id: 'distilbert-base-uncased-finetuned-sst-2-english', model_size_bytes: 267386880, - required_native_memory_bytes: 534773760, model_threads: 1, - state: 'started', node: { routing_state: { reason: 'The object cannot be set twice!', routing_state: 'failed', }, }, + required_native_memory_bytes: 534773760, + state: 'started', }, { allocation_status: { @@ -302,18 +311,20 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'elastic__distilbert-base-cased-finetuned-conll03-english', inference_threads: 1, + key: 'elastic__distilbert-base-cased-finetuned-conll03-english_node2', model_id: 'elastic__distilbert-base-cased-finetuned-conll03-english', model_size_bytes: 260947500, - required_native_memory_bytes: 521895000, model_threads: 1, - state: 'started', node: { routing_state: { reason: 'The object cannot be set twice!', routing_state: 'failed', }, }, + required_native_memory_bytes: 521895000, + state: 'started', }, { allocation_status: { @@ -321,18 +332,20 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'sentence-transformers__msmarco-minilm-l-12-v3', inference_threads: 1, + key: 'sentence-transformers__msmarco-minilm-l-12-v3_node2', model_id: 'sentence-transformers__msmarco-minilm-l-12-v3', model_size_bytes: 133378867, - required_native_memory_bytes: 266757734, model_threads: 1, - state: 'started', node: { routing_state: { reason: 'The object cannot be set twice!', routing_state: 'failed', }, }, + required_native_memory_bytes: 266757734, + state: 'started', }, { allocation_status: { @@ -340,18 +353,20 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'typeform__mobilebert-uncased-mnli', inference_threads: 1, + key: 'typeform__mobilebert-uncased-mnli_node2', model_id: 'typeform__mobilebert-uncased-mnli', model_size_bytes: 100139008, - required_native_memory_bytes: 200278016, model_threads: 1, - state: 'started', node: { routing_state: { reason: 'The object cannot be set twice!', routing_state: 'failed', }, }, + required_native_memory_bytes: 200278016, + state: 'started', }, ], attributes: { @@ -360,7 +375,6 @@ describe('Model service', () => { }, id: 'DpCy7SOBQla3pu0Dq-tnYw', memory_overview: { - ml_max_in_bytes: 1073741824, anomaly_detection: { total: 0, }, @@ -371,6 +385,7 @@ describe('Model service', () => { jvm: 1073741824, total: 15599742976, }, + ml_max_in_bytes: 1073741824, trained_models: { by_model: [ { @@ -393,6 +408,7 @@ describe('Model service', () => { total: 1555161790, }, }, + name: 'node2', roles: ['data', 'master', 'ml', 'transform'], }, { @@ -403,12 +419,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'distilbert-base-uncased-finetuned-sst-2-english', inference_threads: 1, + key: 'distilbert-base-uncased-finetuned-sst-2-english_node1', model_id: 'distilbert-base-uncased-finetuned-sst-2-english', model_size_bytes: 267386880, - required_native_memory_bytes: 534773760, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -416,6 +432,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 534773760, + state: 'started', }, { allocation_status: { @@ -423,12 +441,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'elastic__distilbert-base-cased-finetuned-conll03-english', inference_threads: 1, + key: 'elastic__distilbert-base-cased-finetuned-conll03-english_node1', model_id: 'elastic__distilbert-base-cased-finetuned-conll03-english', model_size_bytes: 260947500, - required_native_memory_bytes: 521895000, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -436,6 +454,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 521895000, + state: 'started', }, { allocation_status: { @@ -443,12 +463,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'sentence-transformers__msmarco-minilm-l-12-v3', inference_threads: 1, + key: 'sentence-transformers__msmarco-minilm-l-12-v3_node1', model_id: 'sentence-transformers__msmarco-minilm-l-12-v3', model_size_bytes: 133378867, - required_native_memory_bytes: 266757734, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -456,6 +476,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 266757734, + state: 'started', }, { allocation_status: { @@ -463,12 +485,12 @@ describe('Model service', () => { state: 'started', target_allocation_count: 3, }, + deployment_id: 'typeform__mobilebert-uncased-mnli', inference_threads: 1, + key: 'typeform__mobilebert-uncased-mnli_node1', model_id: 'typeform__mobilebert-uncased-mnli', model_size_bytes: 100139008, - required_native_memory_bytes: 200278016, model_threads: 1, - state: 'started', node: { average_inference_time_ms: 0, inference_count: 0, @@ -476,6 +498,8 @@ describe('Model service', () => { routing_state: 'started', }, }, + required_native_memory_bytes: 200278016, + state: 'started', }, ], attributes: { @@ -484,7 +508,6 @@ describe('Model service', () => { }, id: 'pt7s6lKHQJaP4QHKtU-Q0Q', memory_overview: { - ml_max_in_bytes: 1073741824, anomaly_detection: { total: 0, }, @@ -495,6 +518,7 @@ describe('Model service', () => { jvm: 1073741824, total: 15599742976, }, + ml_max_in_bytes: 1073741824, trained_models: { by_model: [ { diff --git a/x-pack/plugins/ml/server/models/model_management/memory_usage.ts b/x-pack/plugins/ml/server/models/model_management/memory_usage.ts index 29c51055efe5b..541b396b0a6e5 100644 --- a/x-pack/plugins/ml/server/models/model_management/memory_usage.ts +++ b/x-pack/plugins/ml/server/models/model_management/memory_usage.ts @@ -199,6 +199,7 @@ export class MemoryUsageService { ...rest, ...modelSizeState, node: nodeRest, + key: `${rest.deployment_id}_${node.name}`, }; }); diff --git a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts index dd0a15fa20f83..034ad761f2863 100644 --- a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts +++ b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts @@ -19,6 +19,7 @@ export const threadingParamsSchema = schema.maybe( number_of_allocations: schema.number(), threads_per_allocation: schema.number(), priority: schema.oneOf([schema.literal('low'), schema.literal('normal')]), + deployment_id: schema.maybe(schema.string()), }) ); diff --git a/x-pack/plugins/ml/server/routes/trained_models.ts b/x-pack/plugins/ml/server/routes/trained_models.ts index c09b46f7c23cb..06a0849488a9c 100644 --- a/x-pack/plugins/ml/server/routes/trained_models.ts +++ b/x-pack/plugins/ml/server/routes/trained_models.ts @@ -10,13 +10,13 @@ import { RouteInitialization } from '../types'; import { wrapError } from '../client/error_wrapper'; import { getInferenceQuerySchema, + inferTrainedModelBody, + inferTrainedModelQuery, modelIdSchema, optionalModelIdSchema, + pipelineSimulateBody, putTrainedModelQuerySchema, - inferTrainedModelQuery, - inferTrainedModelBody, threadingParamsSchema, - pipelineSimulateBody, updateDeploymentParamsSchema, } from './schemas/inference_schema'; @@ -59,14 +59,33 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization) const result = body.trained_model_configs as TrainedModelConfigResponse[]; try { if (withPipelines) { + // Also need to retrieve the list of deployment IDs from stats + const stats = await mlClient.getTrainedModelsStats({ + ...(modelId ? { model_id: modelId } : {}), + size: 10000, + }); + + const modelDeploymentsMap = stats.trained_model_stats.reduce((acc, curr) => { + if (!curr.deployment_stats) return acc; + // @ts-ignore elasticsearch-js client is missing deployment_id + const deploymentId = curr.deployment_stats.deployment_id; + if (acc[curr.model_id]) { + acc[curr.model_id].push(deploymentId); + } else { + acc[curr.model_id] = [deploymentId]; + } + return acc; + }, {} as Record); + const modelIdsAndAliases: string[] = Array.from( - new Set( - result + new Set([ + ...result .map(({ model_id: id, metadata }) => { return [id, ...(metadata?.model_aliases ?? [])]; }) - .flat() - ) + .flat(), + ...Object.values(modelDeploymentsMap).flat(), + ]) ); const pipelinesResponse = await modelsProvider(client).getModelsPipelines( @@ -81,6 +100,12 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization) ...(pipelinesResponse.get(alias) ?? {}), }; }, {}), + ...(modelDeploymentsMap[model.model_id] ?? []).reduce((acc, deploymentId) => { + return { + ...acc, + ...(pipelinesResponse.get(deploymentId) ?? {}), + }; + }, {}), }; } } diff --git a/x-pack/plugins/translations/translations/fr-FR.json b/x-pack/plugins/translations/translations/fr-FR.json index 453deac3c93e0..d3bfc3f941a3a 100644 --- a/x-pack/plugins/translations/translations/fr-FR.json +++ b/x-pack/plugins/translations/translations/fr-FR.json @@ -21379,7 +21379,6 @@ "xpack.ml.timeSeriesExplorer.timeSeriesChart.zoomAggregationIntervalLabel": "(intervalle d'agrégation : {focusAggInt}, étendue du compartiment : {bucketSpan})", "xpack.ml.trainedModels.modelsList.deleteModal.header": "Supprimer {modelsCount, plural, one {{modelId}} other {# modèles}} ?", "xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "La suppression {modelsCount, plural, one {du modèle} other {des modèles}} a échoué", - "xpack.ml.trainedModels.modelsList.forceStopDialog.title": "Arrêter le modèle {modelId} ?", "xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, one {# modèle sélectionné} other {# modèles sélectionnés}}", "xpack.ml.trainedModels.modelsList.startDeployment.modalTitle": "Démarrer le déploiement de {modelId}", "xpack.ml.trainedModels.modelsList.startFailed": "Impossible de démarrer \"{modelId}\"", diff --git a/x-pack/plugins/translations/translations/ja-JP.json b/x-pack/plugins/translations/translations/ja-JP.json index a1e81adaff790..3c27b1d843c3b 100644 --- a/x-pack/plugins/translations/translations/ja-JP.json +++ b/x-pack/plugins/translations/translations/ja-JP.json @@ -21366,7 +21366,6 @@ "xpack.ml.timeSeriesExplorer.timeSeriesChart.updatedAnnotationNotificationMessage": "ID {jobId}のジョブの注釈が更新されました。", "xpack.ml.timeSeriesExplorer.timeSeriesChart.zoomAggregationIntervalLabel": "(アグリゲーション間隔:{focusAggInt}、バケットスパン:{bucketSpan})", "xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {モデル}}の削除が失敗しました", - "xpack.ml.trainedModels.modelsList.forceStopDialog.title": "モデル{modelId}を停止しますか?", "xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, other {#個のモデル}}を選択済み", "xpack.ml.trainedModels.modelsList.startDeployment.modalTitle": "{modelId}デプロイを開始", "xpack.ml.trainedModels.modelsList.startFailed": "\"{modelId}\"の開始に失敗しました", diff --git a/x-pack/plugins/translations/translations/zh-CN.json b/x-pack/plugins/translations/translations/zh-CN.json index 4175b6ff20149..7d6d834b7a36a 100644 --- a/x-pack/plugins/translations/translations/zh-CN.json +++ b/x-pack/plugins/translations/translations/zh-CN.json @@ -21378,7 +21378,6 @@ "xpack.ml.timeSeriesExplorer.timeSeriesChart.zoomAggregationIntervalLabel": "(聚合时间间隔:{focusAggInt},存储桶跨度:{bucketSpan})", "xpack.ml.trainedModels.modelsList.deleteModal.header": "删除 {modelsCount, plural, one {{modelId}} other {# 个模型}}?", "xpack.ml.trainedModels.modelsList.fetchDeletionErrorMessage": "{modelsCount, plural, other {模型}}删除失败", - "xpack.ml.trainedModels.modelsList.forceStopDialog.title": "停止模型 {modelId}", "xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, other {# 个模型}}已选择", "xpack.ml.trainedModels.modelsList.startDeployment.modalTitle": "启动 {modelId} 部署", "xpack.ml.trainedModels.modelsList.startFailed": "无法启动“{modelId}”",