Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Logprobs] Support logprobs=1 #2612

Merged
merged 10 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0


class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)


class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
Expand Down Expand Up @@ -160,7 +167,7 @@ class CompletionRequest(BaseModel):
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[int] = None
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand All @@ -176,7 +183,7 @@ class CompletionResponse(BaseModel):
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[float] = None
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
47 changes: 46 additions & 1 deletion fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def generate_stream(
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 256))
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
Expand All @@ -99,6 +100,8 @@ def generate_stream(
input_echo_len = len(input_ids)

if model.config.is_encoder_decoder:
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
raise NotImplementedError
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
Expand All @@ -107,8 +110,11 @@ def generate_stream(
dtype=torch.int64,
device=device,
)
else:
start_ids = torch.as_tensor([input_ids], device=device)

past_key_values = out = None
token_logprobs = [None] # The first token has no logprobs.
sent_interrupt = False
finish_reason = None
for i in range(max_new_tokens):
Expand All @@ -121,9 +127,19 @@ def generate_stream(
)
logits = model.lm_head(out[0])
else:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
out = model(input_ids=start_ids, use_cache=True)
logits = out.logits
past_key_values = out.past_key_values

if logprobs is not None:
# Prefull logprobs for the prompt.
shift_input_ids = start_ids[..., 1:].contiguous()
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
for label_id, logit in zip(
shift_input_ids[0].tolist(), shift_logits[0]
):
token_logprobs.append(logit[label_id])
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
Expand Down Expand Up @@ -173,6 +189,11 @@ def generate_stream(
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_ids.append(token)
if logprobs is not None:
# Cannot use last_token_logits because logprobs is based on raw logits.
token_logprobs.append(
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
)

if token in stop_token_ids:
stopped = True
Expand All @@ -194,6 +215,28 @@ def generate_stream(
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
ret_logprobs = None
if logprobs is not None:
ret_logprobs = {
"text_offset": [],
"tokens": [
tokenizer.decode(token)
for token in (
output_ids if echo else output_ids[input_echo_len:]
)
],
"token_logprobs": token_logprobs
if echo
else token_logprobs[input_echo_len:],
"top_logprobs": [{}]
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
}
# Compute text_offset
curr_pos = 0
for text in ret_logprobs["tokens"]:
ret_logprobs["text_offset"].append(curr_pos)
curr_pos += len(text)

# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
if judge_sent_end and stopped and not is_sentence_complete(output):
if len(tokens) > 1:
Expand Down Expand Up @@ -231,6 +274,7 @@ def generate_stream(
if not partially_stopped:
yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
Expand All @@ -251,6 +295,7 @@ def generate_stream(

yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
Expand Down
14 changes: 12 additions & 2 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
EmbeddingsRequest,
EmbeddingsResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
Expand Down Expand Up @@ -229,6 +230,11 @@ def process_input(model_name, inp):
return inp


def create_openai_logprobs(logprob_dict):
"""Create OpenAI-style logprobs."""
return LogProbs(**logprob_dict) if logprob_dict is not None else None


def _add_to_set(s, new_stop):
if not s:
return
Expand All @@ -250,6 +256,7 @@ async def get_gen_params(
frequency_penalty: Optional[float],
max_tokens: Optional[int],
echo: Optional[bool],
logprobs: Optional[int] = None,
stop: Optional[Union[str, List[str]]],
best_of: Optional[int] = None,
use_beam_search: Optional[bool] = None,
Expand Down Expand Up @@ -291,6 +298,7 @@ async def get_gen_params(
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"logprobs": logprobs,
"top_p": top_p,
"top_k": top_k,
"presence_penalty": presence_penalty,
Expand Down Expand Up @@ -516,6 +524,7 @@ async def create_completion(request: CompletionRequest):
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
stop=request.stop,
best_of=request.best_of,
Expand All @@ -541,7 +550,7 @@ async def create_completion(request: CompletionRequest):
CompletionResponseChoice(
index=i,
text=content["text"],
logprobs=content.get("logprobs", None),
logprobs=create_openai_logprobs(content.get("logprobs", None)),
finish_reason=content.get("finish_reason", "stop"),
)
)
Expand Down Expand Up @@ -573,6 +582,7 @@ async def generate_completion_stream_generator(
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
stop=request.stop,
)
Expand All @@ -592,7 +602,7 @@ async def generate_completion_stream_generator(
choice_data = CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=content.get("logprobs", None),
logprobs=create_openai_logprobs(content.get("logprobs", None)),
finish_reason=content.get("finish_reason", None),
)
chunk = CompletionStreamResponse(
Expand Down
1 change: 0 additions & 1 deletion docs/commands/test_process.md → tests/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
## Unit tests for FastChat
The scripts are under [FastChat/tests](../../tests).

### Test CLI Inference

Expand Down
18 changes: 14 additions & 4 deletions tests/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ def test_list_models():
return names


def test_completion(model):
def test_completion(model, logprob):
prompt = "Once upon a time"
completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64)
print(prompt + completion.choices[0].text)
completion = openai.Completion.create(
model=model, prompt=prompt, logprobs=logprob, max_tokens=64
)
print(f"full text: {prompt + completion.choices[0].text}", flush=True)
if completion.choices[0].logprobs is not None:
print(f"logprobs: {completion.choices[0].logprobs.token_logprobs}", flush=True)


def test_completion_stream(model):
Expand Down Expand Up @@ -104,7 +108,13 @@ def test_openai_curl():

for model in models:
print(f"===== Test {model} ======")
test_completion(model)

if model in ["fastchat-t5-3b-v1.0"]:
logprob = None
else:
logprob = 1

test_completion(model, logprob)
test_completion_stream(model)
test_chat_completion(model)
test_chat_completion_stream(model)
Expand Down