Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add load/save context commands #534

Merged
merged 21 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/user/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ Show information on available commands.

Add files to context.

/load [context file path]
------------

Load context from a file.

/save [context file path]
------------

Save context to a file.

/redo
-----

Expand Down
26 changes: 26 additions & 0 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,29 @@ async def search(
return all_features_sorted
else:
return all_features_sorted[:max_results]

galer7 marked this conversation as resolved.
Show resolved Hide resolved
galer7 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
simple_dict[str(path.absolute())] = [

This change ensures compatibility with different operating systems and file path formats.

def to_simple_context_dict(self) -> dict[str, list[str]]:
"""Return a simple dictionary representation of the code context"""

simple_dict: dict[str, list[str]] = {}
for path, features in self.include_files.items():
simple_dict[str(path.absolute())] = [str(feature) for feature in features]
return simple_dict

def from_simple_context_dict(self, simple_dict: dict[str, list[str]]):
"""Load the code context from a simple dictionary representation"""

for path_str, features_str in simple_dict.items():
path = Path(path_str)
features_for_path: List[CodeFeature] = []

for feature_str in features_str:
feature_path = Path(feature_str)

# feature_path is already absolute, so cwd doesn't matter
current_features = get_code_features_for_path(
feature_path, cwd=Path("/")
)
features_for_path += list(current_features)

self.include_files[path] = features_for_path
2 changes: 2 additions & 0 deletions mentat/command/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from .exclude import ExcludeCommand
from .help import HelpCommand
from .include import IncludeCommand
from .load import LoadCommand
galer7 marked this conversation as resolved.
Show resolved Hide resolved
galer7 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see the inclusion of the load and save commands in the __init__.py file, ensuring they are recognized and can be utilized within the application.

from .redo import RedoCommand
from .run import RunCommand
from .sample import SampleCommand
from .save import SaveCommand
from .screenshot import ScreenshotCommand
from .search import SearchCommand
from .talk import TalkCommand
Expand Down
69 changes: 69 additions & 0 deletions mentat/command/commands/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
from pathlib import Path
from typing import List

from typing_extensions import override

from mentat.auto_completer import get_command_filename_completions
from mentat.command.command import Command, CommandArgument
from mentat.errors import PathValidationError
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path


class LoadCommand(Command, command_name="load"):
@override
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
context_file_path = mentat_dir_path / "context.json"

if len(args) > 1:
stream.send(
"Only one context file can be loaded at a time", style="warning"
)
galer7 marked this conversation as resolved.
Show resolved Hide resolved
return

if args:
path_arg = args[0]
try:
context_file_path = Path(path_arg).expanduser().resolve()
except RuntimeError as e:
raise PathValidationError(
f"Invalid context file path provided: {path_arg}: {e}"
)

try:
galer7 marked this conversation as resolved.
Show resolved Hide resolved
with open(context_file_path, "r") as file:
parsed_include_files = json.load(file)
except FileNotFoundError:
stream.send(f"Context file not found at {context_file_path}", style="error")
return
except json.JSONDecodeError as e:
stream.send(
f"Failed to parse context file at {context_file_path}: {e}",
style="error",
)
return

code_context.from_simple_context_dict(parsed_include_files)

stream.send(f"Context loaded from {context_file_path}", style="success")

@override
@classmethod
def arguments(cls) -> List[CommandArgument]:
return [CommandArgument("optional", ["path"])]

@override
@classmethod
def argument_autocompletions(
cls, arguments: list[str], argument_position: int
) -> list[str]:
return get_command_filename_completions(arguments[-1])

@override
@classmethod
def help_message(cls) -> str:
return "Loads a context file."
56 changes: 56 additions & 0 deletions mentat/command/commands/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import json
from pathlib import Path
from typing import List

from typing_extensions import override

from mentat.auto_completer import get_command_filename_completions
from mentat.command.command import Command, CommandArgument
from mentat.errors import PathValidationError
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path


class SaveCommand(Command, command_name="save"):
@override
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
context_file_path = mentat_dir_path / "context.json"

if len(args) > 1:
stream.send("Only one context file can be saved at a time", style="warning")
return

if len(args) == 1:
try:
context_file_path = Path(args[0]).expanduser().resolve()
except RuntimeError as e:
raise PathValidationError(
f"Invalid context file path provided: {args[0]}: {e}"
)

serializable_context = code_context.to_simple_context_dict()

with open(context_file_path, "w") as file:
json.dump(serializable_context, file)

stream.send(f"Context saved to {context_file_path}", style="success")

@override
@classmethod
def arguments(cls) -> List[CommandArgument]:
return [CommandArgument("optional", ["path"])]

@override
@classmethod
def argument_autocompletions(
cls, arguments: list[str], argument_position: int
) -> list[str]:
return get_command_filename_completions(arguments[-1])

@override
@classmethod
def help_message(cls) -> str:
return "Saves the current context to a file."
15 changes: 15 additions & 0 deletions mentat/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,18 @@ def __str__(self) -> str:
return f"{self.start}"
else:
return f"{self.start}-{self.end}"

@staticmethod
def from_string(interval_string: str) -> Interval:
print("interval_string:", interval_string)
try:
interval_parts = interval_string.split("-")
except ValueError:
return Interval(1, INTERVAL_FILE_END)

if len(interval_parts) != 2:
# corrupt interval string, make it whole file
return Interval(1, INTERVAL_FILE_END)

start, end = interval_parts
return Interval(int(start), int(end))
4 changes: 2 additions & 2 deletions mentat/session_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class SessionStream:
default: Any data sent to the client over this channel should be displayed. Valid kwargs: color, style

*session_exit: Sent by the client, suggesting that the session should exit whenever possible.
client_exit: Sent by the server, client should shut down when recieved.
session_stopped: Sent by the server directly before server shuts down. Server can't be contacted after recieved.
client_exit: Sent by the server, client should shut down when received.
galer7 marked this conversation as resolved.
Show resolved Hide resolved
session_stopped: Sent by the server directly before server shuts down. Server can't be contacted after received.

loading: Used to tell the client to display a loading bar. Valid kwargs: terminate

Expand Down
100 changes: 100 additions & 0 deletions tests/commands_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import subprocess
from pathlib import Path
from textwrap import dedent
Expand All @@ -7,6 +8,7 @@
from mentat.code_feature import CodeFeature
from mentat.command.command import Command, InvalidCommand
from mentat.command.commands.help import HelpCommand
from mentat.interval import Interval
from mentat.session import Session
from mentat.session_context import SESSION_CONTEXT

Expand Down Expand Up @@ -80,6 +82,104 @@ async def test_exclude_command(temp_testbed, mock_collect_user_input):
assert not code_context.include_files


@pytest.mark.asyncio
async def test_save_command(temp_testbed, mock_collect_user_input):
default_context_path = "context.json"
mock_collect_user_input.set_stream_messages(
[
"/include scripts",
f"/save {default_context_path}",
"q",
]
)

session = Session(cwd=temp_testbed)
session.start()
await session.stream.recv(channel="client_exit")

saved_code_context: dict[str, list[str]] = json.load(open(default_context_path))
calculator_script_path = Path(temp_testbed) / "scripts" / "calculator.py"
assert str(calculator_script_path) in (saved_code_context.keys())
assert [str(calculator_script_path)] in (saved_code_context.values())


@pytest.mark.asyncio
async def test_load_command_success(temp_testbed, mock_collect_user_input):
scripts_dir = Path(temp_testbed) / "scripts"
features = [
CodeFeature(scripts_dir / "calculator.py", Interval(1, 10)),
CodeFeature(scripts_dir / "echo.py"),
]
context_file_path = "context.json"

context_file_data = {}
with open(context_file_path, "w") as f:
for feature in features:
context_file_data[str(feature.path)] = [str(feature)]

json.dump(context_file_data, f)

mock_collect_user_input.set_stream_messages(
[
f"/load {context_file_path}",
"q",
]
)

session = Session(cwd=temp_testbed)
session.start()
await session.stream.recv(channel="client_exit")

code_context = SESSION_CONTEXT.get().code_context

assert scripts_dir / "calculator.py" in code_context.include_files.keys()
assert code_context.include_files[scripts_dir / "calculator.py"] == [
CodeFeature(scripts_dir / "calculator.py", Interval(1, 10)),
]
assert scripts_dir / "echo.py" in code_context.include_files.keys()
assert code_context.include_files[scripts_dir / "echo.py"] == [
CodeFeature(scripts_dir / "echo.py"),
]


@pytest.mark.asyncio
async def test_load_command_file_not_found(temp_testbed, mock_collect_user_input):
context_file_path = "context-f47e7a1c-84a2-40a9-9e40-255f976d3223.json"

mock_collect_user_input.set_stream_messages(
[
f"/load {context_file_path}",
"q",
]
)

session = Session(cwd=temp_testbed)
session.start()
await session.stream.recv(channel="client_exit")

assert "Context file not found" in session.stream.messages[4].data


@pytest.mark.asyncio
async def test_load_command_invalid_json(temp_testbed, mock_collect_user_input):
context_file_path = "context.json"
with open(context_file_path, "w") as f:
f.write("invalid json")

mock_collect_user_input.set_stream_messages(
[
f"/load {context_file_path}",
"q",
]
)

session = Session(cwd=temp_testbed)
session.start()
await session.stream.recv(channel="client_exit")

assert "Failed to parse context file" in session.stream.messages[4].data


@pytest.mark.asyncio
async def test_undo_command(temp_testbed, mock_collect_user_input, mock_call_llm_api):
temp_file_name = "temp.py"
Expand Down
Loading