Skip to content

Commit

Permalink
Updated tests and fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiannis128 committed Sep 15, 2024
1 parent bcc8022 commit 0a073f0
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 223 deletions.
2 changes: 1 addition & 1 deletion esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def add_safeguards(content: str, char: str, allowed_keys: list[str]) -> str:
if content.find("}", look_pointer) != -1:
# Do it in reverse with reverse keys.
content = add_safeguards(content[::-1], "}", reversed_keys)[::-1]
new_msg = msg.copy()
new_msg = msg.model_copy()
new_msg.content = content
result.append(new_msg)
return result
Expand Down
5 changes: 5 additions & 0 deletions esbmc_ai/esbmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional

from esbmc_ai.solution import SourceFile
from esbmc_ai.config import default_scenario


class ESBMCUtil:
Expand Down Expand Up @@ -48,6 +49,10 @@ def esbmc_get_error_type(cls, esbmc_output: str) -> str:
scenario: str = from_loc_error_msg[scenario_index + 1 :]
scenario_end_l_index: int = scenario.find("\n")
scenario = scenario[:scenario_end_l_index].strip()

if not scenario:
return default_scenario

return scenario

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
(SystemMessage(content='System message'), AIMessage(content='OK'))
[AIMessage(content='Test 1'), HumanMessage(content='Test 2'), SystemMessage(content='Test 3')]
system: System message
ai: OK
ai: Test 1
human: Test 2
system: Test 3
21 changes: 10 additions & 11 deletions tests/regtest/test_base_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from esbmc_ai.ai_models import AIModel
from esbmc_ai.chat_response import ChatResponse
from esbmc_ai.chats.base_chat_interface import BaseChatInterface
from esbmc_ai.config import AIAgentConversation, ChatPromptSettings


@pytest.fixture
Expand All @@ -24,11 +23,7 @@ def setup():
]

chat: BaseChatInterface = BaseChatInterface(
ai_model_agent=ChatPromptSettings(
initial_prompt="",
system_messages=AIAgentConversation.from_seq(system_messages),
temperature=1.0,
),
system_messages=system_messages,
ai_model=ai_model,
llm=llm,
)
Expand All @@ -50,8 +45,11 @@ def test_push_message_stack(regtest, setup) -> None:
chat.push_to_message_stack(messages[2])

with regtest:
print(chat.ai_model_agent.system_messages.messages)
print(chat.messages)
for msg in chat._system_messages:
print(f"{msg.type}: {msg.content}")

for msg in chat.messages:
print(f"{msg.type}: {msg.content}")


def test_send_message(regtest, setup) -> None:
Expand All @@ -65,12 +63,13 @@ def test_send_message(regtest, setup) -> None:

with regtest:
print("System Messages:")
for m in chat.ai_model_agent.system_messages.messages:
for m in chat._system_messages:
print(f"{m.type}: {m.content}")
print("Chat Messages:")
for m in chat.messages:
print(f"{m.type}: {m.content}")
print("Responses:")
for m in chat_responses:
print(f"{m.message.type}({m.total_tokens} - {m.finish_reason}): {m.message.content}")

print(
f"{m.message.type}({m.total_tokens} - {m.finish_reason}): {m.message.content}"
)
28 changes: 15 additions & 13 deletions tests/test_ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@
from pytest import raises

from esbmc_ai.ai_models import (
AIModelOpenAI,
add_custom_ai_model,
is_valid_ai_model,
AIModel,
_AIModels,
get_ai_model_by_name,
OllamaAIModel,
_get_openai_model_max_tokens,
)

"""TODO Find a way to mock the OpenAI API and test GPT LLM code."""


def test_is_valid_ai_model() -> None:
assert is_valid_ai_model(_AIModels.FALCON_7B.value)
assert is_valid_ai_model(_AIModels.STARCHAT_BETA.value)
assert is_valid_ai_model("falcon-7b")
# def test_is_valid_ai_model() -> None:
# assert is_valid_ai_model(_AIModels.FALCON_7B.value)
# assert is_valid_ai_model(_AIModels.STARCHAT_BETA.value)
# assert is_valid_ai_model("falcon-7b")


def test_is_not_valid_ai_model() -> None:
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_add_custom_ai_model() -> None:

def test_get_ai_model_by_name() -> None:
# Try with first class AI
assert get_ai_model_by_name("falcon-7b")
# assert get_ai_model_by_name("falcon-7b")

# Try with custom AI.
# Add custom AI model if not added by previous tests.
Expand Down Expand Up @@ -142,12 +142,14 @@ def test_escape_messages() -> None:


def test__get_openai_model_max_tokens() -> None:
assert _get_openai_model_max_tokens("gpt-4o") == 128000
assert _get_openai_model_max_tokens("gpt-4-turbo") == 8192
assert _get_openai_model_max_tokens("gpt-3.5-turbo") == 16385
assert _get_openai_model_max_tokens("gpt-3.5-turbo-instruct") == 4096
assert _get_openai_model_max_tokens("gpt-3.5-turbo-aaaaaa") == 16385
assert _get_openai_model_max_tokens("gpt-3.5-turbo-instruct-bbb") == 4096
assert AIModelOpenAI.get_openai_model_max_tokens("gpt-4o") == 128000
assert AIModelOpenAI.get_openai_model_max_tokens("gpt-4-turbo") == 8192
assert AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo") == 16385
assert AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo-instruct") == 4096
assert AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo-aaaaaa") == 16385
assert (
AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo-instruct-bbb") == 4096
)

with raises(ValueError):
_get_openai_model_max_tokens("aaaaa")
AIModelOpenAI.get_openai_model_max_tokens("aaaaa")
23 changes: 6 additions & 17 deletions tests/test_base_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from esbmc_ai.ai_models import AIModel
from esbmc_ai.chats.base_chat_interface import BaseChatInterface
from esbmc_ai.chat_response import ChatResponse
from esbmc_ai.config import AIAgentConversation, ChatPromptSettings


@pytest.fixture(scope="module")
Expand All @@ -28,16 +27,14 @@ def test_push_message_stack(setup) -> None:
ai_model, system_messages = setup

chat: BaseChatInterface = BaseChatInterface(
ai_model_agent=ChatPromptSettings(
AIAgentConversation.from_seq(system_messages),
initial_prompt="",
temperature=1.0,
),
system_messages=system_messages,
ai_model=ai_model,
llm=llm,
)

assert chat.ai_model_agent.system_messages.messages == tuple(system_messages)
for msg, chat_msg in zip(system_messages, chat._system_messages):
assert msg.type == chat_msg.type
assert msg.content == chat_msg.content

messages: list[BaseMessage] = [
AIMessage(content="Test 1"),
Expand All @@ -61,11 +58,7 @@ def test_send_message(setup) -> None:
ai_model, system_messages = setup

chat: BaseChatInterface = BaseChatInterface(
ai_model_agent=ChatPromptSettings(
AIAgentConversation.from_seq(system_messages),
initial_prompt="",
temperature=1.0,
),
system_messages=system_messages,
ai_model=ai_model,
llm=llm,
)
Expand Down Expand Up @@ -101,11 +94,7 @@ def test_apply_template() -> None:
llm: FakeListChatModel = FakeListChatModel(responses=responses)

chat: BaseChatInterface = BaseChatInterface(
ai_model_agent=ChatPromptSettings(
AIAgentConversation.from_seq(system_messages),
initial_prompt="{source_code}{esbmc_output}",
temperature=1.0,
),
system_messages=system_messages,
ai_model=ai_model,
llm=llm,
)
Expand Down
70 changes: 1 addition & 69 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,80 +7,12 @@
from esbmc_ai.ai_models import is_valid_ai_model


def test_load_config_value() -> None:
result, ok = config._load_config_value(
{
"test": "value",
},
"test",
)
assert ok and result == "value"


def test_load_config_value_default_value() -> None:
result, ok = config._load_config_value(
{
"test": "value",
},
"test",
"wrong",
)
assert ok and result == "value"


def test_load_config_value_default_value_not_exists() -> None:
result, ok = config._load_config_value(
{},
"test2",
"wrong",
)
assert not ok and result == "wrong"


def test_load_config_real_number() -> None:
result = config._load_config_real_number(
{
"test": 1.0,
},
"test",
)
assert result == 1.0


def test_load_config_real_number_default_value() -> None:
result = config._load_config_real_number({}, "test", 1.1)
assert result == 1.1


def test_load_config_real_number_wrong_value() -> None:
with raises(TypeError):
result = config._load_config_real_number(
{
"test": "wrong value",
},
"test",
)
assert result == None


def test_load_config_real_number_wrong_value_default() -> None:
with raises(TypeError):
result = config._load_config_real_number(
{
"test": "wrong value",
},
"test",
1.0,
)
assert result == None


def test_load_custom_ai() -> None:
custom_ai_config: dict = {
"example_ai": {
"max_tokens": 4096,
"url": "www.example.com",
"server_type": "ollama"
"server_type": "ollama",
}
}

Expand Down
Loading

0 comments on commit 0a073f0

Please sign in to comment.