Skip to content

Commit

Permalink
Update multi adapter lora notebook (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Oct 2, 2024
1 parent 1e0451f commit bbc192c
Showing 1 changed file with 10 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@
"source": [
"%%writefile lora-multi-adapter/serving.properties\n",
"option.model_id=huggyllama/llama-7b\n",
"option.engine=Python\n",
"option.rolling_batch=vllm\n",
"option.engine=MPI\n",
"option.rolling_batch=lmi-dist\n",
"option.tensor_parallel_degree=1\n",
"option.enable_lora=true\n",
"option.gpu_memory_utilization=0.8"
Expand Down Expand Up @@ -187,7 +187,7 @@
"metadata": {},
"outputs": [],
"source": [
"role = \"arn:aws:iam::125045733377:role/AmazonSageMaker-ExecutionRole-djl\" # execution role for the endpoint\n",
"role = sagemaker.get_execution_role() # execution role for the endpoint\n",
"sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs\n",
"bucket = sess.default_bucket() # bucket to house artifacts\n",
"model_bucket = sess.default_bucket() # bucket to house artifacts\n",
Expand Down Expand Up @@ -219,13 +219,12 @@
"outputs": [],
"source": [
"inference_image_uri = image_uris.retrieve(\n",
" framework=\"djl-deepspeed\",\n",
" framework=\"djl-lmi\",\n",
" region=region,\n",
" version=\"0.27.0\"\n",
" )",
" version=\"0.29.0\"\n",
" )\n",
"model_name_acc = name_from_base(f\"lora-multi-adapter\")\n",
"\n",
"# LoRA Adapters feature is a preview feature and ENABLE_ADAPTERS_PREVIEW environmnet variable should be set to use it\n",
"create_model_response = sm_client.create_model(\n",
" ModelName=model_name_acc,\n",
" ExecutionRoleArn=role,\n",
Expand Down Expand Up @@ -324,8 +323,7 @@
"\n",
"response_model = smr_client.invoke_endpoint(\n",
" EndpointName=endpoint_name,\n",
" Body=json.dumps({\"inputs\": [\"Tell me about Alpacas\", \"Invente uma desculpa criativa pra dizer que não preciso ir à festa.\", \"Tell me about AWS\"],\n",
" \"adapters\": [\"eng_alpaca\", \"portuguese_alpaca\", \"eng_alpaca\"]}),\n",
" Body=json.dumps({\"inputs\": \"Tell me about Alpacas\", \"adapters\": \"eng_alpaca\"}),\n",
" ContentType=\"application/json\",\n",
")\n",
"\n",
Expand Down Expand Up @@ -362,9 +360,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "conda_python3",
"language": "python",
"name": "python3"
"name": "conda_python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -376,7 +374,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down

0 comments on commit bbc192c

Please sign in to comment.