From 28ead2c1fef1988f6bcf073df7d78859af7064b6 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Fri, 4 Aug 2023 10:25:57 -0700 Subject: [PATCH] expand support to Claude, Cohere, Llama2 and Azure --- pentestgpt/utils/APIs/chatgpt_api.py | 16 +++++++++------- requirements.txt | 3 ++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pentestgpt/utils/APIs/chatgpt_api.py b/pentestgpt/utils/APIs/chatgpt_api.py index e6a0699..1eaeaf4 100644 --- a/pentestgpt/utils/APIs/chatgpt_api.py +++ b/pentestgpt/utils/APIs/chatgpt_api.py @@ -6,7 +6,8 @@ from typing import Any, Dict, List, Tuple from tenacity import * from pentestgpt.utils.llm_api import LLMAPI - +import litellm +from litellm import completion import loguru import openai, tiktoken @@ -44,7 +45,8 @@ def __eq__(self, other): class ChatGPTAPI(LLMAPI): def __init__(self, config_class): self.name = str(config_class.model) - openai.api_key = os.getenv("OPENAI_KEY", None) + litellm.openai_key = os.getenv("OPENAI_KEY", None) # set a key for openai + litellm.api_base = config_class.api_base openai.api_base = config_class.api_base self.model = config_class.model self.log_dir = config_class.log_dir @@ -61,7 +63,7 @@ def _chat_completion( model = "gpt-4" # otherwise, just use the default model (because it is cheaper lol) try: - response = openai.ChatCompletion.create( + response = completion( model=model, messages=history, temperature=temperature, @@ -74,7 +76,7 @@ def _chat_completion( ) logger.log("Connection Error: ", e) time.sleep(self.error_wait_time) - response = openai.ChatCompletion.create( + response = completion( model=model, messages=history, temperature=temperature, @@ -87,7 +89,7 @@ def _chat_completion( ) logger.error("Rate Limit Error: ", e) time.sleep(self.config.error_wait_time) - response = openai.ChatCompletion.create( + response = completion( model=model, messages=history, temperature=temperature, @@ -103,7 +105,7 @@ def _chat_completion( self.history_length -= 1 ## update the history history = history[-self.history_length :] - response = openai.ChatCompletion.create( + response = completion( model=model, messages=history, temperature=temperature, @@ -114,7 +116,7 @@ def _chat_completion( logger.warning("Response is not valid. Waiting for 5 seconds") try: time.sleep(5) - response = openai.ChatCompletion.create( + response = completion( model=model, messages=history, temperature=temperature, diff --git a/requirements.txt b/requirements.txt index 19b13bb..0291135 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ langchain tiktoken pycookiecheat tenacity -gpt4all \ No newline at end of file +gpt4all +litellm \ No newline at end of file