Skip to content

Commit

Permalink
Add argument to write result of fix code to directory
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiannis128 committed Sep 16, 2024
1 parent 090cdf7 commit 834cde8
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 52 deletions.
15 changes: 8 additions & 7 deletions esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def _execute_fix_code_command(source_file: SourceFile) -> FixCodeCommandResult:
source_code_format=Config.get_value("source_code_format"),
esbmc_output_format=Config.get_value("esbmc.output_type"),
scenarios=Config.get_fix_code_scenarios(),
output_dir=Config.output_dir,
)


Expand Down Expand Up @@ -310,12 +311,13 @@ def main() -> None:
help="Generate patch files and place them in the same folder as the source files.",
)

# parser.add_argument(
# "--generate-default-config",
# action="store_true",
# default=False,
# help="Will generate and save the default config to the current working directory as 'esbmcai.toml'."
# )
parser.add_argument(
"-o",
"--output-dir",
default="",
help="Store the result at the following dir. Specifying the same directory will "
+ "overwrite the original file.",
)

args: argparse.Namespace = parser.parse_args()

Expand Down Expand Up @@ -378,7 +380,6 @@ def main() -> None:
assert len(solution.files) == 1

source_file: SourceFile = solution.files[0]
assert source_file.file_path

esbmc_output: str = _run_esbmc(source_file, anim)

Expand Down
5 changes: 3 additions & 2 deletions esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ def is_valid_ai_model(
name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model

# Try accessing openai api and checking if there is a model defined.
# NOTE: This is not tested as no way to mock API currently.
if api_keys and api_keys.openai:
# Will only work on models that start with gpt- to avoid spamming API and
# getting blocked. NOTE: This is not tested as no way to mock API currently.
if name.startswith("gpt-") and api_keys and api_keys.openai:
try:
from openai import Client

Expand Down
3 changes: 3 additions & 0 deletions esbmc_ai/chats/base_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import abstractmethod
from typing import Optional
import traceback

from langchain.schema import (
BaseMessage,
Expand Down Expand Up @@ -89,6 +90,7 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse:

# Check if token limit has been exceeded.
all_messages.append(response_message)
# FIXME This causes a warning
new_tokens: int = self.llm.get_num_tokens_from_messages(
messages=all_messages,
)
Expand All @@ -106,6 +108,7 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse:
)
except Exception as e:
print(f"There was an unkown error when generating a response: {e}")
traceback.print_exc()
exit(1)

return response
12 changes: 11 additions & 1 deletion esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Author: Yiannis Charalambous

from pathlib import Path
import sys
from typing import Any, Optional, Tuple
from typing_extensions import override
Expand Down Expand Up @@ -64,7 +65,6 @@ def print_raw_conversation() -> None:

# Handle kwargs
source_file: SourceFile = kwargs["source_file"]
assert source_file.file_path

generate_patches: bool = (
kwargs["generate_patches"] if "generate_patches" in kwargs else False
Expand All @@ -89,6 +89,9 @@ def print_raw_conversation() -> None:
raw_conversation: bool = (
kwargs["raw_conversation"] if "raw_conversation" in kwargs else False
)
output_dir: Optional[Path] = (
kwargs["output_dir"] if "output_dir" in kwargs else None
)
# End of handle kwargs

match message_history:
Expand Down Expand Up @@ -158,6 +161,7 @@ def print_raw_conversation() -> None:
if finish_reason == FinishReason.length:
solution_generator.compress_message_stack()
else:
# Update the source file state
source_file.update_content(llm_solution)
break

Expand Down Expand Up @@ -204,6 +208,12 @@ def print_raw_conversation() -> None:
else:
returned_source = source_file.latest_content

if output_dir:
assert (
output_dir.is_dir()
), "FixCodeCommand: Output directory needs to be valid"
with open(output_dir / source_file.file_path.name, "w") as file:
file.write(source_file.latest_content)
return FixCodeCommandResult(True, returned_source)

try:
Expand Down
12 changes: 12 additions & 0 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Config:
raw_conversation: bool = False
cfg_path: Path
generate_patches: bool
output_dir: Optional[Path] = None

_fields: List[ConfigField] = [
ConfigField(
Expand Down Expand Up @@ -431,6 +432,17 @@ def _load_args(cls, args) -> None:
Config.raw_conversation = args.raw_conversation
Config.generate_patches = args.generate_patches

if args.output_dir:
path: Path = Path(args.output_dir).expanduser()
if path.is_dir():
Config.output_dir = path
else:
print(
"Error while parsing arguments: output_dir: dir does not exist:",
Config.output_dir,
)
sys.exit(1)

@classmethod
def _flatten_dict(cls, d, parent_key="", sep="."):
"""Recursively flattens a nested dictionary."""
Expand Down
50 changes: 17 additions & 33 deletions esbmc_ai/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def apply_line_patch(
return "\n".join(lines)

def __init__(
self, file_path: Optional[Path], content: str, file_ext: Optional[str] = None
self, file_path: Path, content: str, file_ext: Optional[str] = None
) -> None:
self._file_path: Optional[Path] = file_path
self._file_path: Path = file_path
# Content file shows the file throughout the repair process. Index 0 is
# the orignial.
self._content: list[str] = [content]
Expand All @@ -50,7 +50,7 @@ def __init__(
self._file_ext: Optional[str] = file_ext

@property
def file_path(self) -> Optional[Path]:
def file_path(self) -> Path:
"""Returns the file path of this source file."""
return self._file_path

Expand Down Expand Up @@ -162,7 +162,7 @@ def save_file(
the saved file in /tmp and use the file_path file name only."""

file_name: Optional[str] = None
dir_path: Optional[Path] = None
dir_path: Path
if file_path:
# If file path is a file, then use the name and directory. If not
# then use a temporary name and just store the folder.
Expand All @@ -172,19 +172,15 @@ def save_file(
else:
dir_path = file_path
else:
if not self._file_path:
raise ValueError(
"Source code file does not have a name or file_path to save to"
)
# Just store the file and use the temp dir.
file_name = self._file_path.name

if temp_dir:
dir_path = Path(gettempdir())
if not temp_dir:
raise ValueError(
"Need to enable temporary directory or provide file path to store to."
)

assert (
dir_path
), "dir_path could not be retrieved: file_path or temp_dir need to be set."
dir_path = Path(gettempdir())

# Create path if it does not exist.
if not os.path.exists(dir_path):
Expand Down Expand Up @@ -241,29 +237,17 @@ def files(self) -> tuple[SourceFile, ...]:
@property
def files_mapped(self) -> dict[Path, SourceFile]:
"""Will return the files mapped to their directory. Returns by value."""
return {
source_file.file_path: source_file
for source_file in self._files
if source_file.file_path
}

def add_source_file(
self, file_path: Optional[Path], content: Optional[str]
) -> None:
"""Add a source file to the solution."""
if file_path:
if content:
self._files.append(SourceFile(file_path, content))
else:
with open(file_path, "r") as file:
self._files.append(SourceFile(file_path, file.read()))
return
return {source_file.file_path: source_file for source_file in self._files}

def add_source_file(self, file_path: Path, content: Optional[str]) -> None:
"""Add a source file to the solution. If content is provided then it will
not be loaded."""
assert file_path
if content:
self._files.append(SourceFile(file_path, content))
return

raise RuntimeError("file_path and content cannot be both invalid!")
else:
with open(file_path, "r") as file:
self._files.append(SourceFile(file_path, file.read()))


# Define a global solution (is not required to be used)
Expand Down
28 changes: 19 additions & 9 deletions tests/test_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,34 @@ def solution() -> Solution:


def test_add_source_file(solution) -> None:
src: str = "int main(int argc, char** argv) {return 0;}"
solution.add_source_file(None, src)
src = '#include <stdio.h> int main(int argc, char** argv) { printf("hello world\n"); return 0;}'
solution.add_source_file("Testfile1", src)
solution.add_source_file("Testfile2", src)
solution.add_source_file("Testfile3", src)

assert len(solution.files) == 3

assert (
len(solution.files) == 1
and solution.files[0].file_path == None
solution.files[0].file_path == "Testfile1"
and solution.files[0].latest_content == src
)
src = '#include <stdio.h> int main(int argc, char** argv) { printf("hello world\n"); return 0;}'
solution.add_source_file("Testfile1", src)
assert (
len(solution.files) == 2
and solution.files[1].file_path == "Testfile1"
solution.files[1].file_path == "Testfile2"
and solution.files[1].latest_content == src
)
assert (
len(solution.files_mapped) == 1
solution.files[2].file_path == "Testfile3"
and solution.files[2].latest_content == src
)

assert (
len(solution.files_mapped) == 3
and solution.files_mapped["Testfile1"].file_path == "Testfile1"
and solution.files_mapped["Testfile1"].initial_content == src
and solution.files_mapped["Testfile2"].file_path == "Testfile2"
and solution.files_mapped["Testfile2"].initial_content == src
and solution.files_mapped["Testfile3"].file_path == "Testfile3"
and solution.files_mapped["Testfile3"].initial_content == src
)


Expand Down

0 comments on commit 834cde8

Please sign in to comment.