From d9065227bd61359e4ab71a499c33ace7f61520f8 Mon Sep 17 00:00:00 2001 From: vvycaaa <147325516+vvycaaa@users.noreply.github.com> Date: Fri, 22 Dec 2023 09:42:58 +0800 Subject: [PATCH] fix(core): Move the last user's information to the end (#960) --- dbgpt/core/interface/message.py | 13 +++-- dbgpt/core/interface/tests/test_message.py | 57 ++++++++++++++++++++++ dbgpt/model/proxy/llms/bard.py | 13 +++-- dbgpt/model/proxy/llms/chatgpt.py | 13 +++-- 4 files changed, 75 insertions(+), 21 deletions(-) mode change 100644 => 100755 dbgpt/core/interface/message.py mode change 100644 => 100755 dbgpt/core/interface/tests/test_message.py mode change 100644 => 100755 dbgpt/model/proxy/llms/bard.py mode change 100644 => 100755 dbgpt/model/proxy/llms/chatgpt.py diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py old mode 100644 new mode 100755 index bd06f0dc7..2b1439c6d --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -157,14 +157,13 @@ def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]: else: pass # Move the last user's information to the end - temp_his = history[::-1] - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i break - if last_user_input: - history.remove(last_user_input) + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) history.append(last_user_input) return history diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py old mode 100644 new mode 100755 index 425f268af..41f5f36c5 --- a/dbgpt/core/interface/tests/test_message.py +++ b/dbgpt/core/interface/tests/test_message.py @@ -67,6 +67,23 @@ def conversation_with_messages(): return conv +@pytest.fixture +def human_model_message(): + return ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello") + + +@pytest.fixture +def ai_model_message(): + return ModelMessage(role=ModelMessageRoleType.AI, content="Hi there") + + +@pytest.fixture +def system_model_message(): + return ModelMessage( + role=ModelMessageRoleType.SYSTEM, content="You are a helpful chatbot!" + ) + + def test_init(basic_conversation): assert basic_conversation.chat_mode == "chat_normal" assert basic_conversation.user_name == "user1" @@ -305,3 +322,43 @@ def test_load_from_storage(storage_conversation, in_memory_storage): assert new_conversation.messages[1].content == "AI response" assert isinstance(new_conversation.messages[0], HumanMessage) assert isinstance(new_conversation.messages[1], AIMessage) + + +def test_to_openai_messages( + human_model_message, ai_model_message, system_model_message +): + none_messages = ModelMessage.to_openai_messages([]) + assert none_messages == [] + + single_messages = ModelMessage.to_openai_messages([human_model_message]) + assert single_messages == [{"role": "user", "content": human_model_message.content}] + + normal_messages = ModelMessage.to_openai_messages( + [ + system_model_message, + human_model_message, + ai_model_message, + human_model_message, + ] + ) + assert normal_messages == [ + {"role": "system", "content": system_model_message.content}, + {"role": "user", "content": human_model_message.content}, + {"role": "assistant", "content": ai_model_message.content}, + {"role": "user", "content": human_model_message.content}, + ] + + shuffle_messages = ModelMessage.to_openai_messages( + [ + system_model_message, + human_model_message, + human_model_message, + ai_model_message, + ] + ) + assert shuffle_messages == [ + {"role": "system", "content": system_model_message.content}, + {"role": "user", "content": human_model_message.content}, + {"role": "assistant", "content": ai_model_message.content}, + {"role": "user", "content": human_model_message.content}, + ] diff --git a/dbgpt/model/proxy/llms/bard.py b/dbgpt/model/proxy/llms/bard.py old mode 100644 new mode 100755 index 5b6ed26a0..fc398fe8b --- a/dbgpt/model/proxy/llms/bard.py +++ b/dbgpt/model/proxy/llms/bard.py @@ -25,14 +25,13 @@ def bard_generate_stream( else: pass - temp_his = history[::-1] - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i break - if last_user_input: - history.remove(last_user_input) + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) history.append(last_user_input) msgs = [] diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py old mode 100644 new mode 100755 index e9229da44..d81626e7a --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -110,14 +110,13 @@ def _build_request(model: ProxyModel, params): pass # Move the last user's information to the end - temp_his = history[::-1] - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i break - if last_user_input: - history.remove(last_user_input) + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) history.append(last_user_input) payloads = {