Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jetstream Maxtext Module #719

Merged
merged 35 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2a7e568
first commit
Bslabe123 Jul 1, 2024
2824bb2
terraform fmt
Bslabe123 Jul 1, 2024
d8e1228
Update README.md
Bslabe123 Jul 1, 2024
cc3d7c5
prometheus adapter module in main
Bslabe123 Jul 1, 2024
d627b02
remove apply.sh
Bslabe123 Jul 1, 2024
f662092
typo
Bslabe123 Jul 1, 2024
4f1ba93
terraform fmt
Bslabe123 Jul 1, 2024
2086f62
Merge remote-tracking branch 'origin/main' into jetstream-module
Bslabe123 Jul 2, 2024
c8b3687
large cleanup and validation
Bslabe123 Jul 2, 2024
b63b93c
moved fields and made module variables consistent with example variables
Bslabe123 Jul 2, 2024
1a0444e
parameterized accelerator selectors
Bslabe123 Jul 2, 2024
505fbe1
parameterize metrics scrape interval
Bslabe123 Jul 2, 2024
b441e28
fmt
Bslabe123 Jul 2, 2024
d193afa
fmt
Bslabe123 Jul 2, 2024
dfd8078
load parameters parameterization and multiple hpa resources
Bslabe123 Jul 2, 2024
e36f05f
fmt
Bslabe123 Jul 2, 2024
2cc9a1f
parameterized model name
Bslabe123 Jul 2, 2024
4b37ebd
update readme and validators
Bslabe123 Jul 2, 2024
de8a8f7
changes to jetstream module deployment readme
Bslabe123 Jul 2, 2024
359cc6d
terraform fmt
Bslabe123 Jul 2, 2024
d86578f
accelerator_memory_used_percentage -> memory_used_percentage
Bslabe123 Jul 3, 2024
ccc736a
changes to READMEs
Bslabe123 Jul 3, 2024
d77e30c
tweaks
Bslabe123 Jul 3, 2024
482fd8e
metrics port optional
Bslabe123 Jul 3, 2024
58081c4
sample tfvars no longer includes autoscaling config
Bslabe123 Jul 3, 2024
bb81350
example autoscaling config
Bslabe123 Jul 3, 2024
7af7810
Update README.md
Bslabe123 Jul 8, 2024
ab88992
Update README.md
Bslabe123 Jul 8, 2024
aece9b0
Update README.md
Bslabe123 Jul 8, 2024
376a90a
strengthen hpa config validation
Bslabe123 Jul 8, 2024
161c333
More updates to readmes
Bslabe123 Jul 8, 2024
880cb36
tweak to readme
Bslabe123 Jul 8, 2024
eb44246
typo
Bslabe123 Jul 8, 2024
d5d05f4
missing kubectl apply
Bslabe123 Jul 8, 2024
cc64c39
typos
Bslabe123 Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions modules/jetstream-maxtext-deployment/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
## Bash equivalent of this module

Assure the following are set before running:
- BUCKET_NAME: Bucket name to be used in your checkpoint path
- METRICS_PORT: Port to emit custom metrics on
- (optional) MAXENGINE_SERVER_IMAGE: Maxengine server container image
- (optional) JETSTREAM_HTTP_SERVER_IMAGE: Jetstream HTTP server container image

```
if [ -z "$MAXENGINE_SERVER_IMAGE" ]; then
MAXENGINE_SERVER_IMAGE="us-docker.pkg.dev\/cloud-tpu-images\/inference\/maxengine-server:v0.2.2"
fi

if [ -z "$JETSTREAM_HTTP_SERVER_IMAGE" ]; then
JETSTREAM_HTTP_SERVER_IMAGE="us-docker.pkg.dev\/cloud-tpu-images\/inferenc\/jetstream-http:v0.2.2"
fi

if [ -z "$BUCKET_NAME" ]; then
echo "Must provide BUCKET_NAME in environment" 1>&2
exit 2;
fi

JETSTREAM_MANIFEST=$(mktemp)
cat ./templates/deployment.yaml.tftpl >> "$JETSTREAM_MANIFEST"

PODMONITORING_MANIFEST=$(mktemp)
cat ./templates/podmonitoring.yaml.tftpl >> "$PODMONITORING_MANIFEST"

if [ "$METRICS_PORT" != "" ]; then
cat $PODMONITORING_MANIFEST | sed "s/\${metrics_port}/$METRICS_PORT/g" >> "$PODMONITORING_MANIFEST"
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}/prometheus_port=$METRICS_PORT/g" >> "$JETSTREAM_MANIFEST"

cat $PODMONITORING_MANIFEST | kubectl apply -f -
else
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}//g" >> "$JETSTREAM_MANIFEST"
fi

cat $JETSTREAM_MANIFEST \
| sed "s/\${maxengine_server_image}/$MAXENGINE_SERVER_IMAGE/g" \
| sed "s/\${jetstream_http_server_image}/$JETSTREAM_HTTP_SERVER_IMAGE/g" \
| sed "s/\${load_parameters_path_arg}/load_parameters=gs:\/\/$BUCKET_NAME\/final\/unscanned\/gemma_7b-it\/0\/checkpoints\/0\/items/g" >> "$JETSTREAM_MANIFEST"

cat $JETSTREAM_MANIFEST | kubectl apply -f -
```

### Metrics Adapter

#### Custom Metrics Stackdriver Adapter

Follow the [Custom-metrics-stackdriver-adapter README](LINK HERE)

#### Prometheus Adapter

Follow the [Prometheus-adapter README](AWAITING OTHER MERGE), a few notes:

This module requires the cluster name to be passed in manually via the CLUSTER_NAME variable to filter incoming metrics. This is a consequence of differing cluster name schemas between GKE and standard k8s clusters. Instructions for each are as follows for if the cluster name isnt already known. For GKE clusters, Remove any charachters prior to and including the last underscore with `kubectl config current-context | awk -F'_' ' { print $NF }'` to get the cluster name. For other clusters, The cluster name is simply: `kubectl config current-context`.
Bslabe123 marked this conversation as resolved.
Show resolved Hide resolved

Instructions to set the PROMETHEUS_HELM_VALUES_FILE env var as follows:

```
PROMETHEUS_HELM_VALUES_FILE=$(mktemp)
sed "s/\${cluster_name}/$CLUSTER_NAME/g" ../templates/values.yaml.tftpl >> "$PROMETHEUS_HELM_VALUES_FILE"
```


34 changes: 34 additions & 0 deletions modules/jetstream-maxtext-deployment/apply.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
if [ -z "$MAXENGINE_SERVER_IMAGE" ]; then
Bslabe123 marked this conversation as resolved.
Show resolved Hide resolved
MAXENGINE_SERVER_IMAGE="us-docker.pkg.dev\/cloud-tpu-images\/inference\/maxengine-server:v0.2.2"
fi

if [ -z "$JETSTREAM_HTTP_SERVER_IMAGE" ]; then
JETSTREAM_HTTP_SERVER_IMAGE="us-docker.pkg.dev\/cloud-tpu-images\/inferenc\/jetstream-http:v0.2.2"
fi

if [ -z "$BUCKET_NAME" ]; then
echo "Must provide BUCKET_NAME in environment" 1>&2
exit 2;
fi

JETSTREAM_MANIFEST=$(mktemp)
cat ./templates/deployment.yaml.tftpl >> "$JETSTREAM_MANIFEST"

PODMONITORING_MANIFEST=$(mktemp)
cat ./templates/podmonitoring.yaml.tftpl >> "$PODMONITORING_MANIFEST"

if [ "$METRICS_PORT" != "" ]; then
cat $PODMONITORING_MANIFEST | sed "s/\${metrics_port}/$METRICS_PORT/g" >> "$PODMONITORING_MANIFEST"
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}/prometheus_port=$METRICS_PORT/g" >> "$JETSTREAM_MANIFEST"

cat $PODMONITORING_MANIFEST | kubectl apply -f -
else
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}//g" >> "$JETSTREAM_MANIFEST"
fi

cat $JETSTREAM_MANIFEST \
| sed "s/\${maxengine_server_image}/$MAXENGINE_SERVER_IMAGE/g" \
| sed "s/\${jetstream_http_server_image}/$JETSTREAM_HTTP_SERVER_IMAGE/g" \
| sed "s/\${load_parameters_path_arg}/load_parameters=gs:\/\/$BUCKET_NAME\/final\/unscanned\/gemma_7b-it\/0\/checkpoints\/0\/items/g" >> "$JETSTREAM_MANIFEST"

cat $JETSTREAM_MANIFEST | kubectl apply -f -
81 changes: 81 additions & 0 deletions modules/jetstream-maxtext-deployment/main.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

locals {
deployment_template = "${path.module}/templates/deployment.yaml.tftpl"
service_template = "${path.module}/templates/service.yaml.tftpl"
podmonitoring_template = "${path.module}/templates/podmonitoring.yaml.tftpl"
cmsa_jetstream_hpa_template = "${path.module}/templates/custom-metrics-stackdriver-adapter/hpa.jetstream.yaml.tftpl"
}

resource "kubernetes_manifest" "jetstream-deployment" {
count = 1
manifest = yamldecode(templatefile(local.deployment_template, {
maxengine_server_image = var.maxengine_server_image
jetstream_http_server_image = var.jetstream_http_server_image
load_parameters_path_arg = format("load_parameters_path=gs://%s/final/unscanned/gemma_7b-it/0/checkpoints/0/items", var.bucket_name)
Bslabe123 marked this conversation as resolved.
Show resolved Hide resolved
metrics_port_arg = var.metrics_port != null ? format("prometheus_port=%d", var.metrics_port) : "",
}))
}

resource "kubernetes_manifest" "jetstream-service" {
count = 1
manifest = yamldecode(file(local.service_template))
}

resource "kubernetes_manifest" "jetstream-podmonitoring" {
count = var.metrics_port != null ? 1 : 0
manifest = yamldecode(templatefile(local.podmonitoring_template, {
metrics_port = var.metrics_port != null ? var.metrics_port : "",
}))
}


## CMSA module pending https://github.com/GoogleCloudPlatform/ai-on-gke/pull/718/files merge
module "custom_metrics_stackdriver_adapter" {
count = var.metrics_adapter == "custom-metrics-stackdriver-adapter" ? 1 : 0
source = "../custom-metrics-stackdriver-adapter"
workload_identity = {
enabled = true
project_id = var.project_id
}
}

## Prometheus adapter module pending https://github.com/GoogleCloudPlatform/ai-on-gke/pull/716/files merge
module "prometheus_adapter" {
count = var.metrics_adapter == "prometheus-adapter" ? 1 : 0
source = "../prometheus-adapter"
credentials_config = {
kubeconfig = {
path : "~/.kube/config"
}
}
project_id = var.project_id
cluster_name = var.cluster_name
config_file = templatefile("/templates/prometheus-adapter/values.yaml.tftpl", {
cluster_name = var.cluster_name
})
}

resource "kubernetes_manifest" "hpa_custom_metric" {
count = (var.custom_metrics_enabled && var.hpa_type != null || var.hpa_type != "memory_used") && var.hpa_averagevalue_target != null ? 1 : 0
manifest = yamldecode(templatefile(local.cmsa_jetstream_hpa_template, {
hpa_type = try(var.hpa_type, "")
hpa_averagevalue_target = try(var.hpa_averagevalue_target, 1)
hpa_min_replicas = var.hpa_min_replicas
hpa_max_replicas = var.hpa_max_replicas
}))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: jetstream-hpa
namespace: default
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: maxengine-server
minReplicas: ${hpa_min_replicas}
maxReplicas: ${hpa_max_replicas}
metrics:
%{ if length(regexall("jetstream_.*", hpa_type)) > 0 }
- type: Pods
pods:
metric:
name: prometheus.googleapis.com|${hpa_type}|gauge
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
%{ else }
- type: External
external:
metric:
name: kubernetes.io|node|accelerator|memory_used
Bslabe123 marked this conversation as resolved.
Show resolved Hide resolved
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
%{ endif }
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: maxengine-server
namespace: default
spec:
replicas: 1
selector:
matchLabels:
app: maxengine-server
template:
metadata:
labels:
app: maxengine-server
spec:
nodeSelector:
cloud.google.com/gke-tpu-topology: 2x2
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
Bslabe123 marked this conversation as resolved.
Show resolved Hide resolved
containers:
- name: maxengine-server
image: ${maxengine_server_image}
imagePullPolicy: Always
securityContext:
privileged: true
args:
- model_name=gemma-7b
Bslabe123 marked this conversation as resolved.
Show resolved Hide resolved
- tokenizer_path=assets/tokenizer.gemma
- per_device_batch_size=4
- max_prefill_predict_length=1024
- max_target_length=2048
- async_checkpointing=false
- ici_fsdp_parallelism=1
- ici_autoregressive_parallelism=-1
- ici_tensor_parallelism=1
- scan_layers=false
- weight_dtype=bfloat16
- attention=dot_product
- ${load_parameters_path_arg}
- ${metrics_port_arg}
ports:
- containerPort: 9000
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
- name: jetstream-http
image: ${jetstream_http_server_image}
imagePullPolicy: Always
ports:
- containerPort: 8000
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
apiVersion: monitoring.googleapis.com/v1
kind: PodMonitoring
metadata:
name: jetstream-podmonitoring
namespace: default
spec:
endpoints:
- interval: 1s
Bslabe123 marked this conversation as resolved.
Show resolved Hide resolved
path: "/"
port: ${metrics_port}
targetLabels:
metadata:
- pod
- container
- node
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: jetstream-hpa
namespace: ${namespace}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: maxengine-server
minReplicas: ${hpa_min_replicas}
maxReplicas: ${hpa_max_replicas}
metrics:
%{ if length(regexall("jetstream_.*", hpa_type)) > 0 }
- type: Pods
pods:
metric:
name: ${hpa_type}
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
%{ else }
- type: External
external:
metric:
name: ${hpa_type}
target:
type: AverageValue
averageValue: ${hpa_averagevalue_target}
%{ endif }
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
rules:
default: false
external:
- seriesQuery: 'jetstream_prefill_backlog_size'
resources:
template: <<.Resource>>
name:
matches: ""
as: "jetstream_prefill_backlog_size"
metricsQuery: sum(<<.Series>>{<<.LabelMatchers>>,cluster="${cluster_name}"})
- seriesQuery: 'jetstream_transfer_backlog_size'
resources:
template: <<.Resource>>
name:
matches: ""
as: "jetstream_transfer_backlog_size"
metricsQuery: sum(<<.Series>>{<<.LabelMatchers>>,cluster="${cluster_name}"})
- seriesQuery: 'jetstream_generate_backlog_size'
resources:
template: <<.Resource>>
name:
matches: ""
as: "jetstream_generate_backlog_size"
metricsQuery: sum(<<.Series>>{<<.LabelMatchers>>,cluster="${cluster_name}"})
- seriesQuery: 'jetstream_slots_used_percentage'
resources:
template: <<.Resource>>
name:
matches: ""
as: "jetstream_slots_used_percentage"
metricsQuery: sum(<<.Series>>{<<.LabelMatchers>>,cluster="${cluster_name}"})
- seriesQuery: 'kubernetes_io:node_accelerator_memory_used'
resources:
template: <<.Resource>>
name:
matches: ""
as: "accelerator_memory_used_percentage"
metricsQuery: avg(kubernetes_io:node_accelerator_memory_used{cluster_name="${cluster_name}"}) / avg(kubernetes_io:node_accelerator_memory_total{cluster_name="${cluster_name}"})
17 changes: 17 additions & 0 deletions modules/jetstream-maxtext-deployment/templates/service.yaml.tftpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
apiVersion: v1
kind: Service
metadata:
name: jetstream-svc
namespace: default
spec:
selector:
app: maxengine-server
ports:
- protocol: TCP
name: jetstream-http
port: 8000
targetPort: 8000
- protocol: TCP
name: jetstream-grpc
port: 9000
targetPort: 9000
Loading
Loading