Skip to content

Commit

Permalink
Merge pull request #50 from gitautoai/hiroshi
Browse files Browse the repository at this point in the history
Always correct hunk header line count.
  • Loading branch information
nikitamalinov authored Mar 20, 2024
2 parents 50d19df + 5118cdd commit c3fa42f
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 6 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Python PyTest Workflow

on: [push, pull_request]

jobs:
build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: 3.12

- name: Install pytest
run: |
pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Set PYTHONPATH
run: echo "PYTHONPATH=$GITHUB_WORKSPACE" >> $GITHUB_ENV

- name: Run pytest
run: pytest
Binary file modified requirements.txt
Binary file not shown.
7 changes: 5 additions & 2 deletions services/openai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from services.openai.functions import GET_REMOTE_FILE_CONTENT, functions
from services.openai.init import create_openai_client
from services.openai.instructions import SYSTEM_INSTRUCTION_FOR_AGENT
from utils.file_manager import clean_specific_lines, split_diffs
from utils.file_manager import clean_specific_lines, correct_hunk_headers, split_diffs


def create_assistant() -> Assistant:
Expand Down Expand Up @@ -100,9 +100,12 @@ def run_assistant(
# Clean the diff text and split it
diff: str = clean_specific_lines(text=value)
text_diffs: list[str] = split_diffs(diff_text=diff)
output: list[str] = []
for diff in text_diffs:
diff = correct_hunk_headers(diff_text=diff)
print(f"Diff: {repr(diff)}\n")
return text_diffs
output.append(diff)
return output


def submit_message(thread: Thread, user_message: str) -> Run:
Expand Down
35 changes: 35 additions & 0 deletions tests/utils/test_file_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from utils.file_manager import correct_hunk_headers


def test_correct_hunk_headers_1() -> None:
"""Test the case where line counts are wrong"""
diff_text = """--- a/example.txt
+++ b/example.txt
@@ -1,3 +1,3 @@
- old line 1
+ new line 1
- old line 2
+ new line 2"""
expected_result = """--- a/example.txt
+++ b/example.txt
@@ -1,2 +1,2 @@
- old line 1
+ new line 1
- old line 2
+ new line 2"""
assert correct_hunk_headers(diff_text=diff_text) == expected_result


def test_correct_hunk_headers_2() -> None:
"""Test the case where the hunk header does not have a line count"""
diff_text = """--- a/example2.txt
+++ b/example2.txt
@@ -1 +1 @@
- old line 1
+ new line 1"""
expected_result = """--- a/example2.txt
+++ b/example2.txt
@@ -1,1 +1,1 @@
- old line 1
+ new line 1"""
assert correct_hunk_headers(diff_text=diff_text) == expected_result
55 changes: 51 additions & 4 deletions utils/file_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import os
import re
import subprocess
import tempfile
import logging


def apply_patch(original_text: str, diff_text: str) -> str:
Expand Down Expand Up @@ -51,12 +51,12 @@ def apply_patch(original_text: str, diff_text: str) -> str:
print("Failed to apply patch.")
print(f"stdout: {e.stdout}")
print(f"stderr: {e.stderr}\n")
logging.error(f"apply_patch stderr: {e.stderr}")
logging.error(msg=f"apply_patch stderr: {e.stderr}") # pylint: disable=no-member
print(f"Command: {' '.join(e.cmd)}")
print(f"Exit status: {e.returncode}")
return ""
except Exception as e:
logging.error(f"Error: {e}")
except Exception as e: # pylint: disable=broad-except
logging.error(msg=f"Error: {e}") # pylint: disable=no-member
return ""
finally:
os.remove(path=original_file_name)
Expand All @@ -76,6 +76,53 @@ def clean_specific_lines(text: str) -> str:
).strip()


def correct_hunk_headers(diff_text: str) -> str:
"""
Match following patterns:
1: @@ -start1 +start2 @@
2: @@ -start1,lines1 +start2 @@
3: @@ -start1 +start2,lines2 @@
4: @@ -start1,lines1 +start2,lines2 @@
"""
# Split the diff into lines
lines: list[str] = diff_text.splitlines()
updated_lines: list[str] = []
hunk_pattern: re.Pattern[str] = re.compile(
pattern=r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@')

i = 0
while i < len(lines):
line: str = lines[i]
match: re.Match[str] | None = hunk_pattern.match(string=line)

# Add the line to the updated diff if it's not a hunk header
if not match:
updated_lines.append(line)
i += 1
continue

# Correct the hunk header if match is not None
l1, _s1, l2, _s2 = (int(x) if x is not None else 0 for x in match.groups())
s1_actual, s2_actual = 0, 0
i += 1

# Count actual number of lines changed
start_index: int = i
while i < len(lines) and not lines[i].startswith('@@'):
if lines[i].startswith('+'):
s2_actual += 1
if lines[i].startswith('-'):
s1_actual += 1
i += 1

# Update the hunk header with actual numbers
updated_hunk_header: str = f'@@ -{l1},{s1_actual} +{l2},{s2_actual} @@'
updated_lines.append(updated_hunk_header)
updated_lines.extend(lines[start_index:i])

return '\n'.join(updated_lines)


def extract_file_name(diff_text: str) -> str:
match = re.search(pattern=r"^\+\+\+ (.+)$", string=diff_text, flags=re.MULTILINE)
if match:
Expand Down

0 comments on commit c3fa42f

Please sign in to comment.