From d0a471ec3779bee17818acf14b50a25a742d6715 Mon Sep 17 00:00:00 2001 From: Hiroshi Nishio <4620828+hiroshinishio@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:43:26 -0800 Subject: [PATCH] Improved types and unified and replaced jwt with PyJWT --- main.py | 2 +- requirements.txt | 5 +- services/github/github_manager.py | 44 ++++++----- services/github/webhook_handler.py | 105 ++++++++++++++------------ services/supabase/supabase_manager.py | 38 +++++----- 5 files changed, 107 insertions(+), 87 deletions(-) diff --git a/main.py b/main.py index 1d1b4d51..3e27c52e 100644 --- a/main.py +++ b/main.py @@ -6,9 +6,9 @@ from mangum import Mangum # Local imports +from config import GITHUB_APP_ID, GITHUB_PRIVATE_KEY, GITHUB_WEBHOOK_SECRET from services.github.github_manager import GitHubManager from services.github.webhook_handler import handle_webhook_event -from config import GITHUB_APP_ID, GITHUB_PRIVATE_KEY, GITHUB_WEBHOOK_SECRET # Create FastAPI instance app = FastAPI() diff --git a/requirements.txt b/requirements.txt index 11f49646..349fba5f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ anyio==4.2.0 attrs==23.2.0 backoff==2.2.1 beautifulsoup4==4.12.3 +black==24.2.0 build==1.0.3 certifi==2024.2.2 cffi==1.16.0 @@ -26,11 +27,11 @@ httpx==0.25.2 idna==3.6 jsonschema==4.21.1 jsonschema-specifications==2023.12.1 -jwt==1.3.1 lxml==4.9.3 mangum==0.17.0 markdown-it-py==3.0.0 mdurl==0.1.2 +mypy-extensions==1.0.0 networkx==3.2.1 numpy==1.26.4 openai==1.12.0 @@ -38,6 +39,7 @@ packaging==23.2 pathspec==0.12.1 Pillow==9.5.0 pip-tools==7.3.0 +platformdirs==4.2.0 playwright==1.41.2 postgrest==0.15.0 prompt-toolkit==3.0.43 @@ -46,6 +48,7 @@ pydantic==2.6.1 pydantic_core==2.16.2 pyee==11.0.1 Pygments==2.17.2 +PyJWT==2.8.0 pypandoc==1.12 pyproject_hooks==1.0.0 python-dateutil==2.8.2 diff --git a/services/github/github_manager.py b/services/github/github_manager.py index 25f67088..34e2dfb7 100644 --- a/services/github/github_manager.py +++ b/services/github/github_manager.py @@ -1,56 +1,64 @@ +# Standard imports import hashlib # For HMAC (Hash-based Message Authentication Code) signatures import hmac # For HMAC (Hash-based Message Authentication Code) signatures -import jwt # For generating JWTs (JSON Web Tokens) import logging import requests import time +# Third-party imports +from fastapi import Request +import jwt # For generating JWTs (JSON Web Tokens) + class GitHubManager: # Constructor to initialize the GitHub App ID and private key to this instance - def __init__(self, app_id, private_key): - self.app_id = app_id - self.private_key = private_key + def __init__(self, app_id: str, private_key: bytes) -> None: + self.app_id: str = app_id + self.private_key: bytes = private_key # Generate a JWT (JSON Web Token) for GitHub App authentication - def create_jwt(self): + def create_jwt(self) -> str: now = int(time.time()) - payload = { + payload: dict[str, int | str] = { "iat": now, # Issued at time "exp": now + 600, # JWT expires in 10 minutes "iss": self.app_id, # Issuer } # The reason we use RS256 is that GitHub requires it for JWTs - return jwt.encode(payload, self.private_key, algorithm="RS256") + return jwt.encode(payload=payload, key=self.private_key, algorithm="RS256") # Verify the webhook signature for security - async def verify_webhook_signature(self, request, secret): - signature = request.headers.get("X-Hub-Signature-256") - body = await request.body() + async def verify_webhook_signature(self, request: Request, secret: str) -> None: + signature: str | None = request.headers.get("X-Hub-Signature-256") + if signature is None: + raise ValueError("Missing webhook signature") + body: bytes = await request.body() # Compare the computed signature with the one in the headers - expected_signature = "sha256=" + hmac.new(secret.encode(), body, hashlib.sha256).hexdigest() + hmac_key: bytes = secret.encode() + hmac_signature: str = hmac.new(key=hmac_key, msg=body, digestmod=hashlib.sha256).hexdigest() + expected_signature: str = "sha256=" + hmac_signature if not hmac.compare_digest(signature, expected_signature): raise ValueError("Invalid webhook signature") # Get an access token for the installed GitHub App - def get_installation_access_token(self, installation_id): + def get_installation_access_token(self, installation_id: int) -> tuple[str, str]: try: - jwt_token = self.create_jwt() - headers = { + jwt_token: str = self.create_jwt() + headers: dict[str, str] = { "Authorization": f"Bearer {jwt_token}", "Accept": "application/vnd.github.v3+json", "X-GitHub-Api-Version": "2022-11-28" } - url = f"https://api.github.com/app/installations/{installation_id}/access_tokens" + url: str = f"https://api.github.com/app/installations/{installation_id}/access_tokens" - response = requests.post(url, headers=headers) + response = requests.post(url=url, headers=headers) response.raise_for_status() # Raises HTTPError for bad responses json = response.json() return json["token"], json["expires_at"] except requests.exceptions.HTTPError as e: - logging.error(f"HTTP Error: {e.response.status_code} - {e.response.text}") + logging.error(msg=f"HTTP Error: {e.response.status_code} - {e.response.text}") raise except Exception as e: - logging.error(f"Error: {e}") + logging.error(msg=f"Error: {e}") raise diff --git a/services/github/webhook_handler.py b/services/github/webhook_handler.py index 998ee5f4..d6d395a2 100644 --- a/services/github/webhook_handler.py +++ b/services/github/webhook_handler.py @@ -4,6 +4,7 @@ import sys import time import uuid +from pathlib import Path # Third-party imports import git @@ -24,21 +25,21 @@ supabase_manager = InstallationTokenManager(url=SUPABASE_URL, key=SUPABASE_SERVICE_ROLE_KEY) -async def handle_installation_created(payload): - installation_id = payload["installation"]["id"] - account_login = payload["installation"]["account"]["login"] - html_url = payload["installation"]["account"]["html_url"] - action = payload.get("action") +async def handle_installation_created(payload: GitHubInstallationPayload) -> None: + installation_id: int = payload["installation"]["id"] + account_login: str = payload["installation"]["account"]["login"] + html_url: str = payload["installation"]["account"]["html_url"] + action: str = payload.get("action") repositories = [] repository_ids = [] if action == 'created': - repositories = [obj.get('full_name') for obj in payload["repositories"]] - repository_ids = [obj.get('id') for obj in payload["repositories"]] + repositories: list[str] = [obj.get('full_name') for obj in payload["repositories"]] + repository_ids: list[int] = [obj.get('id') for obj in payload["repositories"]] if action == 'added': repositories = [obj.get('full_name') for obj in payload["repositories_added"]] repository_ids = [obj.get('id') for obj in payload["repositories_added"]] - supabase_manager.save_installation_token(installation_id, account_login, html_url, repositories, repository_ids) + supabase_manager.save_installation_token(installation_id=installation_id, account_login=account_login, html_url=html_url, repositories=repositories, repository_ids=repository_ids) async def handle_installation_deleted(payload: GitHubInstallationPayload) -> None: @@ -47,43 +48,45 @@ async def handle_installation_deleted(payload: GitHubInstallationPayload) -> Non # Handle the issue labeled event -async def handle_issue_labeled(payload): - label = payload["label"]["name"] +async def handle_issue_labeled(payload: GitHubLabeledPayload): + # Extract label and validate it + label: str = payload["label"]["name"] if label != LABEL: return - issue = payload["issue"] - url = issue["html_url"] - repository_id = payload["repository"]["id"] - installation_id = supabase_manager.get_installation_id(repository_id) - print("Installation ID: ", installation_id) + # Extract issue and repository information + issue: IssueInfo = payload["issue"] + # url: str = issue["html_url"] + repository_id: int = payload["repository"]["id"] - with open('privateKey.pem', 'rb') as pem_file: - signing_key = pem_file.read() + # Retrieve the installation ID from Supabase + installation_id: str = supabase_manager.get_installation_id(repository_id=repository_id) - new_uuid = uuid.uuid4() - print("UUID: ", new_uuid) + # Read the private key for JWT + # https://docs.github.com/en/apps/creating-github-apps/authenticating-with-a-github-app/generating-a-json-web-token-jwt-for-a-github-app + with open(file='privateKey.pem', mode='rb') as pem_file: + signing_key: bytes = pem_file.read() + + # Create a JWT token for authentication + now = int(time.time()) payload = { - 'iat': int(time.time()), - 'exp': int(time.time()) + 600, + 'iat': now, + 'exp': now + 600, # JWT expires in 10 minutes 'iss': GITHUB_APP_ID } - - jwt_instance = JWT() - encoded_jwt = jwt_instance.encode(payload, jwk_from_pem(signing_key), alg='RS256') - - print(f"JWT: {encoded_jwt}") - - headers = { + encoded_jwt: str = jwt.encode(payload=payload, key=signing_key, algorithm='RS256') + headers: dict[str, str] = { "Authorization": f"Bearer {encoded_jwt}", "Content-Type": "application/json" - } + } - response = requests.post(f'https://api.github.com/app/installations/{installation_id}/access_tokens', headers=headers) - token = response.json().get('token') + response = requests.post(url=f'https://api.github.com/app/installations/{installation_id}/access_tokens', headers=headers) + token: str = response.json().get('token') - git.Repo.clone_from(f'https://x-access-token:{token}@github.com/nikitamalinov/lalager', f'./tmp/{new_uuid}') - + new_uuid = uuid.uuid4() + git.Repo.clone_from(url=f'https://x-access-token:{token}@github.com/nikitamalinov/lalager', to_path=f'./tmp/{new_uuid}') + + # Initialize the OpenAI API io = InputOutput( pretty=True, yes=True, @@ -96,9 +99,12 @@ async def handle_issue_labeled(payload): tool_error_color="red", encoding="utf-8", dry_run=False, - ) + ) + + # Print the tool output io.tool_output(*sys.argv, log_only=True) + git_dname = str(Path.cwd() / f'tmp/{new_uuid}') openai_api_key = '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' @@ -108,6 +114,7 @@ async def handle_issue_labeled(payload): main_model = Model.create('gpt-4-1106-preview', client) + # Create a new coder instance try: coder = Coder.create( main_model=main_model, @@ -134,40 +141,42 @@ async def handle_issue_labeled(payload): except ValueError as err: print(err) return 1 - io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args") - io.add_to_input_history("add header with tag 'Hello World' to homepage") + # Run the coder + io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args") + io.add_to_input_history(inp="add header with tag 'Hello World' to homepage") io.tool_output() coder.run(with_message="add header with tag 'Hello World' to homepage") - repo_path = Path.cwd() / f'tmp/{new_uuid}' - original_path = os.getcwd() - os.chdir(repo_path) + # Create a new branch and push to it + repo_path: Path = Path.cwd() / f'tmp/{new_uuid}' # cwd stands for current working directory + original_path: str = os.getcwd() + os.chdir(path=repo_path) - str_uuid = str(new_uuid) + str_uuid = str(object=new_uuid) # Create a new branch and push to it - repo = git.Repo(repo_path) - branch = str_uuid - repo.create_head(branch) + repo = git.Repo(path=repo_path) + branch: str = str_uuid + repo.create_head(path=branch) repo.git.push('origin', branch) # Push to branch to create PR - remote_url = repo.remotes.origin.url - repo_name = remote_url.split('/')[-1].replace('.git', '') - repo_owner = remote_url.split('/')[-2] + remote_url: str = repo.remotes.origin.url + repo_name: str = remote_url.split(sep='/')[-1].replace('.git', '') + repo_owner: str = remote_url.split(sep='/')[-2] url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/pulls" headers = { "Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}", "X-GitHub-Api-Version": "2022-11-28", } - data = { + data: dict[str, str] = { "title": issue['title'], "body": "World", "head": f"nikitamalinov:{str_uuid}", "base": 'main', } - response = requests.post(url, headers=headers, json=data) + response = requests.post(url=url, headers=headers, json=data) os.chdir(original_path) diff --git a/services/supabase/supabase_manager.py b/services/supabase/supabase_manager.py index 7f8f8dc7..0941207c 100644 --- a/services/supabase/supabase_manager.py +++ b/services/supabase/supabase_manager.py @@ -5,28 +5,28 @@ # Manager class to handle installation tokens class InstallationTokenManager: # Initialize Supabase client when the manager is created - def __init__(self, url, key): - self.client: Client = create_client(url, key) + def __init__(self, url: str, key: str) -> None: + self.client: Client = create_client(supabase_url=url, supabase_key=key) - def save_installation_token(self, installation_id, account_login, html_url, repositories, repository_ids): - data, _ = self.client.table("repo_info").select("*").eq("installation_id", installation_id).execute() + def save_installation_token(self, installation_id, account_login: str, html_url: str, repositories, repository_ids) -> None: + data, _ = self.client.table(table_name="repo_info").select("*").eq(column="installation_id", value=installation_id).execute() if (len(data[1]) > 0): - self.client.table("repo_info").update({ - "installation_id": installation_id, - "login": account_login, - 'html_url': html_url, - "repositories": repositories, - "repository_ids": repository_ids, - "deleted_at": None, - }).eq("installation_id", installation_id).execute() + self.client.table(table_name="repo_info").update(json={ + "installation_id": installation_id, + "login": account_login, + 'html_url': html_url, + "repositories": repositories, + "repository_ids": repository_ids, + "deleted_at": None, + }).eq(column="installation_id", value=installation_id).execute() else: - self.client.table("repo_info").insert({ - "installation_id": installation_id, - "login": account_login, - 'html_url': html_url, - "repositories": repositories, - "repository_ids": repository_ids, - }).execute() + self.client.table(table_name="repo_info").insert(json={ + "installation_id": installation_id, + "login": account_login, + 'html_url': html_url, + "repositories": repositories, + "repository_ids": repository_ids, + }).execute() def get_installation_id(self, repository_id: int) -> str: data, _ = self.client.table(table_name="repo_info").select("installation_id").contains(column='repository_ids', value=[str(object=repository_id)]).execute()