Skip to content

Commit

Permalink
feat(stepfunctions-tasks): add support for ModelClientConfig to SageM…
Browse files Browse the repository at this point in the history
…akerCreateTransformJob (#11892)

Noticed support for [ModelClientConfig](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html#sagemaker-CreateTransformJob-request-ModelClientConfig) was missing from this particular type of job, so attempted to add it.

----

*By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
  • Loading branch information
setu4993 authored Dec 15, 2020
1 parent 78c185d commit bf05092
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 7 deletions.
4 changes: 4 additions & 0 deletions packages/@aws-cdk/aws-stepfunctions-tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,10 @@ You can call the [`CreateTransformJob`](https://docs.aws.amazon.com/sagemaker/la
new sfn.SagemakerTransformTask(this, 'Batch Inference', {
transformJobName: 'MyTransformJob',
modelName: 'MyModelName',
modelClientOptions: {
invocationMaxRetries: 3, // default is 0
invocationTimeout: cdk.Duration.minutes(5), // default is 60 seconds
}
role,
transformInput: {
transformDataSource: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,28 @@ export enum CompressionType {
// Create Transform Job types
//

/**
* Configures the timeout and maximum number of retries for processing a transform job invocation.
*
* @experimental
*/
export interface ModelClientOptions {

/**
* The maximum number of retries when invocation requests are failing.
*
* @default 0
*/
readonly invocationsMaxRetries?: number;

/**
* The timeout duration for an invocation request.
*
* @default Duration.minutes(1)
*/
readonly invocationsTimeout?: Duration;
}

/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import * as ec2 from '@aws-cdk/aws-ec2';
import * as iam from '@aws-cdk/aws-iam';
import * as sfn from '@aws-cdk/aws-stepfunctions';
import { Size, Stack } from '@aws-cdk/core';
import { Size, Stack, Token } from '@aws-cdk/core';
import { Construct } from 'constructs';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types';
import { BatchStrategy, ModelClientOptions, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types';
import { renderTags } from './private/utils';

/**
* Properties for creating an Amazon SageMaker training job task
* Properties for creating an Amazon SageMaker transform job task
*
* @experimental
*/
export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps {
/**
* Training Job Name.
* Transform Job Name.
*/
readonly transformJobName: string;

/**
* Role for the Training Job.
* Role for the Transform Job.
*
* @default - A role is created with `AmazonSageMakerFullAccess` managed policy
*/
Expand Down Expand Up @@ -59,6 +59,13 @@ export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps
*/
readonly modelName: string;

/**
* Configures the timeout and maximum number of retries for processing a transform job invocation.
*
* @default - 0 retries and 60 seconds of timeout
*/
readonly modelClientOptions?: ModelClientOptions;

/**
* Tags to be applied to the train job.
*
Expand All @@ -85,7 +92,7 @@ export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps
}

/**
* Class representing the SageMaker Create Training Job task.
* Class representing the SageMaker Create Transform Job task.
*
* @experimental
*/
Expand Down Expand Up @@ -147,7 +154,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
}

/**
* The execution role for the Sagemaker training job.
* The execution role for the Sagemaker transform job.
*
* Only available after task has been added to a state machine.
*/
Expand All @@ -164,6 +171,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
...this.renderEnvironment(this.props.environment),
...(this.props.maxConcurrentTransforms ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}),
...(this.props.maxPayload ? { MaxPayloadInMB: this.props.maxPayload.toMebibytes() } : {}),
...this.props.modelClientOptions ? this.renderModelClientOptions(this.props.modelClientOptions) : {},
ModelName: this.props.modelName,
...renderTags(this.props.tags),
...this.renderTransformInput(this.transformInput),
Expand All @@ -173,6 +181,23 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
};
}

private renderModelClientOptions(options: ModelClientOptions): { [key: string]: any } {
const retries = options.invocationsMaxRetries;
if (!Token.isUnresolved(retries) && retries? (retries < 0 || retries > 3): false) {
throw new Error(`invocationsMaxRetries should be between 0 and 3. Received: ${retries}.`);
}
const timeout = options.invocationsTimeout?.toSeconds();
if (!Token.isUnresolved(timeout) && timeout? (timeout < 1 || timeout > 3600): false) {
throw new Error(`invocationsTimeout should be between 1 and 3600 seconds. Received: ${timeout}.`);
}
return {
ModelClientConfig: {
InvocationsMaxRetries: retries ?? 0,
InvocationsTimeoutInSeconds: timeout ?? 60,
},
};
}

private renderTransformInput(input: TransformInput): { [key: string]: any } {
return {
TransformInput: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ test('create complex transform job', () => {
const task = new SageMakerCreateTransformJob(stack, 'TransformTask', {
transformJobName: 'MyTransformJob',
modelName: 'MyModelName',
modelClientOptions: {
invocationsMaxRetries: 1,
invocationsTimeout: cdk.Duration.minutes(20),
},
integrationPattern: sfn.IntegrationPattern.RUN_JOB,
role,
transformInput: {
Expand Down Expand Up @@ -151,6 +155,10 @@ test('create complex transform job', () => {
Parameters: {
TransformJobName: 'MyTransformJob',
ModelName: 'MyModelName',
ModelClientConfig: {
InvocationsMaxRetries: 1,
InvocationsTimeoutInSeconds: 1200,
},
TransformInput: {
DataSource: {
S3DataSource: {
Expand Down

0 comments on commit bf05092

Please sign in to comment.