Skip to content

Commit

Permalink
fix: correct context loading from session new/overwrite and resume (#180
Browse files Browse the repository at this point in the history
)
  • Loading branch information
lamchau authored Oct 24, 2024
1 parent e7f7434 commit 2f33514
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 24 deletions.
14 changes: 10 additions & 4 deletions src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from pathlib import Path
from typing import Optional

from langfuse.decorators import langfuse_context
from exchange import Message, Text, ToolResult, ToolUse
from exchange.langfuse_wrapper import observe_wrapper, auth_check
from exchange.langfuse_wrapper import auth_check, observe_wrapper
from langfuse.decorators import langfuse_context
from rich import print
from rich.markdown import Markdown
from rich.panel import Panel
Expand All @@ -21,7 +21,7 @@
from goose.utils import droid, load_plugins
from goose.utils._cost_calculator import get_total_cost_message
from goose.utils._create_exchange import create_exchange
from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, log_messages
from goose.utils.session_file import is_empty_session, is_existing_session, log_messages, read_or_create_file

RESUME_MESSAGE = "I see we were interrupted. How can I help you?"

Expand Down Expand Up @@ -286,9 +286,15 @@ def _prompt_overwrite_session(self) -> None:
print(f"[yellow]Session already exists at {self.session_file_path}.[/]")

choice = OverwriteSessionPrompt.ask("Enter your choice", show_choices=False)
# during __init__ we load the previous context, so we need to
# explicitly clear it
self.exchange.messages.clear()

match choice:
case "y" | "yes":
print("Overwriting existing session")
with open(self.session_file_path, "w") as f:
f.write("")

case "n" | "no":
while True:
Expand All @@ -299,7 +305,7 @@ def _prompt_overwrite_session(self) -> None:
print(f"[yellow]Session '{new_session_name}' already exists[/]")

case "r" | "resume":
self.exchange.messages.extend(self.load_session())
self.exchange.messages.extend(self._get_initial_messages())

def _remove_empty_session(self) -> bool:
"""
Expand Down
146 changes: 126 additions & 20 deletions tests/cli/test_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from unittest.mock import MagicMock, patch
import os
from typing import Union
from unittest.mock import MagicMock, mock_open, patch

import pytest
from exchange import Message, ToolResult, ToolUse
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.cli.session import Session
from prompt_toolkit import PromptSession
Expand All @@ -11,27 +14,70 @@
SESSION_NAME = "test"


@pytest.fixture(scope="module", autouse=True)
def set_openai_api_key():
key = "OPENAI_API_KEY"
value = "test_api_key"

original_api_key = os.environ.get(key)
os.environ[key] = value

yield

if original_api_key is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_api_key


@pytest.fixture
def mock_specified_session_name():
with patch.object(PromptSession, "prompt", return_value=SPECIFIED_SESSION_NAME) as specified_session_name:
yield specified_session_name
@patch.object(PromptSession, "prompt", return_value=SPECIFIED_SESSION_NAME)
def mock_specified_session_name(specified_session_name):
yield specified_session_name


@pytest.fixture
def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory):
with (
patch("goose.cli.session.create_exchange") as mock_exchange,
patch("goose.cli.session.load_profile", return_value=profile_factory()),
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
patch("goose.cli.session.load_provider", return_value="provider"),
):
mock_session_notifier.return_value = MagicMock()
mock_exchange.return_value = exchange_factory()
@patch("goose.cli.session.create_exchange", name="mock_exchange")
@patch("goose.cli.session.load_profile", name="mock_load_profile")
@patch("goose.cli.session.SessionNotifier", name="mock_session_notifier")
@patch("goose.cli.session.load_provider", name="mock_load_provider")
def create_session_with_mock_configs(
mock_load_provider,
mock_session_notifier,
mock_load_profile,
mock_exchange,
mock_sessions_path,
exchange_factory,
profile_factory,
):
mock_load_provider.return_value = "provider"
mock_session_notifier.return_value = MagicMock()
mock_load_profile.return_value = profile_factory()
mock_exchange.return_value = exchange_factory()

def create_session(session_attributes: dict = {}):
return Session(**session_attributes)
def create_session(session_attributes: dict = {}):
return Session(**session_attributes)

yield create_session
return create_session


@pytest.fixture
def session_factory(create_session_with_mock_configs):
def factory(
name=SESSION_NAME,
overwrite_prompt=None,
is_existing_session=None,
get_initial_messages=None,
file_opener=open,
):
session = create_session_with_mock_configs({"name": name})
session.overwrite_prompt = overwrite_prompt or OverwriteSessionPrompt()
session.is_existing_session = is_existing_session or (lambda _: False)
session._get_initial_messages = get_initial_messages or (lambda: [])
session.file_opener = file_opener
return session

return factory


def test_session_does_not_extend_last_user_text_message_on_init(
Expand Down Expand Up @@ -123,18 +169,33 @@ def test_log_log_cost(create_session_with_mock_configs):
mock_logger.info.assert_called_once_with(cost_message)


@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid")
def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path):
@patch("goose.cli.session.droid", return_value="generated_session_name")
@patch("goose.cli.session.load_provider")
def test_set_generated_session_name(
mock_load_provider, mock_droid, create_session_with_mock_configs, mock_sessions_path
):
mock_provider = MagicMock()
mock_load_provider.return_value = mock_provider

session = create_session_with_mock_configs({"name": None})

assert session.name == "generated_session_name"


@patch("goose.cli.session.is_existing_session", name="mock_is_existing")
@patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt")
def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs):
def test_existing_session_prompt(
mock_prompt,
mock_is_existing,
create_session_with_mock_configs,
):
session = create_session_with_mock_configs({"name": SESSION_NAME})

def check_prompt_behavior(is_existing, new_session, should_prompt):
def check_prompt_behavior(
is_existing: bool,
new_session: Union[bool, None],
should_prompt: bool,
) -> None:
mock_is_existing.return_value = is_existing
if new_session is None:
session.run()
Expand All @@ -151,3 +212,48 @@ def check_prompt_behavior(is_existing, new_session, should_prompt):
check_prompt_behavior(is_existing=False, new_session=None, should_prompt=False)
check_prompt_behavior(is_existing=True, new_session=True, should_prompt=True)
check_prompt_behavior(is_existing=False, new_session=False, should_prompt=False)


def test_prompt_overwrite_session(session_factory):
def check_overwrite_behavior(choice: str, expected_messages: list[Message]) -> None:
session = session_factory()

with (
patch.object(OverwriteSessionPrompt, "ask", return_value=choice),
patch.object(session, "is_existing_session", return_value=True),
patch.object(
session,
"_get_initial_messages",
return_value=[Message.user(text="duck duck"), Message.user(text="goose")],
),
patch("rich.prompt.Prompt.ask", return_value="new_session_name"),
patch("builtins.open", mock_open()) as mock_file,
):
session._prompt_overwrite_session()

if choice in ["y", "yes"]:
mock_file.assert_called_once_with(session.session_file_path, "w")
mock_file().write.assert_called_once_with("")
elif choice in ["n", "no"]:
assert session.name == "new_session_name"
elif choice in ["r", "resume"]:
# this is tested comparing the contents of the array
pass

# because the messages are created with an id and creation date, we only want to check the text
actual_messages = [message.text for message in session.exchange.messages]
expected_messages = [message.text for message in expected_messages]
assert actual_messages == expected_messages

check_overwrite_behavior(choice="yes", expected_messages=[])
check_overwrite_behavior(choice="y", expected_messages=[])
check_overwrite_behavior(choice="no", expected_messages=[])
check_overwrite_behavior(choice="n", expected_messages=[])
check_overwrite_behavior(
choice="resume",
expected_messages=[Message.user(text="duck duck"), Message.user(text="goose")],
)
check_overwrite_behavior(
choice="r",
expected_messages=[Message.user(text="duck duck"), Message.user(text="goose")],
) #

0 comments on commit 2f33514

Please sign in to comment.