diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 00000000..8b2dd985 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,26 @@ +name: Python PyTest Workflow + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python 3.12 + uses: actions/setup-python@v2 + with: + python-version: 3.12 + + - name: Install pytest + run: | + pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Set PYTHONPATH + run: echo "PYTHONPATH=$GITHUB_WORKSPACE" >> $GITHUB_ENV + + - name: Run pytest + run: pytest diff --git a/requirements.txt b/requirements.txt index d2fc8f8b..78973e28 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/services/openai/agent.py b/services/openai/agent.py index d996ba9d..eae4cbc7 100644 --- a/services/openai/agent.py +++ b/services/openai/agent.py @@ -15,7 +15,7 @@ from services.openai.functions import GET_REMOTE_FILE_CONTENT, functions from services.openai.init import create_openai_client from services.openai.instructions import SYSTEM_INSTRUCTION_FOR_AGENT -from utils.file_manager import clean_specific_lines, split_diffs +from utils.file_manager import clean_specific_lines, correct_hunk_headers, split_diffs def create_assistant() -> Assistant: @@ -100,9 +100,12 @@ def run_assistant( # Clean the diff text and split it diff: str = clean_specific_lines(text=value) text_diffs: list[str] = split_diffs(diff_text=diff) + output: list[str] = [] for diff in text_diffs: + diff = correct_hunk_headers(diff_text=diff) print(f"Diff: {repr(diff)}\n") - return text_diffs + output.append(diff) + return output def submit_message(thread: Thread, user_message: str) -> Run: diff --git a/tests/utils/test_file_manager.py b/tests/utils/test_file_manager.py new file mode 100644 index 00000000..48cd53c1 --- /dev/null +++ b/tests/utils/test_file_manager.py @@ -0,0 +1,35 @@ +from utils.file_manager import correct_hunk_headers + + +def test_correct_hunk_headers_1() -> None: + """Test the case where line counts are wrong""" + diff_text = """--- a/example.txt ++++ b/example.txt +@@ -1,3 +1,3 @@ +- old line 1 ++ new line 1 +- old line 2 ++ new line 2""" + expected_result = """--- a/example.txt ++++ b/example.txt +@@ -1,2 +1,2 @@ +- old line 1 ++ new line 1 +- old line 2 ++ new line 2""" + assert correct_hunk_headers(diff_text=diff_text) == expected_result + + +def test_correct_hunk_headers_2() -> None: + """Test the case where the hunk header does not have a line count""" + diff_text = """--- a/example2.txt ++++ b/example2.txt +@@ -1 +1 @@ +- old line 1 ++ new line 1""" + expected_result = """--- a/example2.txt ++++ b/example2.txt +@@ -1,1 +1,1 @@ +- old line 1 ++ new line 1""" + assert correct_hunk_headers(diff_text=diff_text) == expected_result diff --git a/utils/file_manager.py b/utils/file_manager.py index cb49359d..781e9af5 100644 --- a/utils/file_manager.py +++ b/utils/file_manager.py @@ -1,8 +1,8 @@ +import logging import os import re import subprocess import tempfile -import logging def apply_patch(original_text: str, diff_text: str) -> str: @@ -51,12 +51,12 @@ def apply_patch(original_text: str, diff_text: str) -> str: print("Failed to apply patch.") print(f"stdout: {e.stdout}") print(f"stderr: {e.stderr}\n") - logging.error(f"apply_patch stderr: {e.stderr}") + logging.error(msg=f"apply_patch stderr: {e.stderr}") # pylint: disable=no-member print(f"Command: {' '.join(e.cmd)}") print(f"Exit status: {e.returncode}") return "" - except Exception as e: - logging.error(f"Error: {e}") + except Exception as e: # pylint: disable=broad-except + logging.error(msg=f"Error: {e}") # pylint: disable=no-member return "" finally: os.remove(path=original_file_name) @@ -76,6 +76,53 @@ def clean_specific_lines(text: str) -> str: ).strip() +def correct_hunk_headers(diff_text: str) -> str: + """ + Match following patterns: + 1: @@ -start1 +start2 @@ + 2: @@ -start1,lines1 +start2 @@ + 3: @@ -start1 +start2,lines2 @@ + 4: @@ -start1,lines1 +start2,lines2 @@ + """ + # Split the diff into lines + lines: list[str] = diff_text.splitlines() + updated_lines: list[str] = [] + hunk_pattern: re.Pattern[str] = re.compile( + pattern=r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@') + + i = 0 + while i < len(lines): + line: str = lines[i] + match: re.Match[str] | None = hunk_pattern.match(string=line) + + # Add the line to the updated diff if it's not a hunk header + if not match: + updated_lines.append(line) + i += 1 + continue + + # Correct the hunk header if match is not None + l1, _s1, l2, _s2 = (int(x) if x is not None else 0 for x in match.groups()) + s1_actual, s2_actual = 0, 0 + i += 1 + + # Count actual number of lines changed + start_index: int = i + while i < len(lines) and not lines[i].startswith('@@'): + if lines[i].startswith('+'): + s2_actual += 1 + if lines[i].startswith('-'): + s1_actual += 1 + i += 1 + + # Update the hunk header with actual numbers + updated_hunk_header: str = f'@@ -{l1},{s1_actual} +{l2},{s2_actual} @@' + updated_lines.append(updated_hunk_header) + updated_lines.extend(lines[start_index:i]) + + return '\n'.join(updated_lines) + + def extract_file_name(diff_text: str) -> str: match = re.search(pattern=r"^\+\+\+ (.+)$", string=diff_text, flags=re.MULTILINE) if match: