Skip to content

Commit

Permalink
Conversation get_messages gets code_message (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Feb 27, 2024
1 parent 9178325 commit 7b04721
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 128 deletions.
15 changes: 4 additions & 11 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
ChatCompletionSystemMessageParam,
)

from mentat.llm_api_handler import prompt_tokens
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT
from mentat.session_input import ask_yes_no, collect_user_input
Expand Down Expand Up @@ -42,7 +41,7 @@ async def enable_agent_mode(self):
"Finding files to determine how to test changes...", style="info"
)
features = ctx.code_context.get_all_features(split_intervals=False)
messages: List[ChatCompletionMessageParam] = [
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(
role="system", content=self.agent_file_selection_prompt
),
Expand Down Expand Up @@ -85,21 +84,15 @@ async def _determine_commands(self) -> List[str]:
ctx = SESSION_CONTEXT.get()

model = ctx.config.model
messages = [
system_prompt: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(
role="system", content=self.agent_command_prompt
),
ChatCompletionSystemMessageParam(
role="system", content=self.agent_file_message
),
] + ctx.conversation.get_messages(include_system_prompt=False)
code_message = await ctx.code_context.get_code_message(
prompt_tokens=prompt_tokens(messages, model)
)
code_message = ChatCompletionSystemMessageParam(
role="system", content=code_message
)
messages.insert(1, code_message)
]
messages = await ctx.conversation.get_messages(system_prompt=system_prompt)

try:
# TODO: Should this even be a separate call or should we collect commands in the edit call?
Expand Down
57 changes: 16 additions & 41 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, TypedDict, Union

from openai.types.chat import ChatCompletionSystemMessageParam

from mentat.code_feature import (
CodeFeature,
get_code_message_from_features,
Expand All @@ -27,12 +25,7 @@
validate_and_format_path,
)
from mentat.interval import parse_intervals, split_intervals_from_path
from mentat.llm_api_handler import (
count_tokens,
get_max_tokens,
prompt_tokens,
raise_if_context_exceeds_max,
)
from mentat.llm_api_handler import count_tokens, get_max_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream

Expand Down Expand Up @@ -68,7 +61,7 @@ def __init__(
self.ignore_files: Set[Path] = set()
self.auto_features: List[CodeFeature] = []

def refresh_context_display(self):
async def refresh_context_display(self):
"""
Sends a message to the client with the code context. It is called in the main loop.
"""
Expand All @@ -87,24 +80,7 @@ def refresh_context_display(self):
git_diff_paths = [str(p) for p in self.diff_context.diff_files()]
git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()]

messages = ctx.conversation.get_messages()
code_message = get_code_message_from_features(
[
feature
for file_features in self.include_files.values()
for feature in file_features
]
+ self.auto_features
)
total_tokens = prompt_tokens(
messages
+ [
ChatCompletionSystemMessageParam(
role="system", content="\n".join(code_message)
)
],
ctx.config.model,
)
total_tokens = await ctx.conversation.count_tokens(include_code_message=True)

total_cost = ctx.cost_tracker.total_cost

Expand All @@ -126,7 +102,6 @@ async def get_code_message(
prompt_tokens: int,
prompt: Optional[str] = None,
expected_edits: Optional[list[str]] = None, # for training/benchmarking
suppress_context_check: bool = False,
) -> str:
"""
Retrieves the current code message.
Expand Down Expand Up @@ -154,29 +129,29 @@ async def get_code_message(
]

code_message += ["Code Files:\n"]
meta_tokens = count_tokens("\n".join(code_message), model, full_message=True)

# Calculate user included features token size
include_features = [
feature
for file_features in self.include_files.values()
for feature in file_features
]
include_files_message = get_code_message_from_features(include_features)
include_files_tokens = count_tokens(
"\n".join(include_files_message), model, full_message=False
)

tokens_used = prompt_tokens + meta_tokens + include_files_tokens
if not suppress_context_check:
raise_if_context_exceeds_max(tokens_used)
auto_tokens = min(
get_max_tokens() - tokens_used - config.token_buffer,
config.auto_context_tokens,
)

# Get auto included features
if config.auto_context_tokens > 0 and prompt:
meta_tokens = count_tokens(
"\n".join(code_message), model, full_message=True
)
include_files_message = get_code_message_from_features(include_features)
include_files_tokens = count_tokens(
"\n".join(include_files_message), model, full_message=False
)

tokens_used = prompt_tokens + meta_tokens + include_files_tokens
auto_tokens = min(
get_max_tokens() - tokens_used - config.token_buffer,
config.auto_context_tokens,
)
features = self.get_all_features()
feature_filter = DefaultFilter(
auto_tokens,
Expand Down
116 changes: 62 additions & 54 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
count_tokens,
get_max_tokens,
prompt_tokens,
raise_if_context_exceeds_max,
)
from mentat.parsers.file_edit import FileEdit
from mentat.parsers.parser import ParsedLLMResponse
Expand All @@ -45,21 +46,7 @@ async def display_token_count(self):
config = session_context.config
code_context = session_context.code_context

messages = self.get_messages()
code_message = await code_context.get_code_message(
prompt_tokens(
messages,
config.model,
),
suppress_context_check=True,
)
messages.append(
ChatCompletionSystemMessageParam(
role="system",
content=code_message,
)
)
tokens = prompt_tokens(messages, config.model)
tokens = await self.count_tokens(include_code_message=True)

context_size = get_max_tokens()
if tokens + config.token_buffer > context_size:
Expand Down Expand Up @@ -130,17 +117,28 @@ def add_message(self, message: ChatCompletionMessageParam):
"""Used for adding messages to the models conversation. Does not add a left-side message to the transcript!"""
self._messages.append(message)

def get_messages(
async def count_tokens(
self,
include_system_prompt: bool = True,
system_prompt: Optional[list[ChatCompletionMessageParam]] = None,
include_code_message: bool = False,
) -> int:
_messages = await self.get_messages(
system_prompt=system_prompt, include_code_message=include_code_message
)
model = SESSION_CONTEXT.get().config.model
return prompt_tokens(_messages, model)

async def get_messages(
self,
system_prompt: Optional[list[ChatCompletionMessageParam]] = None,
include_parsed_llm_responses: bool = False,
include_code_message: bool = False,
) -> list[ChatCompletionMessageParam]:
"""Returns the messages in the conversation. The system message may change throughout
the conversation and messages may contain additional metadata not supported by the API,
so it is important to access the messages through this method.
"""
session_context = SESSION_CONTEXT.get()
config = session_context.config
ctx = SESSION_CONTEXT.get()

_messages = [
( # Remove metadata from messages by default
Expand All @@ -153,18 +151,45 @@ def get_messages(
for msg in self._messages.copy()
]

if config.no_parser_prompt or not include_system_prompt:
return _messages
if len(_messages) > 0 and _messages[-1].get("role") == "user":
prompt = _messages[-1].get("content")
if isinstance(prompt, list):
text_prompts = [
p.get("text", "") for p in prompt if p.get("type") == "text"
]
prompt = " ".join(text_prompts)
else:
parser = config.parser
prompt = parser.get_system_prompt()
prompt_message: ChatCompletionMessageParam = (
prompt = ""

if include_code_message:
code_message = await ctx.code_context.get_code_message(
prompt_tokens(_messages, ctx.config.model),
prompt=(
prompt # Prompt can be image as well as text
if isinstance(prompt, str)
else ""
),
)
_messages = [
ChatCompletionSystemMessageParam(
role="system",
content=prompt,
content=code_message,
)
)
return [prompt_message] + _messages
] + _messages

if system_prompt is None:
if ctx.config.no_parser_prompt:
system_prompt = []
else:
parser = ctx.config.parser
system_prompt = [
ChatCompletionSystemMessageParam(
role="system",
content=parser.get_system_prompt(),
)
]

return system_prompt + _messages

def clear_messages(self) -> None:
"""Clears the messages in the conversation"""
Expand Down Expand Up @@ -246,29 +271,10 @@ async def get_model_response(self) -> ParsedLLMResponse:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config
code_context = session_context.code_context

messages_snapshot = self.get_messages()

# Get current code message
prompt = messages_snapshot[-1].get("content")
if isinstance(prompt, list):
text_prompts = [
p.get("text", "") for p in prompt if p.get("type") == "text"
]
prompt = " ".join(text_prompts)
code_message = await code_context.get_code_message(
prompt_tokens(messages_snapshot, config.model),
prompt=(
prompt # Prompt can be image as well as text
if isinstance(prompt, str)
else ""
),
)
messages_snapshot.insert(
0 if config.no_parser_prompt else 1,
ChatCompletionSystemMessageParam(role="system", content=code_message),
)
messages_snapshot = await self.get_messages(include_code_message=True)
tokens_used = prompt_tokens(messages_snapshot, config.model)
raise_if_context_exceeds_max(tokens_used)

try:
response = await self._stream_model_response(messages_snapshot)
Expand All @@ -282,18 +288,20 @@ async def get_model_response(self) -> ParsedLLMResponse:
return ParsedLLMResponse("", "", list[FileEdit]())
return response

def remaining_context(self) -> int | None:
async def remaining_context(self) -> int | None:
ctx = SESSION_CONTEXT.get()
return get_max_tokens() - prompt_tokens(self.get_messages(), ctx.config.model)
return get_max_tokens() - prompt_tokens(
await self.get_messages(), ctx.config.model
)

def can_add_to_context(self, message: str) -> bool:
async def can_add_to_context(self, message: str) -> bool:
"""
Whether or not the model has enough context remaining to add this message.
Will take token buffer into account and uses full_message=True.
"""
ctx = SESSION_CONTEXT.get()

remaining_context = self.remaining_context()
remaining_context = await self.remaining_context()
return (
remaining_context is not None
and remaining_context
Expand Down Expand Up @@ -339,7 +347,7 @@ async def run_command(self, command: list[str]) -> bool:
output = "".join(output)
message = f"Command ran:\n{' '.join(command)}\nCommand output:\n{output}"

if self.can_add_to_context(message):
if await self.can_add_to_context(message):
self.add_message(
ChatCompletionSystemMessageParam(role="system", content=message)
)
Expand Down
2 changes: 1 addition & 1 deletion mentat/revisor/revisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def revise_edit(file_edit: FileEdit):
user_message = list(
filter(
lambda message: message["role"] == "user",
ctx.conversation.get_messages(),
await ctx.conversation.get_messages(),
)
)[-1]
user_message["content"] = f"User Request:\n{user_message.get('content')}"
Expand Down
3 changes: 2 additions & 1 deletion mentat/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ async def create_sample(self) -> Sample:
message_history: list[dict[str, str]] = []
message_prompt = ""
response_edit: None | ParsedLLMResponse = None
for m in conversation.get_messages(include_parsed_llm_responses=True)[::-1]:
messages = await conversation.get_messages(include_parsed_llm_responses=True)
for m in messages[::-1]:
response: str | ParsedLLMResponse | None = None
role, content = m["role"], m.get("content")
if role == "user":
Expand Down
2 changes: 1 addition & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def _main(self):
stream.send("Type 'q' or use Ctrl-C to quit at any time.")
need_user_request = True
while True:
code_context.refresh_context_display()
await code_context.refresh_context_display()
try:
if need_user_request:
# Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all
Expand Down
2 changes: 1 addition & 1 deletion mentat/session_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def collect_input_with_commands() -> StreamMessage:
arguments = shlex.split(" ".join(response.data.split(" ")[1:]))
command = Command.create_command(response.data[1:].split(" ")[0])
await command.apply(*arguments)
ctx.code_context.refresh_context_display()
await ctx.code_context.refresh_context_display()
except ValueError as e:
ctx.stream.send(f"Error processing command arguments: {e}", style="error")
response = await collect_user_input(command_autocomplete=True)
Expand Down
3 changes: 0 additions & 3 deletions tests/code_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from mentat.code_context import CodeContext
from mentat.config import Config
from mentat.errors import ReturnToUser
from mentat.feature_filters.default_filter import DefaultFilter
from mentat.git_handler import get_non_gitignored_files
from mentat.include_files import is_file_text_encoded
Expand Down Expand Up @@ -222,8 +221,6 @@ async def _count_max_tokens_where(tokens_used: int) -> int:
return count_tokens(code_message, "gpt-4", full_message=True)

assert await _count_max_tokens_where(0) == 89 # Code
with pytest.raises(ReturnToUser):
await _count_max_tokens_where(1e6)


@pytest.mark.clear_testbed
Expand Down
3 changes: 2 additions & 1 deletion tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ async def test_clear_command(temp_testbed, mock_collect_user_input, mock_call_ll
await session.stream.recv(channel="client_exit")

conversation = SESSION_CONTEXT.get().conversation
assert len(conversation.get_messages()) == 1
messages = await conversation.get_messages()
assert len(messages) == 1


# TODO: test without git
Expand Down
Loading

0 comments on commit 7b04721

Please sign in to comment.