diff --git a/esbmc_ai_lib/__main__.py b/esbmc_ai_lib/__main__.py index 8940fab..856b263 100755 --- a/esbmc_ai_lib/__main__.py +++ b/esbmc_ai_lib/__main__.py @@ -14,7 +14,12 @@ from langchain.base_language import BaseLanguageModel import esbmc_ai_lib.config as config -from esbmc_ai_lib.frontend.solution import set_main_source_file, get_main_source_file +from esbmc_ai_lib.frontend.solution import ( + SourceFile, + get_main_source_file, + set_main_source_file, + get_main_source_file_path, +) from esbmc_ai_lib.commands import ( ChatCommand, @@ -136,6 +141,10 @@ def init_commands_list() -> None: command_names = [command.command_name for command in commands] +def update_solution(source_code: str) -> None: + set_main_source_file(SourceFile(get_main_source_file_path(), source_code)) + + def init_commands() -> None: """# Bus Signals Function that handles initializing commands. Each command needs to be added @@ -144,7 +153,7 @@ def init_commands() -> None: # Let the AI model know about the corrected code. fix_code_command.on_solution_signal.add_listener(chat.set_solution) - fix_code_command.on_solution_signal.add_listener(optimize_code_command.set_solution) + fix_code_command.on_solution_signal.add_listener(update_solution) def _run_command_mode( @@ -155,7 +164,7 @@ def _run_command_mode( ) -> None: if command == fix_code_command: error, solution = fix_code_command.execute( - file_name=get_main_source_file(), + file_name=get_main_source_file_path(), source_code=source_code, esbmc_output=esbmc_output, ) @@ -169,7 +178,7 @@ def _run_command_mode( # raise NotImplementedError() elif command == optimize_code_command: optimize_code_command.execute( - file_path=get_main_source_file(), + file_path=get_main_source_file_path(), source_code=source_code, function_names=args, ) @@ -262,9 +271,6 @@ def main() -> None: config.load_config(config.cfg_path) config.load_args(args) - # Add the main source file to the solution explorer. - set_main_source_file(args.filename) - check_health() anim: LoadingWidget = create_loading_widget() @@ -274,12 +280,13 @@ def main() -> None: print(f"Running ESBMC with {config.esbmc_params}\n") # Read source code - with open(get_main_source_file(), mode="r") as file: - source_code: str = file.read() + with open(get_main_source_file_path(), mode="r") as file: + # Add the main source file to the solution explorer. + set_main_source_file(SourceFile(args.filename, file.read())) anim.start("ESBMC is processing... Please Wait") exit_code, esbmc_output, esbmc_err_output = esbmc( - path=get_main_source_file(), + path=get_main_source_file_path(), esbmc_params=config.esbmc_params, ) anim.stop() @@ -311,7 +318,7 @@ def main() -> None: _run_command_mode( command=commands[idx], args=[], # NOTE: Currently not supported... - source_code=source_code, + source_code=get_main_source_file().content, esbmc_output=esbmc_output, ) sys.exit(0) @@ -330,7 +337,7 @@ def main() -> None: ), ai_model=config.ai_model, llm=chat_llm, - source_code=source_code, + source_code=get_main_source_file().content, esbmc_output=esbmc_output, ) @@ -373,8 +380,8 @@ def main() -> None: print("ESBMC-AI will generate a fix for the code...") error, solution = fix_code_command.execute( - file_name=get_main_source_file(), - source_code=source_code, + file_name=get_main_source_file_path(), + source_code=get_main_source_file().content, esbmc_output=esbmc_output, ) @@ -387,8 +394,8 @@ def main() -> None: elif command == optimize_code_command.command_name: # Optimize Code command optimize_code_command.execute( - file_path=get_main_source_file(), - source_code=source_code, + file_path=get_main_source_file_path(), + source_code=get_main_source_file().content, function_names=command_args, ) continue diff --git a/esbmc_ai_lib/frontend/solution.py b/esbmc_ai_lib/frontend/solution.py index 029c349..eb69ea3 100644 --- a/esbmc_ai_lib/frontend/solution.py +++ b/esbmc_ai_lib/frontend/solution.py @@ -3,25 +3,39 @@ """# Solution Keeps track of all the source files that ESBMC-AI is targeting.""" -_main_source_file: str = "" -_source_files: set[str] = set() +from typing import NamedTuple -def add_source_file(source_file: str) -> None: +class SourceFile(NamedTuple): + file_path: str + content: str + + +_main_source_file: SourceFile = SourceFile("", "") +_source_files: set[SourceFile] = set() + + +def add_source_file(source_file: SourceFile) -> None: global _source_files _source_files.add(source_file) -def set_main_source_file(source_file: str) -> None: +def set_main_source_file(source_file: SourceFile) -> None: add_source_file(source_file) global _main_source_file _main_source_file = source_file -def get_main_source_file() -> str: +def get_main_source_file_path() -> str: + global _main_source_file + return _main_source_file.file_path + + +def get_main_source_file() -> SourceFile: global _main_source_file return _main_source_file -def get_source_files() -> list[str]: +def get_source_files() -> list[SourceFile]: + global _source_files return list(_source_files)