diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index 33f8e40..d24c200 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -186,6 +186,7 @@ def _execute_fix_code_command(source_file: SourceFile) -> FixCodeCommandResult: source_code_format=Config.get_value("source_code_format"), esbmc_output_format=Config.get_value("esbmc.output_type"), scenarios=Config.get_fix_code_scenarios(), + output_dir=Config.output_dir, ) @@ -310,12 +311,13 @@ def main() -> None: help="Generate patch files and place them in the same folder as the source files.", ) - # parser.add_argument( - # "--generate-default-config", - # action="store_true", - # default=False, - # help="Will generate and save the default config to the current working directory as 'esbmcai.toml'." - # ) + parser.add_argument( + "-o", + "--output-dir", + default="", + help="Store the result at the following dir. Specifying the same directory will " + + "overwrite the original file.", + ) args: argparse.Namespace = parser.parse_args() @@ -378,7 +380,6 @@ def main() -> None: assert len(solution.files) == 1 source_file: SourceFile = solution.files[0] - assert source_file.file_path esbmc_output: str = _run_esbmc(source_file, anim) diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index 6c1f228..648aa20 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -248,8 +248,9 @@ def is_valid_ai_model( name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model # Try accessing openai api and checking if there is a model defined. - # NOTE: This is not tested as no way to mock API currently. - if api_keys and api_keys.openai: + # Will only work on models that start with gpt- to avoid spamming API and + # getting blocked. NOTE: This is not tested as no way to mock API currently. + if name.startswith("gpt-") and api_keys and api_keys.openai: try: from openai import Client diff --git a/esbmc_ai/chats/base_chat_interface.py b/esbmc_ai/chats/base_chat_interface.py index 5d2ab07..5dd9418 100644 --- a/esbmc_ai/chats/base_chat_interface.py +++ b/esbmc_ai/chats/base_chat_interface.py @@ -2,6 +2,7 @@ from abc import abstractmethod from typing import Optional +import traceback from langchain.schema import ( BaseMessage, @@ -89,6 +90,7 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse: # Check if token limit has been exceeded. all_messages.append(response_message) + # FIXME This causes a warning new_tokens: int = self.llm.get_num_tokens_from_messages( messages=all_messages, ) @@ -106,6 +108,7 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse: ) except Exception as e: print(f"There was an unkown error when generating a response: {e}") + traceback.print_exc() exit(1) return response diff --git a/esbmc_ai/commands/fix_code_command.py b/esbmc_ai/commands/fix_code_command.py index f925942..0ee8b1b 100644 --- a/esbmc_ai/commands/fix_code_command.py +++ b/esbmc_ai/commands/fix_code_command.py @@ -1,5 +1,6 @@ # Author: Yiannis Charalambous +from pathlib import Path import sys from typing import Any, Optional, Tuple from typing_extensions import override @@ -64,7 +65,6 @@ def print_raw_conversation() -> None: # Handle kwargs source_file: SourceFile = kwargs["source_file"] - assert source_file.file_path generate_patches: bool = ( kwargs["generate_patches"] if "generate_patches" in kwargs else False @@ -89,6 +89,9 @@ def print_raw_conversation() -> None: raw_conversation: bool = ( kwargs["raw_conversation"] if "raw_conversation" in kwargs else False ) + output_dir: Optional[Path] = ( + kwargs["output_dir"] if "output_dir" in kwargs else None + ) # End of handle kwargs match message_history: @@ -158,6 +161,7 @@ def print_raw_conversation() -> None: if finish_reason == FinishReason.length: solution_generator.compress_message_stack() else: + # Update the source file state source_file.update_content(llm_solution) break @@ -204,6 +208,12 @@ def print_raw_conversation() -> None: else: returned_source = source_file.latest_content + if output_dir: + assert ( + output_dir.is_dir() + ), "FixCodeCommand: Output directory needs to be valid" + with open(output_dir / source_file.file_path.name, "w") as file: + file.write(source_file.latest_content) return FixCodeCommandResult(True, returned_source) try: diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index d107bb6..1cde5ee 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -94,6 +94,7 @@ class Config: raw_conversation: bool = False cfg_path: Path generate_patches: bool + output_dir: Optional[Path] = None _fields: List[ConfigField] = [ ConfigField( @@ -431,6 +432,17 @@ def _load_args(cls, args) -> None: Config.raw_conversation = args.raw_conversation Config.generate_patches = args.generate_patches + if args.output_dir: + path: Path = Path(args.output_dir).expanduser() + if path.is_dir(): + Config.output_dir = path + else: + print( + "Error while parsing arguments: output_dir: dir does not exist:", + Config.output_dir, + ) + sys.exit(1) + @classmethod def _flatten_dict(cls, d, parent_key="", sep="."): """Recursively flattens a nested dictionary.""" diff --git a/esbmc_ai/solution.py b/esbmc_ai/solution.py index 05cf2a2..75fbc2c 100644 --- a/esbmc_ai/solution.py +++ b/esbmc_ai/solution.py @@ -39,9 +39,9 @@ def apply_line_patch( return "\n".join(lines) def __init__( - self, file_path: Optional[Path], content: str, file_ext: Optional[str] = None + self, file_path: Path, content: str, file_ext: Optional[str] = None ) -> None: - self._file_path: Optional[Path] = file_path + self._file_path: Path = file_path # Content file shows the file throughout the repair process. Index 0 is # the orignial. self._content: list[str] = [content] @@ -50,7 +50,7 @@ def __init__( self._file_ext: Optional[str] = file_ext @property - def file_path(self) -> Optional[Path]: + def file_path(self) -> Path: """Returns the file path of this source file.""" return self._file_path @@ -162,7 +162,7 @@ def save_file( the saved file in /tmp and use the file_path file name only.""" file_name: Optional[str] = None - dir_path: Optional[Path] = None + dir_path: Path if file_path: # If file path is a file, then use the name and directory. If not # then use a temporary name and just store the folder. @@ -172,19 +172,15 @@ def save_file( else: dir_path = file_path else: - if not self._file_path: - raise ValueError( - "Source code file does not have a name or file_path to save to" - ) # Just store the file and use the temp dir. file_name = self._file_path.name - if temp_dir: - dir_path = Path(gettempdir()) + if not temp_dir: + raise ValueError( + "Need to enable temporary directory or provide file path to store to." + ) - assert ( - dir_path - ), "dir_path could not be retrieved: file_path or temp_dir need to be set." + dir_path = Path(gettempdir()) # Create path if it does not exist. if not os.path.exists(dir_path): @@ -241,29 +237,17 @@ def files(self) -> tuple[SourceFile, ...]: @property def files_mapped(self) -> dict[Path, SourceFile]: """Will return the files mapped to their directory. Returns by value.""" - return { - source_file.file_path: source_file - for source_file in self._files - if source_file.file_path - } - - def add_source_file( - self, file_path: Optional[Path], content: Optional[str] - ) -> None: - """Add a source file to the solution.""" - if file_path: - if content: - self._files.append(SourceFile(file_path, content)) - else: - with open(file_path, "r") as file: - self._files.append(SourceFile(file_path, file.read())) - return + return {source_file.file_path: source_file for source_file in self._files} + def add_source_file(self, file_path: Path, content: Optional[str]) -> None: + """Add a source file to the solution. If content is provided then it will + not be loaded.""" + assert file_path if content: self._files.append(SourceFile(file_path, content)) - return - - raise RuntimeError("file_path and content cannot be both invalid!") + else: + with open(file_path, "r") as file: + self._files.append(SourceFile(file_path, file.read())) # Define a global solution (is not required to be used) diff --git a/tests/test_solution.py b/tests/test_solution.py index f5d189f..4ab3c98 100644 --- a/tests/test_solution.py +++ b/tests/test_solution.py @@ -14,24 +14,34 @@ def solution() -> Solution: def test_add_source_file(solution) -> None: - src: str = "int main(int argc, char** argv) {return 0;}" - solution.add_source_file(None, src) + src = '#include int main(int argc, char** argv) { printf("hello world\n"); return 0;}' + solution.add_source_file("Testfile1", src) + solution.add_source_file("Testfile2", src) + solution.add_source_file("Testfile3", src) + + assert len(solution.files) == 3 + assert ( - len(solution.files) == 1 - and solution.files[0].file_path == None + solution.files[0].file_path == "Testfile1" and solution.files[0].latest_content == src ) - src = '#include int main(int argc, char** argv) { printf("hello world\n"); return 0;}' - solution.add_source_file("Testfile1", src) assert ( - len(solution.files) == 2 - and solution.files[1].file_path == "Testfile1" + solution.files[1].file_path == "Testfile2" and solution.files[1].latest_content == src ) assert ( - len(solution.files_mapped) == 1 + solution.files[2].file_path == "Testfile3" + and solution.files[2].latest_content == src + ) + + assert ( + len(solution.files_mapped) == 3 and solution.files_mapped["Testfile1"].file_path == "Testfile1" and solution.files_mapped["Testfile1"].initial_content == src + and solution.files_mapped["Testfile2"].file_path == "Testfile2" + and solution.files_mapped["Testfile2"].initial_content == src + and solution.files_mapped["Testfile3"].file_path == "Testfile3" + and solution.files_mapped["Testfile3"].initial_content == src )