Skip to content

Commit

Permalink
fixed bug: function call is added repeatedly because the json request…
Browse files Browse the repository at this point in the history
… body update does not use deep copy
  • Loading branch information
yym68686 committed Dec 17, 2023
1 parent ebdc93a commit ce26978
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 42 deletions.
59 changes: 32 additions & 27 deletions test/test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import json
# my_list = [
# {"role": "admin", "content": "This is admin content."},
# {"role": "user", "content": "This is user content."}
# ]
a = {"role": "admin"}
b = {"content": "This is user content."}
b = {"content": []}
c = {"role": "admin"}
a.update(b)
# print(a)
a["content"].append(c)
print(b)
print(a)
# print(json.dumps(str(a), indent=4))

# content_list = [item["content"] for item in my_list]
# print(content_list)
Expand Down Expand Up @@ -33,28 +38,28 @@

# print(json.dumps(function_call_list["web_search"], indent=4))

class openaiAPI:
def __init__(
self,
api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"),
):
from urllib.parse import urlparse, urlunparse
self.source_api_url: str = api_url
parsed_url = urlparse(self.source_api_url)
self.base_url: str = urlunparse(parsed_url[:2] + ("",) * 4)
self.v1_url: str = urlunparse(parsed_url[:2] + ("/v1",) + ("",) * 3)
self.chat_url: str = urlunparse(parsed_url[:2] + ("/v1/chat/completions",) + ("",) * 3)
self.image_url: str = urlunparse(parsed_url[:2] + ("/v1/images/generations",) + ("",) * 3)


a = openaiAPI()
print(a.v1_url)

def getddgsearchurl(result, numresults=3):
# print("ddg-search", result)
search = DuckDuckGoSearchResults(num_results=numresults)
webresult = search.run(result)
# print("ddgwebresult", webresult)
urls = re.findall(r"(https?://\S+)\]", webresult, re.MULTILINE)
# print("duckduckgo urls", urls)
return urls
# class openaiAPI:
# def __init__(
# self,
# api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"),
# ):
# from urllib.parse import urlparse, urlunparse
# self.source_api_url: str = api_url
# parsed_url = urlparse(self.source_api_url)
# self.base_url: str = urlunparse(parsed_url[:2] + ("",) * 4)
# self.v1_url: str = urlunparse(parsed_url[:2] + ("/v1",) + ("",) * 3)
# self.chat_url: str = urlunparse(parsed_url[:2] + ("/v1/chat/completions",) + ("",) * 3)
# self.image_url: str = urlunparse(parsed_url[:2] + ("/v1/images/generations",) + ("",) * 3)


# a = openaiAPI()
# print(a.v1_url)

# def getddgsearchurl(result, numresults=3):
# # print("ddg-search", result)
# search = DuckDuckGoSearchResults(num_results=numresults)
# webresult = search.run(result)
# # print("ddgwebresult", webresult)
# urls = re.findall(r"(https?://\S+)\]", webresult, re.MULTILINE)
# # print("duckduckgo urls", urls)
# return urls
31 changes: 16 additions & 15 deletions utils/chatgpt2api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import json
import copy
from pathlib import Path
from typing import AsyncGenerator

Expand Down Expand Up @@ -411,13 +412,13 @@ def truncate_conversation(
break
return json_post, message_token

def clear_function_call(self, convo_id: str = "default"):
self.conversation[convo_id] = [item for item in self.conversation[convo_id] if '@Trash@' not in item['content']]
function_call_items = [item for item in self.conversation[convo_id] if 'function' in item['role']]
function_call_num = len(function_call_items)
if function_call_num > 50:
for i in range(function_call_num - 25):
self.conversation[convo_id].remove(function_call_items[i])
# def clear_function_call(self, convo_id: str = "default"):
# self.conversation[convo_id] = [item for item in self.conversation[convo_id] if '@Trash@' not in item['content']]
# function_call_items = [item for item in self.conversation[convo_id] if 'function' in item['role']]
# function_call_num = len(function_call_items)
# if function_call_num > 50:
# for i in range(function_call_num - 25):
# self.conversation[convo_id].remove(function_call_items[i])

# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def get_token_count(self, convo_id: str = "default") -> int:
Expand Down Expand Up @@ -485,7 +486,7 @@ def get_post_body(
pass_history: bool = True,
**kwargs,
):
json_post = {
json_post_body = {
"model": os.environ.get("MODEL_NAME") or model or self.engine,
"messages": self.conversation[convo_id] if pass_history else [{"role": "system","content": self.system_prompt},{"role": role, "content": prompt}],
"stream": True,
Expand All @@ -504,12 +505,12 @@ def get_post_body(
"user": role,
"max_tokens": 5000,
}
json_post.update(function_call_list["base"])
json_post_body.update(copy.deepcopy(function_call_list["base"]))
if config.SEARCH_USE_GPT:
json_post["functions"].append(function_call_list["web_search"])
json_post["functions"].append(function_call_list["url_fetch"])
json_post_body["functions"].append(function_call_list["web_search"])
json_post_body["functions"].append(function_call_list["url_fetch"])

return json_post
return json_post_body

def get_max_tokens(self, convo_id: str) -> int:
"""
Expand All @@ -536,8 +537,8 @@ def ask_stream(
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name)
json_post, message_token = self.truncate_conversation(prompt, role, convo_id, model, pass_history, **kwargs)
print(json_post)
print(self.conversation[convo_id])
print(json.dumps(json_post, indent=4))
# print(self.conversation[convo_id])

if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106":
model_max_tokens = kwargs.get("max_tokens", self.max_tokens)
Expand Down Expand Up @@ -642,7 +643,7 @@ def ask_stream(
else:
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
self.function_calls_counter = {}
self.clear_function_call(convo_id=convo_id)
# self.clear_function_call(convo_id=convo_id)
self.encode_web_text_list = []
# total_tokens = self.get_token_count(convo_id)

Expand Down

0 comments on commit ce26978

Please sign in to comment.