Skip to content

Commit

Permalink
fix(core): Move the last user's information to the end (#960)
Browse files Browse the repository at this point in the history
  • Loading branch information
vvycaaa authored Dec 22, 2023
1 parent 6b982e2 commit d906522
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 21 deletions.
13 changes: 6 additions & 7 deletions dbgpt/core/interface/message.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,13 @@ def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
else:
pass
# Move the last user's information to the end
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input:
history.remove(last_user_input)
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
return history

Expand Down
57 changes: 57 additions & 0 deletions dbgpt/core/interface/tests/test_message.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ def conversation_with_messages():
return conv


@pytest.fixture
def human_model_message():
return ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")


@pytest.fixture
def ai_model_message():
return ModelMessage(role=ModelMessageRoleType.AI, content="Hi there")


@pytest.fixture
def system_model_message():
return ModelMessage(
role=ModelMessageRoleType.SYSTEM, content="You are a helpful chatbot!"
)


def test_init(basic_conversation):
assert basic_conversation.chat_mode == "chat_normal"
assert basic_conversation.user_name == "user1"
Expand Down Expand Up @@ -305,3 +322,43 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
assert new_conversation.messages[1].content == "AI response"
assert isinstance(new_conversation.messages[0], HumanMessage)
assert isinstance(new_conversation.messages[1], AIMessage)


def test_to_openai_messages(
human_model_message, ai_model_message, system_model_message
):
none_messages = ModelMessage.to_openai_messages([])
assert none_messages == []

single_messages = ModelMessage.to_openai_messages([human_model_message])
assert single_messages == [{"role": "user", "content": human_model_message.content}]

normal_messages = ModelMessage.to_openai_messages(
[
system_model_message,
human_model_message,
ai_model_message,
human_model_message,
]
)
assert normal_messages == [
{"role": "system", "content": system_model_message.content},
{"role": "user", "content": human_model_message.content},
{"role": "assistant", "content": ai_model_message.content},
{"role": "user", "content": human_model_message.content},
]

shuffle_messages = ModelMessage.to_openai_messages(
[
system_model_message,
human_model_message,
human_model_message,
ai_model_message,
]
)
assert shuffle_messages == [
{"role": "system", "content": system_model_message.content},
{"role": "user", "content": human_model_message.content},
{"role": "assistant", "content": ai_model_message.content},
{"role": "user", "content": human_model_message.content},
]
13 changes: 6 additions & 7 deletions dbgpt/model/proxy/llms/bard.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ def bard_generate_stream(
else:
pass

temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input:
history.remove(last_user_input)
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)

msgs = []
Expand Down
13 changes: 6 additions & 7 deletions dbgpt/model/proxy/llms/chatgpt.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,13 @@ def _build_request(model: ProxyModel, params):
pass

# Move the last user's information to the end
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input:
history.remove(last_user_input)
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)

payloads = {
Expand Down

0 comments on commit d906522

Please sign in to comment.