diff --git a/tutorials-and-examples/inference-servers/checkpoints/Dockerfile b/tutorials-and-examples/inference-servers/checkpoints/Dockerfile index 527d01577..57cdc8bba 100644 --- a/tutorials-and-examples/inference-servers/checkpoints/Dockerfile +++ b/tutorials-and-examples/inference-servers/checkpoints/Dockerfile @@ -24,7 +24,8 @@ RUN apt -y update && apt install -y google-cloud-cli RUN pip install kaggle && \ pip install huggingface_hub[cli] && \ -pip install google-jetstream +pip install google-jetstream && \ +pip install llama-toolchain COPY checkpoint_converter.sh /usr/bin/ RUN chmod +x /usr/bin/checkpoint_converter.sh diff --git a/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh b/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh index d52ae35ec..270a935ab 100644 --- a/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh +++ b/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh @@ -1,5 +1,6 @@ #!/bin/bash +set -e export KAGGLE_CONFIG_DIR="/kaggle" export HUGGINGFACE_TOKEN_DIR="/huggingface" INFERENCE_SERVER="jetstream-maxtext" @@ -19,6 +20,7 @@ check_gsbucket() { BUCKET_NAME=$1 if [ -z $BUCKET_NAME ]; then echo "BUCKET_NAME is empty, please provide a GSBucket" + exit 1 fi } @@ -26,6 +28,7 @@ check_model_path() { MODEL_PATH=$1 if [ -z $MODEL_PATH ]; then echo "MODEL_PATH is empty, please provide the model path" + exit 1 fi } @@ -49,10 +52,15 @@ download_huggingface_checkpoint() { MODEL_NAME=$2 INPUT_CKPT_DIR_LOCAL=/base/ - mkdir /base/ + + if [ ! -d "/base" ]; then + mkdir /base/ + fi huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN) huggingface-cli download ${MODEL_PATH} --local-dir ${INPUT_CKPT_DIR_LOCAL} + echo "Completed downloading model ${MODEL_PATH}" + if [[ $MODEL_NAME == *"llama"* ]]; then if [[ $MODEL_NAME == "llama-2" ]]; then TOKENIZER_PATH=/base/tokenizer.model @@ -64,37 +72,146 @@ download_huggingface_checkpoint() { fi elif [[ $MODEL_NAME == *"gemma"* ]]; then TOKENIZER_PATH=/base/tokenizer.model + if [[ $MODEL_PATH == *"gemma-2b-it-pytorch"* ]]; then + huggingface-cli download google/gemma-2b-pytorch config.json --local-dir ${INPUT_CKPT_DIR_LOCAL} + fi else echo -e "Unclear of tokenizer.model for ${MODEL_NAME}. May have to manually upload." fi } +download_meta_checkpoint() { + META_URL=$1 + MODEL_PATH=$2 + echo -e "$META_URL" | llama download --source meta --model-id $MODEL_PATH +} + convert_maxtext_checkpoint() { BUCKET_NAME=$1 - MODEL_NAME=$2 - VARIATION_NAME=$3 - MODEL_SIZE=$4 - MAXTEXT_VERSION=$5 - - if [ -z $MAXTEXT_VERSION ]; then - MAXTEXT_VERSION=jetstream-v0.2.2 + MODEL_PATH=$2 + MODEL_NAME=$3 + OUTPUT_CKPT_DIR=$4 + VERSION=$5 + HUGGINGFACE=$6 + META_URL=$7 + + echo -e "\nbucket name=${BUCKET_NAME}" + echo -e "\nmodel path=${MODEL_PATH}" + echo -e "\nmodel name=${MODEL_NAME}" + echo -e "\nversion=${VERSION}" + echo -e "\noutput ckpt dir=${OUTPUT_CKPT_DIR}" + echo -e "\nhuggingface=${HUGGINGFACE}" + echo -e "\nurl=${META_URL}" + + if [ -z $VERSION ]; then + VERSION=jetstream-v0.2.2 fi git clone https://github.com/google/maxtext.git # checkout stable MaxText commit cd maxtext - git checkout ${MAXTEXT_VERSION} + git checkout ${VERSION} python3 -m pip install -r requirements.txt + + if [ $VERSION == "jetstream-v0.2.2" || $VERSION == "jetstream-v0.2.1" || $VERSION == "jetstream-v0.2.0" ]; then + pip3 install orbax-checkpoint==0.5.20 + else + pip3 install orbax-checkpoint==0.6.0 + fi + echo -e "\nCloned MaxText repository and completed installing requirements" - python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} --model_size ${MODEL_SIZE} - echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}" + if [[ $MODEL_PATH == *"gemma"* ]]; then + download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_PATH" + OUTPUT_CKPT_DIR_SCANNED=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} + OUTPUT_CKPT_DIR_UNSCANNED=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME} + + python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path=${OUTPUT_CKPT_DIR_SCANNED} --model_size ${MODEL_SIZE} + echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}" + + MAXTEXT_MODEL_NAME=${MODEL_NAME}-${MODEL_SIZE} + + elif [[ $MODEL_PATH == *"Llama"* ]]; then + + if [ $HUGGINGFACE == "True" ]; then + echo "Checkpoint weights are from HuggingFace" + download_huggingface_checkpoint "$MODEL_PATH" "$MODEL_NAME" - RUN_NAME=0 + else + echo "Checkpoint weights are from Meta, use llama CLI" + + if [ -z $META_URL ]; then + echo "META_URL is empty, please provide the Meta url by visiting https://www.llama.com/llama-downloads/ and agreeing to the Terms and Conditions." + exit 1 + fi + echo "META_URL: $META_URL" + + INPUT_CKPT_DIR_LOCAL=/root/.llama/checkpoints/$MODEL_PATH/ + download_meta_checkpoint "$META_URL" "$MODEL_PATH" + fi - python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MODEL_NAME}-${MODEL_SIZE} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}/0/items base_output_directory=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME} - echo -e "\nCompleted unscanning checkpoint to gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}/${RUN_NAME}/checkpoints/0/items" + echo "Setting model size for $MODEL_PATH" + if [[ $MODEL_NAME == "llama-2" ]]; then + if [[ $MODEL_PATH == *"7B"* ]] || [[ $MODEL_PATH == *"7b"* ]]; then + MODEL_SIZE="llama2-7b" + elif [[ $MODEL_PATH == *"13B"* ]] || [[ $MODEL_PATH == *"13b"* ]]; then + MODEL_SIZE="llama2-13b" + elif [[ $MODEL_PATH == *"70B"* ]] || [[ $MODEL_PATH == *"70b"* ]]; then + MODEL_SIZE="llama2-70b" + elif [[ $MODEL_PATH == *"405B"* ]] || [[ $MODEL_PATH == *"405b"* ]]; then + MODEL_SIZE="llama2-405b" + else + echo -e "\nUnclear llama2 model: $MODEL_PATH" + fi + + elif [[ $MODEL_NAME == "llama-3" ]]; then + if [[ $MODEL_PATH == *"8B"* ]] || [[ $MODEL_PATH == *"8b"* ]]; then + MODEL_SIZE="llama3-8b" + elif [[ $MODEL_PATH == *"70B"* ]] || [[ $MODEL_PATH == *"70b"* ]]; then + MODEL_SIZE="llama3-70b" + elif [[ $MODEL_PATH == *"405B"* ]] || [[ $MODEL_PATH == *"405b"* ]]; then + MODEL_SIZE="llama3-405b" + else + echo -e "\nUnclear llama3 model: $MODEL_PATH" + fi + + else + echo -e "\nUnclear llama model" + fi + + echo "Model size for $MODEL_PATH is $MODEL_SIZE" + + OUTPUT_CKPT_DIR_SCANNED=${OUTPUT_CKPT_DIR}/scanned + OUTPUT_CKPT_DIR_UNSCANNED=${OUTPUT_CKPT_DIR}/unscanned + + TOKENIZER_PATH=${INPUT_CKPT_DIR_LOCAL}/tokenizer.model + + pip3 install torch + echo -e "\ninput dir=${INPUT_CKPT_DIR_LOCAL}" + echo -e "\nmaxtext model path=${OUTPUT_CKPT_DIR_UNSCANNED}" + echo -e "\nmodel path=${MODEL_PATH}" + echo -e "\nmodel size=${MODEL_SIZE}" + + cd /maxtext/ + python3 MaxText/llama_ckpt_conversion_inference_only.py --base-model-path ${INPUT_CKPT_DIR_LOCAL} --maxtext-model-path ${OUTPUT_CKPT_DIR_UNSCANNED} --model-size ${MODEL_SIZE} + echo -e "\nCompleted conversion of checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED}/0/items" + + gcloud storage cp ${TOKENIZER_PATH} ${OUTPUT_CKPT_DIR_UNSCANNED} + + touch commit_success.txt + gcloud storage cp commit_success.txt ${OUTPUT_CKPT_DIR_UNSCANNED}/0/items + + else + echo -e "\nUnclear model" + fi + + if [[ $MODEL_PATH == *"gemma"* ]]; then + RUN_NAME=0 + + python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MAXTEXT_MODEL_NAME} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=${OUTPUT_CKPT_DIR_SCANNED}/0/items base_output_directory=${OUTPUT_CKPT_DIR_UNSCANNED} + echo -e "\nCompleted unscanning checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED}/${RUN_NAME}/checkpoints/0/items" + fi } convert_pytorch_checkpoint() { @@ -173,7 +290,7 @@ convert_pytorch_checkpoint() { } -while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do +while getopts 'b:s:m:n:h:t:q:v:i:o:u:' flag; do case "${flag}" in b) BUCKET_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; s) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; @@ -185,6 +302,7 @@ while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do v) VERSION="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; i) INPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; o) OUTPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + u) META_URL="$(echo ${OPTARG} | awk -F'=' '{print $2"="$3"="$4"="$5"="$6}')" ;; *) print_usage exit 1 ;; esac @@ -197,8 +315,7 @@ case ${INFERENCE_SERVER} in jetstream-maxtext) check_gsbucket "$BUCKET_NAME" check_model_path "$MODEL_PATH" - download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_PATH" - convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_SIZE" "$VERSION" + convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_PATH" "$MODEL_NAME" "$OUTPUT_DIRECTORY" "$VERSION" "$HUGGINGFACE" "$META_URL" ;; jetstream-pytorch) check_model_path "$MODEL_PATH" diff --git a/tutorials-and-examples/inference-servers/jetstream/http-server/http_server.py b/tutorials-and-examples/inference-servers/jetstream/http-server/http_server.py index 9c1faa276..fa35dc03d 100644 --- a/tutorials-and-examples/inference-servers/jetstream/http-server/http_server.py +++ b/tutorials-and-examples/inference-servers/jetstream/http-server/http_server.py @@ -48,6 +48,33 @@ def root(): ) return response +@app.get("/healthcheck") +async def healthcheck(): + try: + request = jetstream_pb2.HealthCheckRequest() + + options = [("grpc.keepalive_timeout_ms", 10000)] + async with grpc.aio.insecure_channel("127.0.0.1:9000", options=options) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + response = stub.HealthCheck(request) + response = await response + + if response.is_live == False: + raise fastapi.HTTPException(status_code=500, detail="Healthcheck failed, is_live = False") + + is_live = {"is_live": response.is_live} + response = {"response": is_live} + + response = fastapi.Response( + content=json.dumps(response, indent=4), media_type="application/json" + ) + return response + + except Exception as e: + logging.exception("Exception in healthcheck") + logging.exception(e) + raise fastapi.HTTPException(status_code=500, detail="Healthcheck failed") + @app.post("/generate", status_code=200) async def generate(request: GenerateRequest):