Skip to content

Commit

Permalink
expand support to Claude, Cohere, Llama2 and Azure
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia authored Aug 4, 2023
1 parent 776aac1 commit 28ead2c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 9 additions & 7 deletions pentestgpt/utils/APIs/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import Any, Dict, List, Tuple
from tenacity import *
from pentestgpt.utils.llm_api import LLMAPI

import litellm
from litellm import completion
import loguru
import openai, tiktoken

Expand Down Expand Up @@ -44,7 +45,8 @@ def __eq__(self, other):
class ChatGPTAPI(LLMAPI):
def __init__(self, config_class):
self.name = str(config_class.model)
openai.api_key = os.getenv("OPENAI_KEY", None)
litellm.openai_key = os.getenv("OPENAI_KEY", None) # set a key for openai
litellm.api_base = config_class.api_base
openai.api_base = config_class.api_base
self.model = config_class.model
self.log_dir = config_class.log_dir
Expand All @@ -61,7 +63,7 @@ def _chat_completion(
model = "gpt-4"
# otherwise, just use the default model (because it is cheaper lol)
try:
response = openai.ChatCompletion.create(
response = completion(
model=model,
messages=history,
temperature=temperature,
Expand All @@ -74,7 +76,7 @@ def _chat_completion(
)
logger.log("Connection Error: ", e)
time.sleep(self.error_wait_time)
response = openai.ChatCompletion.create(
response = completion(
model=model,
messages=history,
temperature=temperature,
Expand All @@ -87,7 +89,7 @@ def _chat_completion(
)
logger.error("Rate Limit Error: ", e)
time.sleep(self.config.error_wait_time)
response = openai.ChatCompletion.create(
response = completion(
model=model,
messages=history,
temperature=temperature,
Expand All @@ -103,7 +105,7 @@ def _chat_completion(
self.history_length -= 1
## update the history
history = history[-self.history_length :]
response = openai.ChatCompletion.create(
response = completion(
model=model,
messages=history,
temperature=temperature,
Expand All @@ -114,7 +116,7 @@ def _chat_completion(
logger.warning("Response is not valid. Waiting for 5 seconds")
try:
time.sleep(5)
response = openai.ChatCompletion.create(
response = completion(
model=model,
messages=history,
temperature=temperature,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ langchain
tiktoken
pycookiecheat
tenacity
gpt4all
gpt4all
litellm

0 comments on commit 28ead2c

Please sign in to comment.