-
Notifications
You must be signed in to change notification settings - Fork 2
/
deploy_model.py
68 lines (61 loc) · 2.76 KB
/
deploy_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import subprocess
import argparse
def run_docker_container(model, data_path, gpus, port, follow_logs=False):
# Prepare the docker run command
# calculate num shards(num gpus):
print(f'Deploying model from: {data_path} on GPU(s): {gpus}, port: {port}')
num_shards = len(gpus.split(','))
container_name = f"tgi_{''.join(gpus.split(','))}"
if num_shards > 1:
docker_command = [
'docker', 'run',
'--rm', '-d',
f'--gpus', f'"device={gpus}"',
'--shm-size', '32g',
'-p', f'{port}:80',
'-v', f'{data_path}:/data',
'--name', container_name,
'ghcr.io/huggingface/text-generation-inference:1.1.0',
'--model-id', f'{model}',
'--sharded', 'true',
'--num-shard', f'{num_shards}',
'--max-input-length=3000',
'--max-total-tokens=4096',
'--max-best-of=8',
'--max-stop-sequences=20',
'--max-batch-prefill-tokens=4096',
# '--trust-remote-code'
]
else:
docker_command = [
'docker', 'run',
'--rm', '-d',
f'--gpus', f'"device={gpus}"',
'--shm-size', '32g',
'-p', f'{port}:80',
'-v', f'{data_path}:/data',
'--name', container_name,
'ghcr.io/huggingface/text-generation-inference:1.1.0',
'--model-id', f'{model}',
'--sharded', 'false',
'--max-total-tokens=4096',
'--max-input-length=3000',
'--max-best-of=8',
'--max-stop-sequences=20',
'--max-batch-prefill-tokens=4096',
'--trust-remote-code'
]
# Execute the docker run command
subprocess.run(docker_command)
if follow_logs:
log_command = ['docker', 'logs', '-f', container_name]
subprocess.run(log_command)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run a Docker container with specified GPUs and port.')
parser.add_argument('-m', '--model', type=str, required=False, help='Specify the model to use, e.g., "gpt2"', default='/data')
parser.add_argument('-d', '--data-path', type=str, required=False, help='Specify the path to the checkpoints, e.g., "/data"', default='/hdd1/yzh/ckpts/')
parser.add_argument('-g', '--gpus', type=str, required=True, help='Specify the GPUs to use, e.g., "0,1"')
parser.add_argument('-p', '--port', type=int, required=True, help='Specify the port to use, e.g., 8081')
parser.add_argument('-f', '--follow-logs', action='store_true', help='Follow logs immediately after starting the container')
args = parser.parse_args()
run_docker_container(args.model, args.data_path, args.gpus, args.port, args.follow_logs)