Skip to content

Commit

Permalink
Support model revision and tokenizer revision in huggingface server (k…
Browse files Browse the repository at this point in the history
…serve#3558)

* support model revision and tokenizer revision

Signed-off-by: Lize Cai <lize.cai@sap.com>

* point to specified commit in test case

Signed-off-by: Lize Cai <lize.cai@sap.com>

* format code

Signed-off-by: Lize Cai <lize.cai@sap.com>

---------

Signed-off-by: Lize Cai <lize.cai@sap.com>
Signed-off-by: Dan Sun <dsun20@bloomberg.net>
Co-authored-by: Dan Sun <dsun20@bloomberg.net>
Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com>
  • Loading branch information
2 people authored and tjandy98 committed Apr 10, 2024
1 parent 7f5da68 commit 5f6993c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 11 deletions.
10 changes: 10 additions & 0 deletions python/huggingfaceserver/huggingfaceserver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def list_of_strings(arg):
"--model_dir", required=False, default=None, help="A local path to the model binary"
)
parser.add_argument("--model_id", required=False, help="Huggingface model id")
parser.add_argument(
"--model_revision", required=False, default=None, help="Huggingface model revision"
)
parser.add_argument(
"--tokenizer_revision",
required=False,
default=None,
help="Huggingface tokenizer revision",
)
parser.add_argument(
"--max_length", type=int, default=None, help="max sequence length for the tokenizer"
)
Expand Down Expand Up @@ -74,6 +83,7 @@ def list_of_strings(arg):
engine_args = None
if _vllm and not args.disable_vllm:
args.model = args.model_dir or args.model_id
args.revision = args.model_revision
engine_args = AsyncEngineArgs.from_cli_args(args)
predictor_config = PredictorConfig(
args.predictor_host,
Expand Down
30 changes: 19 additions & 11 deletions python/huggingfaceserver/huggingfaceserver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
self.model_dir = kwargs.get("model_dir", None)
if not self.model_id and not self.model_dir:
self.model_dir = "/mnt/models"
self.model_revision = kwargs.get("model_revision", None)
self.tokenizer_revision = kwargs.get("tokenizer_revision", None)
self.do_lower_case = not kwargs.get("disable_lower_case", False)
self.add_special_tokens = not kwargs.get("disable_special_tokens", False)
self.max_length = kwargs.get("max_length", None)
Expand Down Expand Up @@ -111,8 +113,7 @@ def infer_task_from_model_architecture(model_config: str):
)

@staticmethod
def infer_vllm_supported_from_model_architecture(model_config_path: str):
model_config = AutoConfig.from_pretrained(model_config_path)
def infer_vllm_supported_from_model_architecture(model_config: str):
architecture = model_config.architectures[0]
model_cls = ModelRegistry.load_model_cls(architecture)
if model_cls is None:
Expand All @@ -121,20 +122,24 @@ def infer_vllm_supported_from_model_architecture(model_config_path: str):

def load(self) -> bool:
model_id_or_path = self.model_id
revision = self.model_revision
tokenizer_revision = self.tokenizer_revision
if self.model_dir:
model_id_or_path = pathlib.Path(Storage.download(self.model_dir))
# TODO Read the mapping file, index to object name

model_config = AutoConfig.from_pretrained(model_id_or_path, revision=revision)

if self.use_vllm and self.device == torch.device("cuda"): # vllm needs gpu
if self.infer_vllm_supported_from_model_architecture(model_id_or_path):
if self.infer_vllm_supported_from_model_architecture(model_config):
logger.info("supported model by vLLM")
self.vllm_engine_args.tensor_parallel_size = torch.cuda.device_count()
self.vllm_engine = AsyncLLMEngine.from_engine_args(
self.vllm_engine_args
)
self.ready = True
return self.ready

model_config = AutoConfig.from_pretrained(model_id_or_path)

if not self.task:
self.task = self.infer_task_from_model_architecture(model_config)

Expand All @@ -154,16 +159,19 @@ def load(self) -> bool:
# https://github.com/huggingface/transformers/blob/1248f0925234f97da9eee98da2aa22f7b8dbeda1/src/transformers/generation/utils.py#L1376-L1388
self.tokenizer = AutoTokenizer.from_pretrained(
model_id_or_path,
revision=tokenizer_revision,
do_lower_case=self.do_lower_case,
device_map=self.device_map,
padding_side="left",
)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
model_id_or_path,
revision=tokenizer_revision,
do_lower_case=self.do_lower_case,
device_map=self.device_map,
)

if not self.tokenizer.pad_token:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
logger.info(f"successfully loaded tokenizer for task: {self.task}")
Expand All @@ -172,27 +180,27 @@ def load(self) -> bool:
if not self.predictor_host:
if self.task == MLTask.sequence_classification.value:
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id_or_path, device_map=self.device_map
model_id_or_path, revision=revision, device_map=self.device_map
)
elif self.task == MLTask.question_answering.value:
self.model = AutoModelForQuestionAnswering.from_pretrained(
model_id_or_path, device_map=self.device_map
model_id_or_path, revision=revision, device_map=self.device_map
)
elif self.task == MLTask.token_classification.value:
self.model = AutoModelForTokenClassification.from_pretrained(
model_id_or_path, device_map=self.device_map
model_id_or_path, revision=revision, device_map=self.device_map
)
elif self.task == MLTask.fill_mask.value:
self.model = AutoModelForMaskedLM.from_pretrained(
model_id_or_path, device_map=self.device_map
model_id_or_path, revision=revision, device_map=self.device_map
)
elif self.task == MLTask.text_generation.value:
self.model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, device_map=self.device_map
model_id_or_path, revision=revision, device_map=self.device_map
)
elif self.task == MLTask.text2text_generation.value:
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_id_or_path, device_map=self.device_map
model_id_or_path, revision=revision, device_map=self.device_map
)
else:
raise ValueError(
Expand Down
28 changes: 28 additions & 0 deletions python/huggingfaceserver/huggingfaceserver/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ def test_bert():
assert response == {"predictions": ["paris", "france"]}


def test_model_revision():
# https://huggingface.co/google-bert/bert-base-uncased
commit = "86b5e0934494bd15c9632b12f734a8a67f723594"
model = HuggingfaceModel(
"bert-base-uncased",
{
"model_id": "bert-base-uncased",
"model_revision": commit,
"tokenizer_revision": commit,
"disable_lower_case": False,
},
)
model.load()

response = asyncio.run(
model(
{
"instances": [
"The capital of France is [MASK].",
"The capital of [MASK] is paris.",
]
},
headers={},
)
)
assert response == {"predictions": ["paris", "france"]}


def test_bert_predictor_host(httpx_mock: HTTPXMock):
httpx_mock.add_response(
json={
Expand Down

0 comments on commit 5f6993c

Please sign in to comment.