Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Aidans work not for merging #166

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions infra/modules/sagemaker_deployment/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ resource "aws_sagemaker_endpoint_configuration" "endpoint_config" {
async_inference_config {
output_config {
s3_output_path = var.s3_output_path
notification_config {
include_inference_response_in = ["SUCCESS_NOTIFICATION_TOPIC"]
success_topic = var.sns_success_topic_arn
}
}
}
}
Expand All @@ -37,6 +41,7 @@ resource "aws_sagemaker_endpoint_configuration" "endpoint_config" {
resource "aws_sagemaker_endpoint" "sagemaker_endpoint" {
name = var.endpoint_name
endpoint_config_name = aws_sagemaker_endpoint_configuration.endpoint_config.name
depends_on = [aws_sagemaker_endpoint_configuration.endpoint_config, var.sns_success_topic_arn]
}

# Autoscaling Target Resource
Expand Down Expand Up @@ -78,7 +83,7 @@ resource "aws_appautoscaling_policy" "scale_in_to_zero_policy" {

step_scaling_policy_configuration {
adjustment_type = "ExactCapacity"


step_adjustment {
metric_interval_lower_bound = null # No lower bound to cover everything
Expand Down Expand Up @@ -108,7 +113,7 @@ resource "aws_appautoscaling_policy" "scale_in_to_zero_based_on_backlog" {

step_scaling_policy_configuration {
adjustment_type = "ExactCapacity" # Set the capacity exactly to zero

# Step adjustment for when there are zero queries in the backlog
step_adjustment {
metric_interval_lower_bound = null # No lower bound (cover everything below 0)
Expand Down Expand Up @@ -145,7 +150,7 @@ resource "aws_cloudwatch_metric_alarm" "cloudwatch_alarm" {
period = var.alarms[count.index].period
statistic = var.alarms[count.index].statistic

# Define dimensions based on the count index -
# Define dimensions based on the count index -
# first alarm will not have a null variantName
dimensions = count.index == 0 ? {
EndpointName = aws_sagemaker_endpoint.sagemaker_endpoint.name
Expand Down
3 changes: 2 additions & 1 deletion infra/modules/sagemaker_deployment/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ output "scale_in_to_zero_policy_arn" {
output "scale_in_to_zero_based_on_backlog_arn" {
description = "ARN of the autoscaling policy to scale in to zero for backlog queries when 0 for x minutes"
value = aws_appautoscaling_policy.scale_in_to_zero_based_on_backlog.arn
}
}

5 changes: 5 additions & 0 deletions infra/modules/sagemaker_deployment/variables.tf
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
variable "sns_success_topic_arn" {
type = string
description = "ARN of the SNS topic for Sagemaker successful async outputs"
}

variable "model_name" {
type = string
description = "Name of the SageMaker model"
Expand Down
13 changes: 6 additions & 7 deletions infra/modules/sagemaker_init/iam/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ data "aws_iam_policy_document" "sagemaker_inference_policy_document" {
]
}

statement {
actions = [
"SNS:Publish",
]
resources = ["arn:aws:sns:eu-west-2:${var.account_id}:async-sagemaker-success-topic"]
}

statement {
actions = [
Expand Down Expand Up @@ -159,10 +165,3 @@ resource "aws_iam_role_policy_attachment" "sagemaker_inference_role_policy" {
role = aws_iam_role.inference_role.name
policy_arn = aws_iam_policy.sagemaker_access_policy.arn
}







8 changes: 6 additions & 2 deletions infra/modules/sagemaker_init/iam/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ variable "prefix" {
description = "Prefix for naming IAM resources"
}


variable "sagemaker_default_bucket_name" {
type = string
description = "name of the default S3 bucket used by sagemaker"
Expand All @@ -12,4 +11,9 @@ variable "sagemaker_default_bucket_name" {
variable "aws_s3_bucket_notebook" {
type = any
description = "S3 bucket for notebooks"
}
}

variable "account_id" {
type = string
description = "AWS Account ID"
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import boto3
import ast

def lambda_handler(event, context):
for record in event['Records']:
process_message(record)
print(f"sns message processed")

def process_message(record):
try:
message_str = record['Sns']['Message']
s3 = boto3.resource('s3')
message_dict = ast.literal_eval(message)
print(message_dict)

sagemaker_endpoint_name = message_dict["requestParameters"]["endpointName"]
input_file_uri = message_dict["requestParameters"]["inputLocation"]
input_file_bucket = input_file_uri.split("/user/federated/")[0].split("s3://")[1]
federated_user_id = input_file_uri.split("/user/federated/")[1].split("/")[0]

output_file_uri = message_dict["responseParameters"]["outputLocation"]
output_file_bucket = output_file_uri.split("https://")[1].split("/")[0].split(".s3.eu-west-2.amazonaws.com")[0]
output_file_key = output_file_uri.split("https://")[1].split("/")[1]

copy_source = {
'Bucket': output_file_bucket,
'Key': output_file_key
}
s3_filepath_output = f"user/federated/{federated_user_id}/sagemaker/outputs/{output_file_key}"
s3.meta.client.copy(copy_source, input_file_bucket, s3_filepath_output)

print(f"User {federated_user_id} called Sagemaker endpoint {sagemaker_endpoint_name} and the output file key was {s3_filepath_output}")

except Exception as e:
print("An error occurred")
raise e
104 changes: 104 additions & 0 deletions infra/modules/sagemaker_output_mover/main.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@

resource "aws_iam_role" "iam_for_lambda_s3_move" {
name = "iam_for_lambda_s3_move"
assume_role_policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = "sts:AssumeRole"
Effect = "Allow"
Principal = {
Service = "lambda.amazonaws.com"
}
}]})
}

resource "aws_iam_role_policy" "policy_for_lambda_s3_move" {
name = "policy_for_lambda_s3_move"
role = aws_iam_role.iam_for_lambda_s3_move.id

policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = ["SNS:Receive", "SNS:Subscribe"]
Effect = "Allow"
Resource = aws_sns_topic.async-sagemaker-success-topic.arn
},
{
Action = ["s3:GetObject"]
Effect = "Allow"
Resource = "arn:aws:s3:::*sagemaker*"
},
{
Action = ["s3:PutObject"]
Effect = "Allow"
Resource = "${var.s3_bucket_notebooks_arn}*"
},
{
Action = ["logs:CreateLogGroup","logs:CreateLogStream","logs:PutLogEvents","logs:DescribeLogStreams"]
Effect = "Allow"
Resource = "arn:aws:logs:*:*:*"
}
]
})
}

data "archive_file" "lambda_payload" {
type = "zip"
source_file = "${path.module}/lambda_function/s3_move_output.py"
output_path = "${path.module}/lambda_function/payload.zip"
}

resource "aws_lambda_function" "lambda_s3_move_output" {
filename = data.archive_file.lambda_payload.output_path
source_code_hash = data.archive_file.lambda_payload.output_base64sha256
function_name = "lambda_s3_move_output"
role = aws_iam_role.iam_for_lambda_s3_move.arn
handler = "s3_move_output.lambda_handler"
runtime = "python3.12"
timeout = 30
}


resource "aws_sns_topic" "async-sagemaker-success-topic" {
name = "async-sagemaker-success-topic"
policy = data.aws_iam_policy_document.sns_publish_and_read_policy.json
}

resource "aws_sns_topic_subscription" "topic_lambda" {
topic_arn = aws_sns_topic.async-sagemaker-success-topic.arn
protocol = "lambda"
endpoint = aws_lambda_function.lambda_s3_move_output.arn
}

resource "aws_lambda_permission" "with_sns" {
statement_id = "AllowExecutionFromSNS"
action = "lambda:InvokeFunction"
function_name = aws_lambda_function.lambda_s3_move_output.function_name
principal = "sns.amazonaws.com"
source_arn = aws_sns_topic.async-sagemaker-success-topic.arn
}

data "aws_iam_policy_document" "sns_publish_and_read_policy" {
statement {
sid = "sns_publish_and_read_policy_1"
actions = ["SNS:Publish"]
effect = "Allow"
principals {
type = "Service"
identifiers = ["sagemaker.amazonaws.com"]
}
resources = ["arn:aws:sns:${var.aws_region}:${var.account_id}:async-sagemaker-success-topic"]
}
statement {
sid = "sns_publish_and_read_policy_2"
actions = ["SNS:Receive","SNS:Subscribe"]
effect = "Allow"
principals {
type = "Service"
identifiers = ["lambda.amazonaws.com"]
}
resources = ["arn:aws:sns:${var.aws_region}:${var.account_id}:async-sagemaker-success-topic"]
}
}
3 changes: 3 additions & 0 deletions infra/modules/sagemaker_output_mover/outputs.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
output "sns_success_topic_arn" {
value = aws_sns_topic.async-sagemaker-success-topic.arn
}
14 changes: 14 additions & 0 deletions infra/modules/sagemaker_output_mover/variables.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
variable "account_id" {
type = string
description = "AWS Account ID"
}

variable "aws_region" {
type = string
description = "AWS Region in format e.g. us-west-1"
}

variable "s3_bucket_notebooks_arn" {
type = string
description = "S3 Bucket for notebook user data storage"
}
33 changes: 33 additions & 0 deletions infra/sagemaker.tf
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module "iam" {
prefix = var.prefix
sagemaker_default_bucket_name = var.sagemaker_default_bucket
aws_s3_bucket_notebook = aws_s3_bucket.notebooks
account_id = data.aws_caller_identity.aws_caller_identity.account_id
}


Expand Down Expand Up @@ -62,6 +63,31 @@ resource "aws_security_group_rule" "notebooks_endpoint_egress_sagemaker" {
protocol = "tcp"
}

resource "aws_security_group_rule" "main_egress_sns" {
description = "endpoint-egress-from-main-vpc"

security_group_id = aws_security_group.notebooks_endpoints.id
cidr_blocks = ["0.0.0.0/0"]

type = "egress"
from_port = "0"
to_port = "65535"
protocol = "tcp"
}

resource "aws_security_group_rule" "main_ingress_sns" {
description = "endpoint-ingress-to-main-vpc"

security_group_id = aws_security_group.notebooks_endpoints.id
cidr_blocks = ["0.0.0.0/0"]

type = "ingress"
from_port = "0"
to_port = "65535"
protocol = "tcp"
}


# SageMaker Execution Role Output
output "execution_role" {
value = module.iam.execution_role
Expand Down Expand Up @@ -109,6 +135,13 @@ module "sns" {
account_id = data.aws_caller_identity.aws_caller_identity.account_id
}

module "sagemaker_output_mover" {
source = "./modules/sagemaker_output_mover"
account_id = data.aws_caller_identity.aws_caller_identity.account_id
aws_region = data.aws_region.aws_region.name
s3_bucket_notebooks_arn = aws_s3_bucket.notebooks.arn
}

module "log_group" {
source = "./modules/logs"
prefix = "data-workspace-sagemaker"
Expand Down
10 changes: 6 additions & 4 deletions infra/sagemaker_llm_resources.tf
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
module "gpt_neo_125_deployment" {
source = "./modules/sagemaker_deployment"
model_name = "gpt-neo-125m"
sns_success_topic_arn = module.sagemaker_output_mover.sns_success_topic_arn
execution_role_arn = module.iam.inference_role
container_image = var.hugging_face_model_image
model_data_url = "${var.sagemaker_models_folder}/gpt-neo-125m.tar.gz"
Expand Down Expand Up @@ -50,7 +51,7 @@ module "gpt_neo_125_deployment" {
comparison_operator = "LessThanThreshold"
threshold = 5.0
evaluation_periods = 3
datapoints_to_alarm = 2 # 2 out of 5 periods breaching then scale down to ensure
datapoints_to_alarm = 2 # 2 out of 5 periods breaching then scale down to ensure
period = 60
statistic = "Average"
alarm_actions = [module.gpt_neo_125_deployment.scale_in_to_zero_policy_arn]
Expand All @@ -63,7 +64,7 @@ module "gpt_neo_125_deployment" {
comparison_operator = "LessThanThreshold"
threshold = 0
evaluation_periods = 3
datapoints_to_alarm = 2 # 2 out of 3 periods breaching then scale down to ensure
datapoints_to_alarm = 2 # 2 out of 3 periods breaching then scale down to ensure
period = 60
statistic = "Sum"
alarm_actions = [module.gpt_neo_125_deployment.scale_in_to_zero_based_on_backlog_arn]
Expand Down Expand Up @@ -179,6 +180,7 @@ module "gpt_neo_125_deployment" {
module "llama_3_2_1b_deployment" {
source = "./modules/sagemaker_deployment"
model_name = "Llama-3-2-1B"
sns_success_topic_arn = module.sagemaker_output_mover.sns_success_topic_arn
execution_role_arn = module.iam.inference_role
container_image = var.hugging_face_model_image
model_data_url = "${var.sagemaker_models_folder}/Llama-3.2-1B.tar.gz"
Expand Down Expand Up @@ -224,7 +226,7 @@ module "llama_3_2_1b_deployment" {
comparison_operator = "LessThanThreshold"
threshold = 5.0
evaluation_periods = 3
datapoints_to_alarm = 2 # 2 out of 3 periods breaching then scale down to ensure
datapoints_to_alarm = 2 # 2 out of 3 periods breaching then scale down to ensure
period = 60
statistic = "Average"
alarm_actions = [module.llama_3_2_1b_deployment.scale_in_to_zero_policy_arn]
Expand All @@ -237,7 +239,7 @@ module "llama_3_2_1b_deployment" {
comparison_operator = "LessThanThreshold"
threshold = 0
evaluation_periods = 3
datapoints_to_alarm = 2 # 2 out of 3 periods breaching then scale down to ensure
datapoints_to_alarm = 2 # 2 out of 3 periods breaching then scale down to ensure
period = 60
statistic = "Sum"
alarm_actions = [module.llama_3_2_1b_deployment.scale_in_to_zero_based_on_backlog_arn]
Expand Down
4 changes: 4 additions & 0 deletions infra/security_groups.tf
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


resource "aws_security_group" "dns_rewrite_proxy" {
name = "${var.prefix}-dns-rewrite-proxy"
description = "${var.prefix}-dns-rewrite-proxy"
Expand Down Expand Up @@ -603,6 +605,8 @@ resource "aws_security_group" "ecr_api" {
}
}



resource "aws_security_group_rule" "ecr_api_ingress_https_from_dns_rewrite_proxy" {
description = "ingress-https-from-dns-rewrite-proxy"

Expand Down
Loading