Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI o1 model compatibility #719

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions autorag/nodes/generator/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@
logger = logging.getLogger("AutoRAG")

MAX_TOKEN_DICT = { # model name : token limit
"o1-preview": 128_000,
"o1-preview-2024-09-12": 128_000,
"o1-mini": 128_000,
"o1-mini-2024-09-12": 128_000,
"gpt-4o-mini": 128_000,
"gpt-4o-mini-2024-07-18": 128_000,
"gpt-4o": 128_000,
"gpt-4o-2024-08-06": 128_000,
"gpt-4o-2024-05-13": 128_000,
"chatgpt-4o-latest": 128_000,
"gpt-4-turbo": 128_000,
"gpt-4-turbo-2024-04-09": 128_000,
"gpt-4-turbo-preview": 128_000,
Expand Down Expand Up @@ -82,7 +88,11 @@ def openai_llm(
kwargs.pop("n")
logger.warning("parameter n does not effective. It always set to 1.")

tokenizer = tiktoken.encoding_for_model(llm)
# TODO: fix this after updating tiktoken for the o1 model. It is not yet supported yet.
if llm.startswith("o1"):
tokenizer = tiktoken.get_encoding("o200k_base")
else:
tokenizer = tiktoken.encoding_for_model(llm)
if truncate:
max_token_size = MAX_TOKEN_DICT.get(llm) - 7 # because of chat token usage
if max_token_size is None:
Expand All @@ -99,7 +109,15 @@ def openai_llm(

client = AsyncOpenAI(api_key=api_key)
loop = get_event_loop()
tasks = [get_result(prompt, client, llm, tokenizer, **kwargs) for prompt in prompts]
if llm.startswith("o1"):
tasks = [
get_result_o1(prompt, client, llm, tokenizer, **kwargs)
for prompt in prompts
]
else:
tasks = [
get_result(prompt, client, llm, tokenizer, **kwargs) for prompt in prompts
]
result = loop.run_until_complete(process_batch(tasks, batch))
answer_result = list(map(lambda x: x[0], result))
token_result = list(map(lambda x: x[1], result))
Expand Down Expand Up @@ -132,6 +150,31 @@ async def get_result(
return answer, tokens, logprobs


async def get_result_o1(
prompt: str, client: AsyncOpenAI, model: str, tokenizer: Encoding, **kwargs
):
assert model.startswith("o1"), "This function only supports o1 model."
# The default temperature for the o1 model is 1. 1 is only supported.
# See https://platform.openai.com/docs/guides/reasoning about beta limitation of o1 models.
kwargs["temperature"] = 1
kwargs["top_p"] = 1
kwargs["presence_penalty"] = 0
kwargs["frequency_penalty"] = 0
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": prompt},
],
logprobs=False,
n=1,
**kwargs,
)
answer = response.choices[0].message.content
tokens = tokenizer.encode(answer, allowed_special="all")
pseudo_log_probs = [0.5] * len(tokens)
return answer, tokens, pseudo_log_probs


def truncate_by_token(prompt: str, tokenizer: Encoding, max_token_size: int):
tokens = tokenizer.encode(prompt, allowed_special="all")
return tokenizer.decode(tokens[:max_token_size])
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ llama-index-readers-file
llama-index-embeddings-openai
llama-index-embeddings-huggingface
# LLMs
llama-index-llms-openai>=0.1.26
llama-index-llms-openai>=0.2.7
llama-index-llms-huggingface
llama-index-llms-openai-like
llama-index-llms-ollama
Expand Down