Skip to content

Commit

Permalink
Fix code and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Aug 29, 2024
1 parent efd019c commit 87469cd
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 69 deletions.
13 changes: 9 additions & 4 deletions github_resolver/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import json
from typing import Iterable
from github_resolver.resolver_output import ResolverOutput


def load_resolver_output(output_jsonl: str, issue_number: int) -> ResolverOutput:
def load_all_resolver_outputs(output_jsonl: str) -> Iterable[ResolverOutput]:
with open(output_jsonl, "r") as f:
for line in f:
data = json.loads(line)
if data["issue"]["number"] == issue_number:
return ResolverOutput.model_validate(data)
yield ResolverOutput.model_validate(json.loads(line))


def load_single_resolver_output(output_jsonl: str, issue_number: int) -> ResolverOutput:
for resolver_output in load_all_resolver_outputs(output_jsonl):
if resolver_output.issue.number == issue_number:
return resolver_output
raise ValueError(f"Issue number {issue_number} not found in {output_jsonl}")
111 changes: 75 additions & 36 deletions github_resolver/send_pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
import os
import shutil
from github_resolver.github_issue import GithubIssue
from github_resolver.io_utils import load_resolver_output
from github_resolver.io_utils import (
load_all_resolver_outputs,
load_single_resolver_output,
)
from github_resolver.patching import parse_patch, apply_diff
import requests
import subprocess

from github_resolver.resolver_output import ResolverOutput


def apply_patch(repo_dir: str, patch: str) -> None:
diffs = parse_patch(patch)
Expand All @@ -17,35 +22,35 @@ def apply_patch(repo_dir: str, patch: str) -> None:
continue

old_path = (
os.path.join(repo_dir, diff.header.old_path.lstrip('b/'))
if diff.header.old_path and diff.header.old_path != '/dev/null'
os.path.join(repo_dir, diff.header.old_path.lstrip("b/"))
if diff.header.old_path and diff.header.old_path != "/dev/null"
else None
)
new_path = os.path.join(repo_dir, diff.header.new_path.lstrip('b/'))
new_path = os.path.join(repo_dir, diff.header.new_path.lstrip("b/"))

if old_path:
# Open the file in binary mode to detect line endings
with open(old_path, 'rb') as f:
with open(old_path, "rb") as f:
original_content = f.read()

# Detect line endings
if b'\r\n' in original_content:
newline = '\r\n'
elif b'\n' in original_content:
newline = '\n'
if b"\r\n" in original_content:
newline = "\r\n"
elif b"\n" in original_content:
newline = "\n"
else:
newline = None # Let Python decide

with open(old_path, 'r', newline=newline) as f:
with open(old_path, "r", newline=newline) as f:
split_content = [x.strip(newline) for x in f.readlines()]
else:
newline = '\n'
newline = "\n"
split_content = []

new_content = apply_diff(diff, split_content)

# Write the new content using the detected line endings
with open(new_path, 'w', newline=newline) as f:
with open(new_path, "w", newline=newline) as f:
for line in new_content:
print(line, file=f)

Expand Down Expand Up @@ -108,7 +113,6 @@ def send_pull_request(
pr_type: str,
fork_owner: str | None = None,
) -> str:

if pr_type not in ["branch", "draft", "ready"]:
raise ValueError(f"Invalid pr_type: {pr_type}")

Expand Down Expand Up @@ -154,7 +158,6 @@ def send_pull_request(
print(f"Error pushing changes\n{push_command}\n{result.stderr}")
raise RuntimeError("Failed to push changes to the remote repository")


pr_title = f"Fix issue #{github_issue.number}: {github_issue.title}"
pr_body = (
f"This pull request fixes #{github_issue.number}."
Expand Down Expand Up @@ -182,21 +185,25 @@ def send_pull_request(
response.raise_for_status()
pr_data = response.json()

url = pr_data['html_url']
url = pr_data["html_url"]

print(f"{pr_type} created: {url}\n\n--- Title: {pr_title}\n\n--- Body:\n{pr_body}")

return url


def process_single_issue(output_dir: str, issue_number: int, github_token: str, github_username: str, pr_type: str, fork_owner: str | None) -> None:
resolver_output = load_resolver_output(
os.path.join(output_dir, "output.jsonl"),
issue_number,
)

if not resolver_output.issue.resolution_successful:
print(f"Issue {issue_number} was not successfully resolved. Skipping PR creation.")
def process_single_issue(
output_dir: str,
resolver_output: ResolverOutput,
github_token: str,
github_username: str,
pr_type: str,
fork_owner: str | None,
) -> None:
if not resolver_output.issue.success:
print(
f"Issue {issue_number} was not successfully resolved. Skipping PR creation."
)
return

patched_repo_dir = initialize_repo(
Expand All @@ -216,16 +223,29 @@ def process_single_issue(output_dir: str, issue_number: int, github_token: str,
fork_owner=fork_owner,
)

def process_all_successful_issues(output_dir: str, github_token: str, github_username: str, pr_type: str, fork_owner: str | None) -> None:
all_issues = load_resolver_output(os.path.join(output_dir, "output.jsonl"))
for issue_number in all_issues.issues.keys():
resolver_output = load_resolver_output(os.path.join(output_dir, "output.jsonl"), issue_number)
if resolver_output.resolution_successful:
print(f"Processing issue {issue_number}")
process_single_issue(output_dir, issue_number, github_token, github_username, pr_type, fork_owner)

if __name__ == "__main__":
def process_all_successful_issues(
output_dir: str,
github_token: str,
github_username: str,
pr_type: str,
fork_owner: str | None,
) -> None:
output_path = os.path.join(output_dir, "output.jsonl")
for resolver_output in load_all_resolver_outputs(output_path):
if resolver_output.success:
print(f"Processing issue {resolver_output.issue.number}")
process_single_issue(
output_dir,
resolver_output,
github_token,
github_username,
pr_type,
fork_owner,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Send a pull request to Github.")
parser.add_argument(
"--github-token",
Expand Down Expand Up @@ -270,22 +290,41 @@ def process_all_successful_issues(output_dir: str, github_token: str, github_use
my_args.github_token if my_args.github_token else os.getenv("GITHUB_TOKEN")
)
if not github_token:
raise ValueError("Github token is not set, set via --github-token or GITHUB_TOKEN environment variable.")
raise ValueError(
"Github token is not set, set via --github-token or GITHUB_TOKEN environment variable."
)
github_username = (
my_args.github_username
if my_args.github_username
else os.getenv("GITHUB_USERNAME")
)
if not github_username:
raise ValueError("Github username is not set, set via --github-username or GITHUB_USERNAME environment variable.")
raise ValueError(
"Github username is not set, set via --github-username or GITHUB_USERNAME environment variable."
)

if not os.path.exists(my_args.output_dir):
raise ValueError(f"Output directory {my_args.output_dir} does not exist.")

if my_args.issue_number == 'all_successful':
process_all_successful_issues(my_args.output_dir, github_token, github_username, my_args.pr_type, my_args.fork_owner)
if my_args.issue_number == "all_successful":
process_all_successful_issues(
my_args.output_dir,
github_token,
github_username,
my_args.pr_type,
my_args.fork_owner,
)
else:
if not my_args.issue_number.isdigit():
raise ValueError(f"Issue number {my_args.issue_number} is not a number.")
issue_number = int(my_args.issue_number)
process_single_issue(my_args.output_dir, issue_number, github_token, github_username, my_args.pr_type, my_args.fork_owner)
output_path = os.path.join(my_args.output_dir, "output.jsonl")
resolver_output = load_single_resolver_output(output_path, issue_number)
process_single_issue(
my_args.output_dir,
resolver_output,
github_token,
github_username,
my_args.pr_type,
my_args.fork_owner,
)
4 changes: 2 additions & 2 deletions github_resolver/visualize_resolver_output.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import os
from github_resolver.io_utils import load_resolver_output
from github_resolver.io_utils import load_single_resolver_output


def visualize_resolver_output(issue_number: int, output_dir: str, vis_method: str):
output_jsonl = os.path.join(output_dir, "output.jsonl")
resolver_output = load_resolver_output(output_jsonl, issue_number)
resolver_output = load_single_resolver_output(output_jsonl, issue_number)
if vis_method == "json":
print(resolver_output.model_dump_json(indent=4))
else:
Expand Down
73 changes: 46 additions & 27 deletions tests/test_send_pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

from github_resolver.send_pull_request import (
apply_patch,
load_resolver_output,
load_single_resolver_output,
initialize_repo,
send_pull_request,
process_single_issue,
process_all_successful_issues,
)
from github_resolver.resolver_output import ResolverOutput, GithubIssue
Expand Down Expand Up @@ -40,11 +39,11 @@ def mock_github_issue():
)


def test_load_resolver_output():
def test_load_single_resolver_output():
mock_output_jsonl = 'tests/mock_output/output.jsonl'

# Test loading an existing issue
resolver_output = load_resolver_output(mock_output_jsonl, 5)
resolver_output = load_single_resolver_output(mock_output_jsonl, 5)
assert isinstance(resolver_output, ResolverOutput)
assert resolver_output.issue.number == 5
assert resolver_output.issue.title == "Add MIT license"
Expand All @@ -53,7 +52,7 @@ def test_load_resolver_output():

# Test loading a non-existent issue
with pytest.raises(ValueError):
load_resolver_output(mock_output_jsonl, 999)
load_single_resolver_output(mock_output_jsonl, 999)


def test_apply_patch(mock_output_dir):
Expand Down Expand Up @@ -325,30 +324,50 @@ def test_send_pull_request_permission_error(
mock_post.assert_called_once()


@patch('github_resolver.send_pull_request.load_resolver_output')
@patch('github_resolver.send_pull_request.load_all_resolver_outputs')
@patch('github_resolver.send_pull_request.process_single_issue')
def test_process_all_successful_issues(mock_process_single_issue, mock_load_resolver_output):
# Mock the load_resolver_output function to return different ResolverOutput objects
mock_all_issues = MagicMock()
mock_all_issues.issues = {1: MagicMock(), 2: MagicMock(), 3: MagicMock()}

mock_resolver_output_1 = MagicMock()
mock_resolver_output_1.resolution_successful = True
mock_resolver_output_1.issue.number = 1
def test_process_all_successful_issues(mock_process_single_issue, mock_load_all_resolver_outputs):
# Create ResolverOutput objects with properly initialized GithubIssue instances
resolver_output_1 = ResolverOutput(
issue=GithubIssue(owner="test-owner", repo="test-repo", number=1, title="Issue 1", body="Body 1"),
instruction="Test instruction 1",
base_commit="def456",
git_patch="Test patch 1",
history=[],
metrics={},
success=True,
success_explanation="Test success 1",
error=None
)

mock_resolver_output_2 = MagicMock()
mock_resolver_output_2.resolution_successful = False
mock_resolver_output_2.issue.number = 2
resolver_output_2 = ResolverOutput(
issue=GithubIssue(owner="test-owner", repo="test-repo", number=2, title="Issue 2", body="Body 2"),
instruction="Test instruction 2",
base_commit="ghi789",
git_patch="Test patch 2",
history=[],
metrics={},
success=False,
success_explanation="",
error="Test error 2"
)

mock_resolver_output_3 = MagicMock()
mock_resolver_output_3.resolution_successful = True
mock_resolver_output_3.issue.number = 3
resolver_output_3 = ResolverOutput(
issue=GithubIssue(owner="test-owner", repo="test-repo", number=3, title="Issue 3", body="Body 3"),
instruction="Test instruction 3",
base_commit="jkl012",
git_patch="Test patch 3",
history=[],
metrics={},
success=True,
success_explanation="Test success 3",
error=None
)

mock_load_resolver_output.side_effect = [
mock_all_issues,
mock_resolver_output_1,
mock_resolver_output_2,
mock_resolver_output_3
mock_load_all_resolver_outputs.return_value = [
resolver_output_1,
resolver_output_2,
resolver_output_3
]

# Call the function
Expand All @@ -359,8 +378,8 @@ def test_process_all_successful_issues(mock_process_single_issue, mock_load_reso

# Check that the function was called with the correct arguments for successful issues
mock_process_single_issue.assert_has_calls([
call("output_dir", 1, "github_token", "github_username", "draft", None),
call("output_dir", 3, "github_token", "github_username", "draft", None)
call("output_dir", resolver_output_1, "github_token", "github_username", "draft", None),
call("output_dir", resolver_output_3, "github_token", "github_username", "draft", None)
])

# Add more assertions as needed to verify the behavior of the function

0 comments on commit 87469cd

Please sign in to comment.