Skip to content

Commit

Permalink
Merge branch 'main' into fix/network-acl-name
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Aug 29, 2023
2 parents eb8218e + 1cead3b commit 56f9f8d
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { renderEnvironment, renderTags } from './private/utils';
import * as ec2 from '../../../aws-ec2';
import * as iam from '../../../aws-iam';
import * as sfn from '../../../aws-stepfunctions';
import { Duration, Lazy, Size, Stack } from '../../../core';
import { Duration, Lazy, Size, Stack, Token } from '../../../core';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';

/**
Expand Down Expand Up @@ -163,6 +163,14 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
throw new Error('Must define either an algorithm name or training image URI in the algorithm specification');
}

// check that both algorithm name and image are not defined
if (props.algorithmSpecification.algorithmName && props.algorithmSpecification.trainingImage) {
throw new Error('Cannot define both an algorithm name and training image URI in the algorithm specification');
}

// validate algorithm name
this.validateAlgorithmName(props.algorithmSpecification.algorithmName);

// set the input mode to 'File' if not defined
this.algorithmSpecification = props.algorithmSpecification.trainingInputMode
? props.algorithmSpecification
Expand Down Expand Up @@ -324,6 +332,21 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
: {};
}

private validateAlgorithmName(algorithmName?: string): void {
if (algorithmName === undefined || Token.isUnresolved(algorithmName)) {
return;
}

if (algorithmName.length < 1 || 170 < algorithmName.length) {
throw new Error(`Algorithm name length must be between 1 and 170, but got ${algorithmName.length}`);
}

const regex = /^(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]([a-zA-Z0-9-]){0,62})(?<!-)$/;
if (!regex.test(algorithmName)) {
throw new Error(`Expected algorithm name to match pattern ${regex.source}, but got ${algorithmName}`);
}
}

private makePolicyStatements(): iam.PolicyStatement[] {
// set the sagemaker role or create new one
this._grantPrincipal = this._role =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,146 @@ test('Cannot create a SageMaker train task with both algorithm name and image na
}))
.toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/);
});

test('Cannot create a SageMaker train task with both algorithm name and image name defined', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'BlazingText',
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Cannot define both an algorithm name and training image URI in the algorithm specification/);
});

test('create a SageMaker train task with trainingImage', () => {

const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
'TrainingImage.$': '$.Training.imageName',
'TrainingInputMode': 'File',
},
},
});
});

test('create a SageMaker train task with image URI algorithmName', () => {

const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
AlgorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
},
},
});
});

test('Cannot create a SageMaker train task when algorithmName length is 171 or more', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'a'.repeat(171), // maximum length is 170
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Algorithm name length must be between 1 and 170, but got 171/);
});

test('Cannot create a SageMaker train task with incorrect algorithmName', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'Blazing_Text', // underscores are not allowed
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Expected algorithm name to match pattern/);
});

0 comments on commit 56f9f8d

Please sign in to comment.