Skip to content

Commit

Permalink
latest mostly working
Browse files Browse the repository at this point in the history
  • Loading branch information
aidanrussell committed Dec 9, 2024
1 parent d2724e4 commit e6254e0
Show file tree
Hide file tree
Showing 13 changed files with 167 additions and 143 deletions.

This file was deleted.

116 changes: 3 additions & 113 deletions infra/modules/sagemaker_deployment/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ resource "aws_sagemaker_endpoint_configuration" "endpoint_config" {
output_config {
s3_output_path = var.s3_output_path
notification_config {
include_inference_response_in = ["SUCCESS_NOTIFICATION_TOPIC", "ERROR_NOTIFICATION_TOPIC"]
success_topic = aws_sns_topic.async-sagemaker-success-topic.arn
error_topic = aws_sns_topic.async-sagemaker-error-topic.arn
include_inference_response_in = ["SUCCESS_NOTIFICATION_TOPIC"]
success_topic = var.sns_success_topic_arn
}
}
}
Expand All @@ -42,7 +41,7 @@ resource "aws_sagemaker_endpoint_configuration" "endpoint_config" {
resource "aws_sagemaker_endpoint" "sagemaker_endpoint" {
name = var.endpoint_name
endpoint_config_name = aws_sagemaker_endpoint_configuration.endpoint_config.name
depends_on = [aws_sagemaker_endpoint_configuration.endpoint_config, aws_sns_topic.async-sagemaker-error-topic, aws_sns_topic.async-sagemaker-success-topic]
depends_on = [aws_sagemaker_endpoint_configuration.endpoint_config, var.sns_success_topic_arn]
}

# Autoscaling Target Resource
Expand Down Expand Up @@ -162,112 +161,3 @@ resource "aws_cloudwatch_metric_alarm" "cloudwatch_alarm" {

alarm_actions = var.alarms[count.index].alarm_actions != null ? var.alarms[count.index].alarm_actions : []
}

resource "aws_iam_role" "iam_for_lambda" {
name = "iam_for_lambda"
assume_role_policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = "sts:AssumeRole"
Effect = "Allow"
Principal = {
Service = "lambda.amazonaws.com"
}
}]})
}

resource "aws_iam_role_policy" "policy_for_lambda" {
name = "test_policy"
role = aws_iam_role.iam_for_lambda.id

policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = [
"SNS:Subscribe",
"SNS:SetTopicAttributes",
"SNS:RemovePermission",
"SNS:Receive",
"SNS:Publish",
"SNS:ListSubscriptionsByTopic",
"SNS:GetTopicAttributes",
"SNS:DeleteTopic",
"SNS:AddPermission",
]
Effect = "Allow"
Resource = "*"
},
]
})
}

data "archive_file" "lambda_payload" {
type = "zip"
source_dir = "./lambda_function/s3_move_output.py"
output_path = "./lambda_function/payload.zip"
}

resource "aws_lambda_function" "lambda_s3_move_output" {
filename = data.archive_file.lambda_payload.output_path
source_code_hash = data.archive_file.lambda_payload.output_base64sha256
function_name = "lambda_s3_move_output"
role = aws_iam_role.iam_for_lambda.arn
handler = "s3_move_output.lambda_handler"
runtime = "python3.12"
}


resource "aws_sns_topic" "async-sagemaker-success-topic" {
name = "async-sagemaker-success-topic"
#application_success_feedback_role_arn
#application_failure_feedback_role_arn
policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = [
"SNS:Subscribe",
"SNS:SetTopicAttributes",
"SNS:RemovePermission",
"SNS:Receive",
"SNS:Publish",
"SNS:ListSubscriptionsByTopic",
"SNS:GetTopicAttributes",
"SNS:DeleteTopic",
"SNS:AddPermission",
]
Effect = "Allow"
Resource = "*"
},
]
})

}

resource "aws_sns_topic" "async-sagemaker-error-topic" {
name = "async-sagemaker-error-topic"
#application_success_feedback_role_arn
#application_failure_feedback_role_arn
policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = [
"SNS:Subscribe",
"SNS:SetTopicAttributes",
"SNS:RemovePermission",
"SNS:Receive",
"SNS:Publish",
"SNS:ListSubscriptionsByTopic",
"SNS:GetTopicAttributes",
"SNS:DeleteTopic",
"SNS:AddPermission",
]
Effect = "Allow"
Resource = "*"
},
]
})
}
7 changes: 0 additions & 7 deletions infra/modules/sagemaker_deployment/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,3 @@ output "scale_in_to_zero_based_on_backlog_arn" {
value = aws_appautoscaling_policy.scale_in_to_zero_based_on_backlog.arn
}

output "sns_error_topic_arn" {
value = aws_sns_topic.async-sagemaker-error-topic
}

output "sns_success_topic_arn" {
value = aws_sns_topic.async-sagemaker-success-topic
}
5 changes: 5 additions & 0 deletions infra/modules/sagemaker_deployment/variables.tf
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
variable "sns_success_topic_arn" {
type = string
description = "ARN of the SNS topic for Sagemaker successful async outputs"
}

variable "model_name" {
type = string
description = "Name of the SageMaker model"
Expand Down
10 changes: 1 addition & 9 deletions infra/modules/sagemaker_init/iam/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,9 @@ data "aws_iam_policy_document" "sagemaker_inference_policy_document" {

statement {
actions = [
"SNS:Subscribe",
"SNS:SetTopicAttributes",
"SNS:RemovePermission",
"SNS:Receive",
"SNS:Publish",
"SNS:ListSubscriptionsByTopic",
"SNS:GetTopicAttributes",
"SNS:DeleteTopic",
"SNS:AddPermission",
]
resources = ["*"]
resources = ["arn:aws:sns:eu-west-2:${var.account_id}:async-sagemaker-success-topic"]
}

statement {
Expand Down
5 changes: 5 additions & 0 deletions infra/modules/sagemaker_init/iam/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ variable "aws_s3_bucket_notebook" {
type = any
description = "S3 bucket for notebooks"
}

variable "account_id" {
type = string
description = "AWS Account ID"
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import boto3

def lambda_handler(event, context):
try:
s3 = boto3.resource('s3')

sagemaker_endpoint_name = event["requestParameters"]["endpointName"]
input_file_uri = event["requestParameters"]["inputLocation"]
input_file_bucket = input_file_uri.split("/user/federated/")[0].split("s3://")[1]
federated_user_id = input_file_uri.split("/user/federated/")[1].split("/")[0]

output_file_uri = event["responseParameters"]["outputLocation"]
output_file_bucket = output_file_uri.split("https://")[1].split("/")[0].split(".s3.eu-west-2.amazonaws.com")[0]
output_file_key = output_file_uri.split("https://")[1].split("/")[1]

copy_source = {
'Bucket': output_file_bucket,
'Key': output_file_key
}
s3_filepath_output = f"user/federated/{federated_user_id}/sagemaker/outputs/{output_file_key}"
s3.meta.client.copy(copy_source, input_file_bucket, s3_filepath_output)

print(f"User {federated_user_id} called Sagemaker endpoint {sagemaker_endpoint_name} and the output file key was {s3_filepath_output}")

except Exception as e:
print("An error occurred")
raise e
99 changes: 99 additions & 0 deletions infra/modules/sagemaker_output_mover/main.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

resource "aws_iam_role" "iam_for_lambda_s3_move" {
name = "iam_for_lambda_s3_move"
assume_role_policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = "sts:AssumeRole"
Effect = "Allow"
Principal = {
Service = "lambda.amazonaws.com"
}
}]})
}

resource "aws_iam_role_policy" "policy_for_lambda_s3_move" {
name = "policy_for_lambda_s3_move"
role = aws_iam_role.iam_for_lambda_s3_move.id

policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = ["SNS:Receive", "SNS:Subscribe"]
Effect = "Allow"
Resource = aws_sns_topic.async-sagemaker-success-topic.arn
},
{
Action = ["s3:GetObject"]
Effect = "Allow"
Resource = "arn:aws:s3:::*sagemaker*"
},
{
Action = ["s3:PutObject"]
Effect = "Allow"
Resource = "${var.s3_bucket_notebooks_arn}*"
}
]
})
}

data "archive_file" "lambda_payload" {
type = "zip"
source_file = "${path.module}/lambda_function/s3_move_output.py"
output_path = "${path.module}/lambda_function/payload.zip"
}

resource "aws_lambda_function" "lambda_s3_move_output" {
filename = data.archive_file.lambda_payload.output_path
source_code_hash = data.archive_file.lambda_payload.output_base64sha256
function_name = "lambda_s3_move_output"
role = aws_iam_role.iam_for_lambda_s3_move.arn
handler = "s3_move_output.lambda_handler"
runtime = "python3.12"
timeout = 30
}


resource "aws_sns_topic" "async-sagemaker-success-topic" {
name = "async-sagemaker-success-topic"
policy = data.aws_iam_policy_document.sns_publish_and_read_policy.json
}

resource "aws_sns_topic_subscription" "topic_lambda" {
topic_arn = aws_sns_topic.async-sagemaker-success-topic.arn
protocol = "lambda"
endpoint = aws_lambda_function.lambda_s3_move_output.arn
}

resource "aws_lambda_permission" "with_sns" {
statement_id = "AllowExecutionFromSNS"
action = "lambda:InvokeFunction"
function_name = aws_lambda_function.lambda_s3_move_output.function_name
principal = "sns.amazonaws.com"
source_arn = aws_sns_topic.async-sagemaker-success-topic.arn
}

data "aws_iam_policy_document" "sns_publish_and_read_policy" {
statement {
sid = "sns_publish_and_read_policy_1"
actions = ["SNS:Publish"]
effect = "Allow"
principals {
type = "Service"
identifiers = ["sagemaker.amazonaws.com"]
}
resources = ["arn:aws:sns:${var.aws_region}:${var.account_id}:async-sagemaker-success-topic"]
}
statement {
sid = "sns_publish_and_read_policy_2"
actions = ["SNS:Receive","SNS:Subscribe"]
effect = "Allow"
principals {
type = "Service"
identifiers = ["lambda.amazonaws.com"]
}
resources = ["arn:aws:sns:${var.aws_region}:${var.account_id}:async-sagemaker-success-topic"]
}
}
3 changes: 3 additions & 0 deletions infra/modules/sagemaker_output_mover/outputs.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
output "sns_success_topic_arn" {
value = aws_sns_topic.async-sagemaker-success-topic.arn
}
14 changes: 14 additions & 0 deletions infra/modules/sagemaker_output_mover/variables.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
variable "account_id" {
type = string
description = "AWS Account ID"
}

variable "aws_region" {
type = string
description = "AWS Region in format e.g. us-west-1"
}

variable "s3_bucket_notebooks_arn" {
type = string
description = "S3 Bucket for notebook user data storage"
}
8 changes: 8 additions & 0 deletions infra/sagemaker.tf
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module "iam" {
prefix = var.prefix
sagemaker_default_bucket_name = var.sagemaker_default_bucket
aws_s3_bucket_notebook = aws_s3_bucket.notebooks
account_id = data.aws_caller_identity.aws_caller_identity.account_id
}


Expand Down Expand Up @@ -134,6 +135,13 @@ module "sns" {
account_id = data.aws_caller_identity.aws_caller_identity.account_id
}

module "sagemaker_output_mover" {
source = "./modules/sagemaker_output_mover"
account_id = data.aws_caller_identity.aws_caller_identity.account_id
aws_region = data.aws_region.aws_region.name
s3_bucket_notebooks_arn = aws_s3_bucket.notebooks.arn
}

module "log_group" {
source = "./modules/logs"
prefix = "data-workspace-sagemaker"
Expand Down
2 changes: 2 additions & 0 deletions infra/sagemaker_llm_resources.tf
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
module "gpt_neo_125_deployment" {
source = "./modules/sagemaker_deployment"
model_name = "gpt-neo-125m"
sns_success_topic_arn = module.sagemaker_output_mover.sns_success_topic_arn
execution_role_arn = module.iam.inference_role
container_image = var.hugging_face_model_image
model_data_url = "${var.sagemaker_models_folder}/gpt-neo-125m.tar.gz"
Expand Down Expand Up @@ -179,6 +180,7 @@ module "gpt_neo_125_deployment" {
module "llama_3_2_1b_deployment" {
source = "./modules/sagemaker_deployment"
model_name = "Llama-3-2-1B"
sns_success_topic_arn = module.sagemaker_output_mover.sns_success_topic_arn
execution_role_arn = module.iam.inference_role
container_image = var.hugging_face_model_image
model_data_url = "${var.sagemaker_models_folder}/Llama-3.2-1B.tar.gz"
Expand Down

0 comments on commit e6254e0

Please sign in to comment.