Skip to content

Commit

Permalink
JetStream checkpoint converter support for Llama models on MaxText (#840
Browse files Browse the repository at this point in the history
)

* Update pip in JetStream Pytorch and checkpoint Dockerfiles

* Add support for llama model conversions from Meta and HF to MaxText; update http server healthcheck
  • Loading branch information
vivianrwu authored Oct 2, 2024
1 parent 8aa4d6a commit 6a38ad5
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ RUN apt -y update && apt install -y google-cloud-cli

RUN pip install kaggle && \
pip install huggingface_hub[cli] && \
pip install google-jetstream
pip install google-jetstream && \
pip install llama-toolchain

COPY checkpoint_converter.sh /usr/bin/
RUN chmod +x /usr/bin/checkpoint_converter.sh
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash

set -e
export KAGGLE_CONFIG_DIR="/kaggle"
export HUGGINGFACE_TOKEN_DIR="/huggingface"
INFERENCE_SERVER="jetstream-maxtext"
Expand All @@ -19,13 +20,15 @@ check_gsbucket() {
BUCKET_NAME=$1
if [ -z $BUCKET_NAME ]; then
echo "BUCKET_NAME is empty, please provide a GSBucket"
exit 1
fi
}

check_model_path() {
MODEL_PATH=$1
if [ -z $MODEL_PATH ]; then
echo "MODEL_PATH is empty, please provide the model path"
exit 1
fi
}

Expand All @@ -49,10 +52,15 @@ download_huggingface_checkpoint() {
MODEL_NAME=$2

INPUT_CKPT_DIR_LOCAL=/base/
mkdir /base/

if [ ! -d "/base" ]; then
mkdir /base/
fi
huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN)
huggingface-cli download ${MODEL_PATH} --local-dir ${INPUT_CKPT_DIR_LOCAL}

echo "Completed downloading model ${MODEL_PATH}"

if [[ $MODEL_NAME == *"llama"* ]]; then
if [[ $MODEL_NAME == "llama-2" ]]; then
TOKENIZER_PATH=/base/tokenizer.model
Expand All @@ -64,37 +72,146 @@ download_huggingface_checkpoint() {
fi
elif [[ $MODEL_NAME == *"gemma"* ]]; then
TOKENIZER_PATH=/base/tokenizer.model
if [[ $MODEL_PATH == *"gemma-2b-it-pytorch"* ]]; then
huggingface-cli download google/gemma-2b-pytorch config.json --local-dir ${INPUT_CKPT_DIR_LOCAL}
fi
else
echo -e "Unclear of tokenizer.model for ${MODEL_NAME}. May have to manually upload."
fi
}

download_meta_checkpoint() {
META_URL=$1
MODEL_PATH=$2
echo -e "$META_URL" | llama download --source meta --model-id $MODEL_PATH
}

convert_maxtext_checkpoint() {
BUCKET_NAME=$1
MODEL_NAME=$2
VARIATION_NAME=$3
MODEL_SIZE=$4
MAXTEXT_VERSION=$5

if [ -z $MAXTEXT_VERSION ]; then
MAXTEXT_VERSION=jetstream-v0.2.2
MODEL_PATH=$2
MODEL_NAME=$3
OUTPUT_CKPT_DIR=$4
VERSION=$5
HUGGINGFACE=$6
META_URL=$7

echo -e "\nbucket name=${BUCKET_NAME}"
echo -e "\nmodel path=${MODEL_PATH}"
echo -e "\nmodel name=${MODEL_NAME}"
echo -e "\nversion=${VERSION}"
echo -e "\noutput ckpt dir=${OUTPUT_CKPT_DIR}"
echo -e "\nhuggingface=${HUGGINGFACE}"
echo -e "\nurl=${META_URL}"

if [ -z $VERSION ]; then
VERSION=jetstream-v0.2.2
fi

git clone https://github.com/google/maxtext.git

# checkout stable MaxText commit
cd maxtext
git checkout ${MAXTEXT_VERSION}
git checkout ${VERSION}
python3 -m pip install -r requirements.txt

if [ $VERSION == "jetstream-v0.2.2" || $VERSION == "jetstream-v0.2.1" || $VERSION == "jetstream-v0.2.0" ]; then
pip3 install orbax-checkpoint==0.5.20
else
pip3 install orbax-checkpoint==0.6.0
fi

echo -e "\nCloned MaxText repository and completed installing requirements"

python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} --model_size ${MODEL_SIZE}
echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}"
if [[ $MODEL_PATH == *"gemma"* ]]; then
download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_PATH"
OUTPUT_CKPT_DIR_SCANNED=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}
OUTPUT_CKPT_DIR_UNSCANNED=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}

python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path=${OUTPUT_CKPT_DIR_SCANNED} --model_size ${MODEL_SIZE}
echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}"

MAXTEXT_MODEL_NAME=${MODEL_NAME}-${MODEL_SIZE}

elif [[ $MODEL_PATH == *"Llama"* ]]; then

if [ $HUGGINGFACE == "True" ]; then
echo "Checkpoint weights are from HuggingFace"
download_huggingface_checkpoint "$MODEL_PATH" "$MODEL_NAME"

RUN_NAME=0
else
echo "Checkpoint weights are from Meta, use llama CLI"

if [ -z $META_URL ]; then
echo "META_URL is empty, please provide the Meta url by visiting https://www.llama.com/llama-downloads/ and agreeing to the Terms and Conditions."
exit 1
fi
echo "META_URL: $META_URL"

INPUT_CKPT_DIR_LOCAL=/root/.llama/checkpoints/$MODEL_PATH/
download_meta_checkpoint "$META_URL" "$MODEL_PATH"
fi

python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MODEL_NAME}-${MODEL_SIZE} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}/0/items base_output_directory=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}
echo -e "\nCompleted unscanning checkpoint to gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}/${RUN_NAME}/checkpoints/0/items"
echo "Setting model size for $MODEL_PATH"
if [[ $MODEL_NAME == "llama-2" ]]; then
if [[ $MODEL_PATH == *"7B"* ]] || [[ $MODEL_PATH == *"7b"* ]]; then
MODEL_SIZE="llama2-7b"
elif [[ $MODEL_PATH == *"13B"* ]] || [[ $MODEL_PATH == *"13b"* ]]; then
MODEL_SIZE="llama2-13b"
elif [[ $MODEL_PATH == *"70B"* ]] || [[ $MODEL_PATH == *"70b"* ]]; then
MODEL_SIZE="llama2-70b"
elif [[ $MODEL_PATH == *"405B"* ]] || [[ $MODEL_PATH == *"405b"* ]]; then
MODEL_SIZE="llama2-405b"
else
echo -e "\nUnclear llama2 model: $MODEL_PATH"
fi

elif [[ $MODEL_NAME == "llama-3" ]]; then
if [[ $MODEL_PATH == *"8B"* ]] || [[ $MODEL_PATH == *"8b"* ]]; then
MODEL_SIZE="llama3-8b"
elif [[ $MODEL_PATH == *"70B"* ]] || [[ $MODEL_PATH == *"70b"* ]]; then
MODEL_SIZE="llama3-70b"
elif [[ $MODEL_PATH == *"405B"* ]] || [[ $MODEL_PATH == *"405b"* ]]; then
MODEL_SIZE="llama3-405b"
else
echo -e "\nUnclear llama3 model: $MODEL_PATH"
fi

else
echo -e "\nUnclear llama model"
fi

echo "Model size for $MODEL_PATH is $MODEL_SIZE"

OUTPUT_CKPT_DIR_SCANNED=${OUTPUT_CKPT_DIR}/scanned
OUTPUT_CKPT_DIR_UNSCANNED=${OUTPUT_CKPT_DIR}/unscanned

TOKENIZER_PATH=${INPUT_CKPT_DIR_LOCAL}/tokenizer.model

pip3 install torch
echo -e "\ninput dir=${INPUT_CKPT_DIR_LOCAL}"
echo -e "\nmaxtext model path=${OUTPUT_CKPT_DIR_UNSCANNED}"
echo -e "\nmodel path=${MODEL_PATH}"
echo -e "\nmodel size=${MODEL_SIZE}"

cd /maxtext/
python3 MaxText/llama_ckpt_conversion_inference_only.py --base-model-path ${INPUT_CKPT_DIR_LOCAL} --maxtext-model-path ${OUTPUT_CKPT_DIR_UNSCANNED} --model-size ${MODEL_SIZE}
echo -e "\nCompleted conversion of checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED}/0/items"

gcloud storage cp ${TOKENIZER_PATH} ${OUTPUT_CKPT_DIR_UNSCANNED}

touch commit_success.txt
gcloud storage cp commit_success.txt ${OUTPUT_CKPT_DIR_UNSCANNED}/0/items

else
echo -e "\nUnclear model"
fi

if [[ $MODEL_PATH == *"gemma"* ]]; then
RUN_NAME=0

python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MAXTEXT_MODEL_NAME} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=${OUTPUT_CKPT_DIR_SCANNED}/0/items base_output_directory=${OUTPUT_CKPT_DIR_UNSCANNED}
echo -e "\nCompleted unscanning checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED}/${RUN_NAME}/checkpoints/0/items"
fi
}

convert_pytorch_checkpoint() {
Expand Down Expand Up @@ -173,7 +290,7 @@ convert_pytorch_checkpoint() {
}


while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do
while getopts 'b:s:m:n:h:t:q:v:i:o:u:' flag; do
case "${flag}" in
b) BUCKET_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
s) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
Expand All @@ -185,6 +302,7 @@ while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do
v) VERSION="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
i) INPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
o) OUTPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
u) META_URL="$(echo ${OPTARG} | awk -F'=' '{print $2"="$3"="$4"="$5"="$6}')" ;;
*) print_usage
exit 1 ;;
esac
Expand All @@ -197,8 +315,7 @@ case ${INFERENCE_SERVER} in
jetstream-maxtext)
check_gsbucket "$BUCKET_NAME"
check_model_path "$MODEL_PATH"
download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_PATH"
convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_SIZE" "$VERSION"
convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_PATH" "$MODEL_NAME" "$OUTPUT_DIRECTORY" "$VERSION" "$HUGGINGFACE" "$META_URL"
;;
jetstream-pytorch)
check_model_path "$MODEL_PATH"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def root():
)
return response

@app.get("/healthcheck")
async def healthcheck():
try:
request = jetstream_pb2.HealthCheckRequest()

options = [("grpc.keepalive_timeout_ms", 10000)]
async with grpc.aio.insecure_channel("127.0.0.1:9000", options=options) as channel:
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
response = stub.HealthCheck(request)
response = await response

if response.is_live == False:
raise fastapi.HTTPException(status_code=500, detail="Healthcheck failed, is_live = False")

is_live = {"is_live": response.is_live}
response = {"response": is_live}

response = fastapi.Response(
content=json.dumps(response, indent=4), media_type="application/json"
)
return response

except Exception as e:
logging.exception("Exception in healthcheck")
logging.exception(e)
raise fastapi.HTTPException(status_code=500, detail="Healthcheck failed")


@app.post("/generate", status_code=200)
async def generate(request: GenerateRequest):
Expand Down

0 comments on commit 6a38ad5

Please sign in to comment.