From 1cead3b032c382f8cf63f72508c3222a3c58e85a Mon Sep 17 00:00:00 2001 From: Tomoya Oda <38136327+tmyoda@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:44:03 +0100 Subject: [PATCH] feat(stepfunctions-tasks): `algorithmName` validation for `SageMakerCreateTrainingJob` (#26877) Referencing PR #26675, I have added validation for the `algorithmName` parameter in `SageMakerCreateTrainingJob`. However, it was suggested that changes for validation should be separated. So, I have created this PR. Docs for `algorithmName`: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#API_AlgorithmSpecification_Contents Exemption Request: This change does not alter the behavior. I believe the unit test `create-training-job.test.ts` that I have added is sufficient to test this change. ---- *By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license* --- .../lib/sagemaker/create-training-job.ts | 25 ++- .../sagemaker/create-training-job.test.ts | 143 ++++++++++++++++++ 2 files changed, 167 insertions(+), 1 deletion(-) diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index f04842249e24b..ff2ac0c76e94e 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -4,7 +4,7 @@ import { renderEnvironment, renderTags } from './private/utils'; import * as ec2 from '../../../aws-ec2'; import * as iam from '../../../aws-iam'; import * as sfn from '../../../aws-stepfunctions'; -import { Duration, Lazy, Size, Stack } from '../../../core'; +import { Duration, Lazy, Size, Stack, Token } from '../../../core'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; /** @@ -163,6 +163,14 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam throw new Error('Must define either an algorithm name or training image URI in the algorithm specification'); } + // check that both algorithm name and image are not defined + if (props.algorithmSpecification.algorithmName && props.algorithmSpecification.trainingImage) { + throw new Error('Cannot define both an algorithm name and training image URI in the algorithm specification'); + } + + // validate algorithm name + this.validateAlgorithmName(props.algorithmSpecification.algorithmName); + // set the input mode to 'File' if not defined this.algorithmSpecification = props.algorithmSpecification.trainingInputMode ? props.algorithmSpecification @@ -324,6 +332,21 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam : {}; } + private validateAlgorithmName(algorithmName?: string): void { + if (algorithmName === undefined || Token.isUnresolved(algorithmName)) { + return; + } + + if (algorithmName.length < 1 || 170 < algorithmName.length) { + throw new Error(`Algorithm name length must be between 1 and 170, but got ${algorithmName.length}`); + } + + const regex = /^(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]([a-zA-Z0-9-]){0,62})(? { + + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'BlazingText', + trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')), + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + })) + .toThrowError(/Cannot define both an algorithm name and training image URI in the algorithm specification/); +}); + +test('create a SageMaker train task with trainingImage', () => { + + const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')), + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toMatchObject({ + Parameters: { + AlgorithmSpecification: { + 'TrainingImage.$': '$.Training.imageName', + 'TrainingInputMode': 'File', + }, + }, + }); +}); + +test('create a SageMaker train task with image URI algorithmName', () => { + + const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees', + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toMatchObject({ + Parameters: { + AlgorithmSpecification: { + AlgorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees', + }, + }, + }); +}); + +test('Cannot create a SageMaker train task when algorithmName length is 171 or more', () => { + + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'a'.repeat(171), // maximum length is 170 + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + })) + .toThrowError(/Algorithm name length must be between 1 and 170, but got 171/); +}); + +test('Cannot create a SageMaker train task with incorrect algorithmName', () => { + + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: { + algorithmName: 'Blazing_Text', // underscores are not allowed + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + })) + .toThrowError(/Expected algorithm name to match pattern/); +});