From 7b047216301e79eeaa892fea0e25cac50a271d9c Mon Sep 17 00:00:00 2001 From: Jake Koenig Date: Tue, 27 Feb 2024 08:19:12 -0800 Subject: [PATCH] Conversation get_messages gets code_message (#533) --- mentat/agent_handler.py | 15 ++--- mentat/code_context.py | 57 +++++------------- mentat/conversation.py | 116 ++++++++++++++++++++----------------- mentat/revisor/revisor.py | 2 +- mentat/sampler/sampler.py | 3 +- mentat/session.py | 2 +- mentat/session_input.py | 2 +- tests/code_context_test.py | 3 - tests/commands_test.py | 3 +- tests/conversation_test.py | 45 +++++++++----- 10 files changed, 120 insertions(+), 128 deletions(-) diff --git a/mentat/agent_handler.py b/mentat/agent_handler.py index 7988b004e..fd414ae78 100644 --- a/mentat/agent_handler.py +++ b/mentat/agent_handler.py @@ -9,7 +9,6 @@ ChatCompletionSystemMessageParam, ) -from mentat.llm_api_handler import prompt_tokens from mentat.prompts.prompts import read_prompt from mentat.session_context import SESSION_CONTEXT from mentat.session_input import ask_yes_no, collect_user_input @@ -42,7 +41,7 @@ async def enable_agent_mode(self): "Finding files to determine how to test changes...", style="info" ) features = ctx.code_context.get_all_features(split_intervals=False) - messages: List[ChatCompletionMessageParam] = [ + messages: list[ChatCompletionMessageParam] = [ ChatCompletionSystemMessageParam( role="system", content=self.agent_file_selection_prompt ), @@ -85,21 +84,15 @@ async def _determine_commands(self) -> List[str]: ctx = SESSION_CONTEXT.get() model = ctx.config.model - messages = [ + system_prompt: list[ChatCompletionMessageParam] = [ ChatCompletionSystemMessageParam( role="system", content=self.agent_command_prompt ), ChatCompletionSystemMessageParam( role="system", content=self.agent_file_message ), - ] + ctx.conversation.get_messages(include_system_prompt=False) - code_message = await ctx.code_context.get_code_message( - prompt_tokens=prompt_tokens(messages, model) - ) - code_message = ChatCompletionSystemMessageParam( - role="system", content=code_message - ) - messages.insert(1, code_message) + ] + messages = await ctx.conversation.get_messages(system_prompt=system_prompt) try: # TODO: Should this even be a separate call or should we collect commands in the edit call? diff --git a/mentat/code_context.py b/mentat/code_context.py index a3348648c..74171d2ba 100644 --- a/mentat/code_context.py +++ b/mentat/code_context.py @@ -5,8 +5,6 @@ from pathlib import Path from typing import Dict, Iterable, List, Optional, Set, TypedDict, Union -from openai.types.chat import ChatCompletionSystemMessageParam - from mentat.code_feature import ( CodeFeature, get_code_message_from_features, @@ -27,12 +25,7 @@ validate_and_format_path, ) from mentat.interval import parse_intervals, split_intervals_from_path -from mentat.llm_api_handler import ( - count_tokens, - get_max_tokens, - prompt_tokens, - raise_if_context_exceeds_max, -) +from mentat.llm_api_handler import count_tokens, get_max_tokens from mentat.session_context import SESSION_CONTEXT from mentat.session_stream import SessionStream @@ -68,7 +61,7 @@ def __init__( self.ignore_files: Set[Path] = set() self.auto_features: List[CodeFeature] = [] - def refresh_context_display(self): + async def refresh_context_display(self): """ Sends a message to the client with the code context. It is called in the main loop. """ @@ -87,24 +80,7 @@ def refresh_context_display(self): git_diff_paths = [str(p) for p in self.diff_context.diff_files()] git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()] - messages = ctx.conversation.get_messages() - code_message = get_code_message_from_features( - [ - feature - for file_features in self.include_files.values() - for feature in file_features - ] - + self.auto_features - ) - total_tokens = prompt_tokens( - messages - + [ - ChatCompletionSystemMessageParam( - role="system", content="\n".join(code_message) - ) - ], - ctx.config.model, - ) + total_tokens = await ctx.conversation.count_tokens(include_code_message=True) total_cost = ctx.cost_tracker.total_cost @@ -126,7 +102,6 @@ async def get_code_message( prompt_tokens: int, prompt: Optional[str] = None, expected_edits: Optional[list[str]] = None, # for training/benchmarking - suppress_context_check: bool = False, ) -> str: """ Retrieves the current code message. @@ -154,7 +129,6 @@ async def get_code_message( ] code_message += ["Code Files:\n"] - meta_tokens = count_tokens("\n".join(code_message), model, full_message=True) # Calculate user included features token size include_features = [ @@ -162,21 +136,22 @@ async def get_code_message( for file_features in self.include_files.values() for feature in file_features ] - include_files_message = get_code_message_from_features(include_features) - include_files_tokens = count_tokens( - "\n".join(include_files_message), model, full_message=False - ) - - tokens_used = prompt_tokens + meta_tokens + include_files_tokens - if not suppress_context_check: - raise_if_context_exceeds_max(tokens_used) - auto_tokens = min( - get_max_tokens() - tokens_used - config.token_buffer, - config.auto_context_tokens, - ) # Get auto included features if config.auto_context_tokens > 0 and prompt: + meta_tokens = count_tokens( + "\n".join(code_message), model, full_message=True + ) + include_files_message = get_code_message_from_features(include_features) + include_files_tokens = count_tokens( + "\n".join(include_files_message), model, full_message=False + ) + + tokens_used = prompt_tokens + meta_tokens + include_files_tokens + auto_tokens = min( + get_max_tokens() - tokens_used - config.token_buffer, + config.auto_context_tokens, + ) features = self.get_all_features() feature_filter = DefaultFilter( auto_tokens, diff --git a/mentat/conversation.py b/mentat/conversation.py index e4cc39573..f551cc927 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -20,6 +20,7 @@ count_tokens, get_max_tokens, prompt_tokens, + raise_if_context_exceeds_max, ) from mentat.parsers.file_edit import FileEdit from mentat.parsers.parser import ParsedLLMResponse @@ -45,21 +46,7 @@ async def display_token_count(self): config = session_context.config code_context = session_context.code_context - messages = self.get_messages() - code_message = await code_context.get_code_message( - prompt_tokens( - messages, - config.model, - ), - suppress_context_check=True, - ) - messages.append( - ChatCompletionSystemMessageParam( - role="system", - content=code_message, - ) - ) - tokens = prompt_tokens(messages, config.model) + tokens = await self.count_tokens(include_code_message=True) context_size = get_max_tokens() if tokens + config.token_buffer > context_size: @@ -130,17 +117,28 @@ def add_message(self, message: ChatCompletionMessageParam): """Used for adding messages to the models conversation. Does not add a left-side message to the transcript!""" self._messages.append(message) - def get_messages( + async def count_tokens( self, - include_system_prompt: bool = True, + system_prompt: Optional[list[ChatCompletionMessageParam]] = None, + include_code_message: bool = False, + ) -> int: + _messages = await self.get_messages( + system_prompt=system_prompt, include_code_message=include_code_message + ) + model = SESSION_CONTEXT.get().config.model + return prompt_tokens(_messages, model) + + async def get_messages( + self, + system_prompt: Optional[list[ChatCompletionMessageParam]] = None, include_parsed_llm_responses: bool = False, + include_code_message: bool = False, ) -> list[ChatCompletionMessageParam]: """Returns the messages in the conversation. The system message may change throughout the conversation and messages may contain additional metadata not supported by the API, so it is important to access the messages through this method. """ - session_context = SESSION_CONTEXT.get() - config = session_context.config + ctx = SESSION_CONTEXT.get() _messages = [ ( # Remove metadata from messages by default @@ -153,18 +151,45 @@ def get_messages( for msg in self._messages.copy() ] - if config.no_parser_prompt or not include_system_prompt: - return _messages + if len(_messages) > 0 and _messages[-1].get("role") == "user": + prompt = _messages[-1].get("content") + if isinstance(prompt, list): + text_prompts = [ + p.get("text", "") for p in prompt if p.get("type") == "text" + ] + prompt = " ".join(text_prompts) else: - parser = config.parser - prompt = parser.get_system_prompt() - prompt_message: ChatCompletionMessageParam = ( + prompt = "" + + if include_code_message: + code_message = await ctx.code_context.get_code_message( + prompt_tokens(_messages, ctx.config.model), + prompt=( + prompt # Prompt can be image as well as text + if isinstance(prompt, str) + else "" + ), + ) + _messages = [ ChatCompletionSystemMessageParam( role="system", - content=prompt, + content=code_message, ) - ) - return [prompt_message] + _messages + ] + _messages + + if system_prompt is None: + if ctx.config.no_parser_prompt: + system_prompt = [] + else: + parser = ctx.config.parser + system_prompt = [ + ChatCompletionSystemMessageParam( + role="system", + content=parser.get_system_prompt(), + ) + ] + + return system_prompt + _messages def clear_messages(self) -> None: """Clears the messages in the conversation""" @@ -246,29 +271,10 @@ async def get_model_response(self) -> ParsedLLMResponse: session_context = SESSION_CONTEXT.get() stream = session_context.stream config = session_context.config - code_context = session_context.code_context - messages_snapshot = self.get_messages() - - # Get current code message - prompt = messages_snapshot[-1].get("content") - if isinstance(prompt, list): - text_prompts = [ - p.get("text", "") for p in prompt if p.get("type") == "text" - ] - prompt = " ".join(text_prompts) - code_message = await code_context.get_code_message( - prompt_tokens(messages_snapshot, config.model), - prompt=( - prompt # Prompt can be image as well as text - if isinstance(prompt, str) - else "" - ), - ) - messages_snapshot.insert( - 0 if config.no_parser_prompt else 1, - ChatCompletionSystemMessageParam(role="system", content=code_message), - ) + messages_snapshot = await self.get_messages(include_code_message=True) + tokens_used = prompt_tokens(messages_snapshot, config.model) + raise_if_context_exceeds_max(tokens_used) try: response = await self._stream_model_response(messages_snapshot) @@ -282,18 +288,20 @@ async def get_model_response(self) -> ParsedLLMResponse: return ParsedLLMResponse("", "", list[FileEdit]()) return response - def remaining_context(self) -> int | None: + async def remaining_context(self) -> int | None: ctx = SESSION_CONTEXT.get() - return get_max_tokens() - prompt_tokens(self.get_messages(), ctx.config.model) + return get_max_tokens() - prompt_tokens( + await self.get_messages(), ctx.config.model + ) - def can_add_to_context(self, message: str) -> bool: + async def can_add_to_context(self, message: str) -> bool: """ Whether or not the model has enough context remaining to add this message. Will take token buffer into account and uses full_message=True. """ ctx = SESSION_CONTEXT.get() - remaining_context = self.remaining_context() + remaining_context = await self.remaining_context() return ( remaining_context is not None and remaining_context @@ -339,7 +347,7 @@ async def run_command(self, command: list[str]) -> bool: output = "".join(output) message = f"Command ran:\n{' '.join(command)}\nCommand output:\n{output}" - if self.can_add_to_context(message): + if await self.can_add_to_context(message): self.add_message( ChatCompletionSystemMessageParam(role="system", content=message) ) diff --git a/mentat/revisor/revisor.py b/mentat/revisor/revisor.py index cc9686688..74a10765f 100644 --- a/mentat/revisor/revisor.py +++ b/mentat/revisor/revisor.py @@ -54,7 +54,7 @@ async def revise_edit(file_edit: FileEdit): user_message = list( filter( lambda message: message["role"] == "user", - ctx.conversation.get_messages(), + await ctx.conversation.get_messages(), ) )[-1] user_message["content"] = f"User Request:\n{user_message.get('content')}" diff --git a/mentat/sampler/sampler.py b/mentat/sampler/sampler.py index e5c7cd5e8..1a1d72ab8 100644 --- a/mentat/sampler/sampler.py +++ b/mentat/sampler/sampler.py @@ -118,7 +118,8 @@ async def create_sample(self) -> Sample: message_history: list[dict[str, str]] = [] message_prompt = "" response_edit: None | ParsedLLMResponse = None - for m in conversation.get_messages(include_parsed_llm_responses=True)[::-1]: + messages = await conversation.get_messages(include_parsed_llm_responses=True) + for m in messages[::-1]: response: str | ParsedLLMResponse | None = None role, content = m["role"], m.get("content") if role == "user": diff --git a/mentat/session.py b/mentat/session.py index a2fff5354..813e29c42 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -151,7 +151,7 @@ async def _main(self): stream.send("Type 'q' or use Ctrl-C to quit at any time.") need_user_request = True while True: - code_context.refresh_context_display() + await code_context.refresh_context_display() try: if need_user_request: # Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all diff --git a/mentat/session_input.py b/mentat/session_input.py index b546c67c3..25b7b9b47 100644 --- a/mentat/session_input.py +++ b/mentat/session_input.py @@ -60,7 +60,7 @@ async def collect_input_with_commands() -> StreamMessage: arguments = shlex.split(" ".join(response.data.split(" ")[1:])) command = Command.create_command(response.data[1:].split(" ")[0]) await command.apply(*arguments) - ctx.code_context.refresh_context_display() + await ctx.code_context.refresh_context_display() except ValueError as e: ctx.stream.send(f"Error processing command arguments: {e}", style="error") response = await collect_user_input(command_autocomplete=True) diff --git a/tests/code_context_test.py b/tests/code_context_test.py index 99fc8681f..1e00980a7 100644 --- a/tests/code_context_test.py +++ b/tests/code_context_test.py @@ -8,7 +8,6 @@ from mentat.code_context import CodeContext from mentat.config import Config -from mentat.errors import ReturnToUser from mentat.feature_filters.default_filter import DefaultFilter from mentat.git_handler import get_non_gitignored_files from mentat.include_files import is_file_text_encoded @@ -222,8 +221,6 @@ async def _count_max_tokens_where(tokens_used: int) -> int: return count_tokens(code_message, "gpt-4", full_message=True) assert await _count_max_tokens_where(0) == 89 # Code - with pytest.raises(ReturnToUser): - await _count_max_tokens_where(1e6) @pytest.mark.clear_testbed diff --git a/tests/commands_test.py b/tests/commands_test.py index 3ad243ca9..a9da97d7b 100644 --- a/tests/commands_test.py +++ b/tests/commands_test.py @@ -246,7 +246,8 @@ async def test_clear_command(temp_testbed, mock_collect_user_input, mock_call_ll await session.stream.recv(channel="client_exit") conversation = SESSION_CONTEXT.get().conversation - assert len(conversation.get_messages()) == 1 + messages = await conversation.get_messages() + assert len(messages) == 1 # TODO: test without git diff --git a/tests/conversation_test.py b/tests/conversation_test.py index 6106c44b8..ba7aba088 100644 --- a/tests/conversation_test.py +++ b/tests/conversation_test.py @@ -1,36 +1,43 @@ +import pytest + +from mentat.errors import ReturnToUser from mentat.parsers.block_parser import BlockParser from mentat.parsers.replacement_parser import ReplacementParser from mentat.session_context import SESSION_CONTEXT -def test_midconveration_parser_change(mock_call_llm_api): +@pytest.mark.asyncio +async def test_midconveration_parser_change(mock_call_llm_api): session_context = SESSION_CONTEXT.get() config = session_context.config conversation = session_context.conversation config.parser = "block" - assert ( - conversation.get_messages()[0]["content"] == BlockParser().get_system_prompt() - ) + messages = await conversation.get_messages() + assert messages[0]["content"] == BlockParser().get_system_prompt() config.parser = "replacement" - assert ( - conversation.get_messages()[0]["content"] - == ReplacementParser().get_system_prompt() - ) + messages = await conversation.get_messages() + assert messages[0]["content"] == ReplacementParser().get_system_prompt() -def test_no_parser_prompt(mock_call_llm_api): +@pytest.mark.asyncio +async def test_no_parser_prompt(mock_call_llm_api): session_context = SESSION_CONTEXT.get() config = session_context.config conversation = session_context.conversation - assert len(conversation.get_messages()) == 1 + messages = await conversation.get_messages(include_code_message=True) + assert len(messages) == 2 + messages = await conversation.get_messages() + assert len(messages) == 1 config.no_parser_prompt = True - assert len(conversation.get_messages()) == 0 + messages = await conversation.get_messages() + assert len(messages) == 0 -def test_add_user_message_with_and_without_image(mock_call_llm_api): +@pytest.mark.asyncio +async def test_add_user_message_with_and_without_image(mock_call_llm_api): session_context = SESSION_CONTEXT.get() conversation = session_context.conversation @@ -38,7 +45,7 @@ def test_add_user_message_with_and_without_image(mock_call_llm_api): test_message = "Hello, World!" test_image_url = "http://example.com/image.png" conversation.add_user_message(test_message, test_image_url) - messages_with_image = conversation.get_messages() + messages_with_image = await conversation.get_messages() assert len(messages_with_image) == 2 # System prompt + user message user_message_content_with_image = messages_with_image[-1]["content"] assert len(user_message_content_with_image) == 2 # Text + image @@ -50,7 +57,17 @@ def test_add_user_message_with_and_without_image(mock_call_llm_api): # Test without image conversation.clear_messages() conversation.add_user_message(test_message) - messages_without_image = conversation.get_messages() + messages_without_image = await conversation.get_messages() assert len(messages_without_image) == 2 # System prompt + user message user_message_content_without_image = messages_without_image[-1]["content"] assert user_message_content_without_image == test_message + + +@pytest.mark.asyncio +async def test_raise_if_context_exceeded(): + session_context = SESSION_CONTEXT.get() + config = session_context.config + config.maximum_context = 0 + conversation = session_context.conversation + with pytest.raises(ReturnToUser): + await conversation.get_model_response()