Skip to content

Commit

Permalink
[Logprobs] Support logprobs=1 (#2612)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Nov 1, 2023
1 parent d5e4b27 commit af4dfe3
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 10 deletions.
11 changes: 9 additions & 2 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0


class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)


class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
Expand Down Expand Up @@ -160,7 +167,7 @@ class CompletionRequest(BaseModel):
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[int] = None
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand All @@ -176,7 +183,7 @@ class CompletionResponse(BaseModel):
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[float] = None
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
47 changes: 46 additions & 1 deletion fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def generate_stream(
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 256))
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
Expand All @@ -99,6 +100,8 @@ def generate_stream(
input_echo_len = len(input_ids)

if model.config.is_encoder_decoder:
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
raise NotImplementedError
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
Expand All @@ -107,8 +110,11 @@ def generate_stream(
dtype=torch.int64,
device=device,
)
else:
start_ids = torch.as_tensor([input_ids], device=device)

past_key_values = out = None
token_logprobs = [None] # The first token has no logprobs.
sent_interrupt = False
finish_reason = None
for i in range(max_new_tokens):
Expand All @@ -121,9 +127,19 @@ def generate_stream(
)
logits = model.lm_head(out[0])
else:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
out = model(input_ids=start_ids, use_cache=True)
logits = out.logits
past_key_values = out.past_key_values

if logprobs is not None:
# Prefull logprobs for the prompt.
shift_input_ids = start_ids[..., 1:].contiguous()
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
for label_id, logit in zip(
shift_input_ids[0].tolist(), shift_logits[0]
):
token_logprobs.append(logit[label_id])
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
Expand Down Expand Up @@ -173,6 +189,11 @@ def generate_stream(
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_ids.append(token)
if logprobs is not None:
# Cannot use last_token_logits because logprobs is based on raw logits.
token_logprobs.append(
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
)

if token in stop_token_ids:
stopped = True
Expand All @@ -194,6 +215,28 @@ def generate_stream(
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
ret_logprobs = None
if logprobs is not None:
ret_logprobs = {
"text_offset": [],
"tokens": [
tokenizer.decode(token)
for token in (
output_ids if echo else output_ids[input_echo_len:]
)
],
"token_logprobs": token_logprobs
if echo
else token_logprobs[input_echo_len:],
"top_logprobs": [{}]
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
}
# Compute text_offset
curr_pos = 0
for text in ret_logprobs["tokens"]:
ret_logprobs["text_offset"].append(curr_pos)
curr_pos += len(text)

# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
if judge_sent_end and stopped and not is_sentence_complete(output):
if len(tokens) > 1:
Expand Down Expand Up @@ -231,6 +274,7 @@ def generate_stream(
if not partially_stopped:
yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
Expand All @@ -251,6 +295,7 @@ def generate_stream(

yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
Expand Down
14 changes: 12 additions & 2 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
EmbeddingsRequest,
EmbeddingsResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
Expand Down Expand Up @@ -229,6 +230,11 @@ def process_input(model_name, inp):
return inp


def create_openai_logprobs(logprob_dict):
"""Create OpenAI-style logprobs."""
return LogProbs(**logprob_dict) if logprob_dict is not None else None


def _add_to_set(s, new_stop):
if not s:
return
Expand All @@ -250,6 +256,7 @@ async def get_gen_params(
frequency_penalty: Optional[float],
max_tokens: Optional[int],
echo: Optional[bool],
logprobs: Optional[int] = None,
stop: Optional[Union[str, List[str]]],
best_of: Optional[int] = None,
use_beam_search: Optional[bool] = None,
Expand Down Expand Up @@ -291,6 +298,7 @@ async def get_gen_params(
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"logprobs": logprobs,
"top_p": top_p,
"top_k": top_k,
"presence_penalty": presence_penalty,
Expand Down Expand Up @@ -516,6 +524,7 @@ async def create_completion(request: CompletionRequest):
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
stop=request.stop,
best_of=request.best_of,
Expand All @@ -541,7 +550,7 @@ async def create_completion(request: CompletionRequest):
CompletionResponseChoice(
index=i,
text=content["text"],
logprobs=content.get("logprobs", None),
logprobs=create_openai_logprobs(content.get("logprobs", None)),
finish_reason=content.get("finish_reason", "stop"),
)
)
Expand Down Expand Up @@ -573,6 +582,7 @@ async def generate_completion_stream_generator(
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
stop=request.stop,
)
Expand All @@ -592,7 +602,7 @@ async def generate_completion_stream_generator(
choice_data = CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=content.get("logprobs", None),
logprobs=create_openai_logprobs(content.get("logprobs", None)),
finish_reason=content.get("finish_reason", None),
)
chunk = CompletionStreamResponse(
Expand Down
1 change: 0 additions & 1 deletion docs/commands/test_process.md → tests/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
## Unit tests for FastChat
The scripts are under [FastChat/tests](../../tests).

### Test CLI Inference

Expand Down
18 changes: 14 additions & 4 deletions tests/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ def test_list_models():
return names


def test_completion(model):
def test_completion(model, logprob):
prompt = "Once upon a time"
completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64)
print(prompt + completion.choices[0].text)
completion = openai.Completion.create(
model=model, prompt=prompt, logprobs=logprob, max_tokens=64
)
print(f"full text: {prompt + completion.choices[0].text}", flush=True)
if completion.choices[0].logprobs is not None:
print(f"logprobs: {completion.choices[0].logprobs.token_logprobs}", flush=True)


def test_completion_stream(model):
Expand Down Expand Up @@ -104,7 +108,13 @@ def test_openai_curl():

for model in models:
print(f"===== Test {model} ======")
test_completion(model)

if model in ["fastchat-t5-3b-v1.0"]:
logprob = None
else:
logprob = 1

test_completion(model, logprob)
test_completion_stream(model)
test_chat_completion(model)
test_chat_completion_stream(model)
Expand Down

0 comments on commit af4dfe3

Please sign in to comment.