Skip to content

Commit

Permalink
ESBMC-AI can now be used with VERIFICATION SUCCESSFUL samples
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiannis128 committed Sep 26, 2023
1 parent 88cf072 commit 97e815b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 17 deletions.
24 changes: 18 additions & 6 deletions esbmc_ai_lib/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def init_commands() -> None:
fix_code_command.on_solution_signal.add_listener(chat.set_solution)
fix_code_command.on_solution_signal.add_listener(update_solution)

optimize_code_command.on_solution_signal.add_listener(chat.set_solution)
optimize_code_command.on_solution_signal.add_listener(update_solution)


def _run_command_mode(
command: ChatCommand,
Expand All @@ -174,14 +177,15 @@ def _run_command_mode(
sys.exit(1)
else:
print(solution)
# elif command == verify_code_command:
# raise NotImplementedError()
elif command == optimize_code_command:
optimize_code_command.execute(
error, solution = optimize_code_command.execute(
file_path=get_main_source_file_path(),
source_code=source_code,
function_names=args,
)

print(solution)
sys.exit(1 if error else 0)
else:
command.execute()
sys.exit(0)
Expand Down Expand Up @@ -302,12 +306,13 @@ def main() -> None:

# ESBMC will output 0 for verification success and 1 for verification
# failed, if anything else gets thrown, it's an ESBMC error.
if exit_code == 0:
if not config.allow_successful and exit_code == 0:
print("Success!")
print(esbmc_output)
sys.exit(0)
elif exit_code != 1:
elif exit_code != 0 and exit_code != 1:
print(f"ESBMC exit code: {exit_code}")
print(f"ESBMC Output:\n\n{esbmc_err_output}")
sys.exit(1)

# Command mode: Check if command is called and call it.
Expand Down Expand Up @@ -396,11 +401,18 @@ def main() -> None:
continue
elif command == optimize_code_command.command_name:
# Optimize Code command
optimize_code_command.execute(
error, solution = optimize_code_command.execute(
file_path=get_main_source_file_path(),
source_code=get_main_source_file().content,
function_names=command_args,
)

if error:
# Print error
print("\n" + solution + "\n")
else:
print(f"\nOptimizations Completed:\n```c\n{solution}```\n")

continue
else:
# Commands without parameters or returns are handled automatically.
Expand Down
33 changes: 22 additions & 11 deletions esbmc_ai_lib/commands/optimize_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from os import get_terminal_size
from typing import Iterable, Optional
from typing import Iterable, Optional, Tuple
from typing_extensions import override
from string import Template
from random import randint
Expand All @@ -13,6 +13,7 @@
from esbmc_ai_lib.frontend.c_types import is_primitive_type
from esbmc_ai_lib.frontend.esbmc_code_generator import ESBMCCodeGenerator
from esbmc_ai_lib.esbmc_util import esbmc_load_source_code
from esbmc_ai_lib.msg_bus import Signal
from esbmc_ai_lib.solution_generator import SolutionGenerator
from .chat_command import ChatCommand
from .. import config
Expand All @@ -31,6 +32,7 @@ def __init__(self) -> None:
command_name="optimize-code",
help_message="(EXPERIMENTAL) Optimizes the code of a specific function or the entire file if a function is not specified. Usage: optimize-code [function_name]",
)
self.on_solution_signal: Signal = Signal()

def _get_functions_list(
self,
Expand Down Expand Up @@ -280,8 +282,19 @@ def get_function_from_collection(

@override
def execute(
self, file_path: str, source_code: str, function_names: list[str]
) -> None:
self,
file_path: str,
source_code: str,
function_names: list[str],
) -> Tuple[bool, str]:
"""Executes the Optimize Code command. The command takes the following inputs:
* file_path: The path of the source code file.
* source_code: The source code file contents.
* function_names: List of function names to optimize. Main is always excluded.
Returns a `Tuple[bool, str]` which is the flag if there was an error, and the
source code from the LLM.
"""
clang_ast: ast.ClangAST = ast.ClangAST(
file_path=file_path,
source_code=source_code,
Expand Down Expand Up @@ -323,7 +336,7 @@ def execute(
function_name=fn_name,
)

new_source_code: str = SolutionGenerator.get_code_from_solution(
optimized_source_code: str = SolutionGenerator.get_code_from_solution(
response.message.content
)

Expand All @@ -335,19 +348,17 @@ def execute(
# Check equivalence
equal: bool = self.check_function_pequivalence(
original_source_code=source_code,
new_source_code=new_source_code,
new_source_code=optimized_source_code,
function_name=fn_name,
)

if equal:
new_source_code = response.message.content
# If equal, then return with explanation.
new_source_code = optimized_source_code
break
elif attempt == max_retries - 1:
print("Failed all attempts...")
return
return True, "Failed all attempts..."
else:
print("Failed attempt", attempt)

print("\nOptimizations Completed:\n")
print(new_source_code)
print()
return False, new_source_code
13 changes: 13 additions & 0 deletions esbmc_ai_lib/user_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def set_solution(self, source_code: str) -> None:
message=AIMessage(content="Understood"), protected=True
)

def set_optimized_solution(self, source_code: str) -> None:
self.solution = source_code
self.push_to_message_stack(
message=HumanMessage(
content=f"Here is the optimized code:\n\n{source_code}"
),
protected=True,
)

self.push_to_message_stack(
message=AIMessage(content="Understood"), protected=True
)

@override
def compress_message_stack(self) -> None:
"""Uses ConversationSummaryMemory from Langchain to summarize the conversation of all the non-protected
Expand Down

0 comments on commit 97e815b

Please sign in to comment.