Skip to content

Commit

Permalink
Jetstream Maxtext Deployment Module: All scale rules now in a single …
Browse files Browse the repository at this point in the history
…HPA (#730)

first commit
  • Loading branch information
Bslabe123 authored Jul 16, 2024
1 parent 6d21def commit b2ce31a
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 52 deletions.
38 changes: 8 additions & 30 deletions modules/jetstream-maxtext-deployment/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -73,41 +73,19 @@ module "prometheus_adapter" {
}

resource "kubernetes_manifest" "prometheus_adapter_hpa_custom_metric" {
for_each = {
for index, rule in var.hpa_config.rules :
index => {
index = index
target_query = rule.target_query
average_value_target = rule.average_value_target
}
if var.maxengine_deployment_settings.custom_metrics_enabled && var.hpa_config.metrics_adapter == "prometheus-adapter"
}

count = var.hpa_config.metrics_adapter == "prometheus-adapter" ? 1 : 0
manifest = yamldecode(templatefile(local.prometheus_jetstream_hpa_template, {
index = each.value.index
hpa_type = try(each.value.target_query, "")
hpa_averagevalue_target = try(each.value.average_value_target, 1)
hpa_min_replicas = var.hpa_config.min_replicas
hpa_max_replicas = var.hpa_config.max_replicas
hpa_min_replicas = var.hpa_config.min_replicas
hpa_max_replicas = var.hpa_config.max_replicas
rules = var.hpa_config.rules
}))
}

resource "kubernetes_manifest" "cmsa_hpa_custom_metric" {
for_each = {
for index, rule in var.hpa_config.rules :
index => {
index = index
target_query = rule.target_query
average_value_target = rule.average_value_target
}
if var.maxengine_deployment_settings.custom_metrics_enabled && var.hpa_config.metrics_adapter == "custom-metrics-stackdriver-adapter"
}

count = var.hpa_config.metrics_adapter == "custom-metrics-stackdriver-adapter" ? 1 : 0
manifest = yamldecode(templatefile(local.cmsa_jetstream_hpa_template, {
index = each.value.index
hpa_type = try(each.value.target_query, "")
hpa_averagevalue_target = try(each.value.average_value_target, 1)
hpa_min_replicas = var.hpa_config.min_replicas
hpa_max_replicas = var.hpa_config.max_replicas
hpa_min_replicas = var.hpa_config.min_replicas
hpa_max_replicas = var.hpa_config.max_replicas
rules = var.hpa_config.rules
}))
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: jetstream-hpa-${index}
name: jetstream-hpa
namespace: default
spec:
scaleTargetRef:
Expand All @@ -11,20 +11,22 @@ spec:
minReplicas: ${hpa_min_replicas}
maxReplicas: ${hpa_max_replicas}
metrics:
%{ if length(regexall("jetstream_.*", hpa_type)) > 0 }
%{ for rule in rules }
%{ if length(regexall("jetstream_.*", rule.target_query)) > 0 }
- type: Pods
pods:
metric:
name: prometheus.googleapis.com|${hpa_type}|gauge
name: ${rule.target_query}
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
averageValue: ${rule.average_value_target}
%{ else }
- type: External
external:
metric:
name: kubernetes.io|node|accelerator|${hpa_type}
name: kubernetes.io|node|accelerator|${rule.target_query}
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
%{ endif }
averageValue: ${rule.average_value_target}
%{ endif }
%{ endfor ~}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: jetstream-hpa-${index}
name: jetstream-hpa
namespace: default
spec:
scaleTargetRef:
Expand All @@ -11,20 +11,22 @@ spec:
minReplicas: ${hpa_min_replicas}
maxReplicas: ${hpa_max_replicas}
metrics:
%{ if length(regexall("jetstream_.*", hpa_type)) > 0 }
- type: Pods
pods:
%{ for rule in rules }
%{ if length(regexall("jetstream_.*", rule.target_query)) > 0 }
- type: External
external:
metric:
name: ${hpa_type}
name: ${rule.target_query}
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
averageValue: ${rule.average_value_target}
%{ else }
- type: External
external:
metric:
name: ${hpa_type}
name: ${rule.target_query}
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
%{ endif }
averageValue: ${rule.average_value_target}
%{ endif }
%{ endfor ~}
3 changes: 1 addition & 2 deletions modules/jetstream-maxtext-deployment/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ variable "maxengine_deployment_settings" {

model_name = string // Name of your LLM (for example: "gemma-7b")
parameters_path = string // Path to the paramters for your model
metrics_port = optional(number) // Emit Jetstream metrics on this port of each contaienr
custom_metrics_enabled = bool // Whether or not custom metrics are also emitted
metrics_port = optional(number) // Emit Jetstream metrics on this port of each container
metrics_scrape_interval = optional(number) // Interval for scraping metrics (default: 10s)

accelerator_selectors = object({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ For deploying autoscaling components via terraform, a few more variables to be s

```
maxengine_deployment_settings = {
custom_metrics_enabled = true
metrics_port = <same as above>
metrics_scrape_interval
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ maxengine_deployment_settings = {
maxengine_server_image = "us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.2"
jetstream_http_server_image = "us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2"

custom_metrics_enabled = true
metrics_port = 9100
metrics_scrape_interval = 10
accelerator_selectors = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ variable "maxengine_deployment_settings" {

model_name = string // Name of your LLM (for example: "gemma-7b")
parameters_path = string // Path to the parameters for your model
metrics_port = optional(number) // Emit Jetstream metrics on this port of each contaienr
custom_metrics_enabled = bool // Whether or not custom metrics are also emitted
metrics_port = optional(number) // Emit Jetstream metrics on this port of each container
metrics_scrape_interval = optional(number) // Interval for scraping metrics (default: 10s)

accelerator_selectors = object({
Expand Down

0 comments on commit b2ce31a

Please sign in to comment.