From ce26978c22e62920a6b09c7192f77f9c2b87ac2b Mon Sep 17 00:00:00 2001 From: yym68686 Date: Sun, 17 Dec 2023 16:46:19 +0800 Subject: [PATCH] fixed bug: function call is added repeatedly because the json request body update does not use deep copy --- test/test.py | 59 ++++++++++++++++++++++++-------------------- utils/chatgpt2api.py | 31 ++++++++++++----------- 2 files changed, 48 insertions(+), 42 deletions(-) diff --git a/test/test.py b/test/test.py index 9071bf12..d021e5bf 100644 --- a/test/test.py +++ b/test/test.py @@ -1,11 +1,16 @@ +import json # my_list = [ # {"role": "admin", "content": "This is admin content."}, # {"role": "user", "content": "This is user content."} # ] a = {"role": "admin"} -b = {"content": "This is user content."} +b = {"content": []} +c = {"role": "admin"} a.update(b) -# print(a) +a["content"].append(c) +print(b) +print(a) +# print(json.dumps(str(a), indent=4)) # content_list = [item["content"] for item in my_list] # print(content_list) @@ -33,28 +38,28 @@ # print(json.dumps(function_call_list["web_search"], indent=4)) -class openaiAPI: - def __init__( - self, - api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"), - ): - from urllib.parse import urlparse, urlunparse - self.source_api_url: str = api_url - parsed_url = urlparse(self.source_api_url) - self.base_url: str = urlunparse(parsed_url[:2] + ("",) * 4) - self.v1_url: str = urlunparse(parsed_url[:2] + ("/v1",) + ("",) * 3) - self.chat_url: str = urlunparse(parsed_url[:2] + ("/v1/chat/completions",) + ("",) * 3) - self.image_url: str = urlunparse(parsed_url[:2] + ("/v1/images/generations",) + ("",) * 3) - - -a = openaiAPI() -print(a.v1_url) - -def getddgsearchurl(result, numresults=3): - # print("ddg-search", result) - search = DuckDuckGoSearchResults(num_results=numresults) - webresult = search.run(result) - # print("ddgwebresult", webresult) - urls = re.findall(r"(https?://\S+)\]", webresult, re.MULTILINE) - # print("duckduckgo urls", urls) - return urls \ No newline at end of file +# class openaiAPI: +# def __init__( +# self, +# api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"), +# ): +# from urllib.parse import urlparse, urlunparse +# self.source_api_url: str = api_url +# parsed_url = urlparse(self.source_api_url) +# self.base_url: str = urlunparse(parsed_url[:2] + ("",) * 4) +# self.v1_url: str = urlunparse(parsed_url[:2] + ("/v1",) + ("",) * 3) +# self.chat_url: str = urlunparse(parsed_url[:2] + ("/v1/chat/completions",) + ("",) * 3) +# self.image_url: str = urlunparse(parsed_url[:2] + ("/v1/images/generations",) + ("",) * 3) + + +# a = openaiAPI() +# print(a.v1_url) + +# def getddgsearchurl(result, numresults=3): +# # print("ddg-search", result) +# search = DuckDuckGoSearchResults(num_results=numresults) +# webresult = search.run(result) +# # print("ddgwebresult", webresult) +# urls = re.findall(r"(https?://\S+)\]", webresult, re.MULTILINE) +# # print("duckduckgo urls", urls) +# return urls \ No newline at end of file diff --git a/utils/chatgpt2api.py b/utils/chatgpt2api.py index 1090e13f..8754e711 100644 --- a/utils/chatgpt2api.py +++ b/utils/chatgpt2api.py @@ -1,6 +1,7 @@ import os import re import json +import copy from pathlib import Path from typing import AsyncGenerator @@ -411,13 +412,13 @@ def truncate_conversation( break return json_post, message_token - def clear_function_call(self, convo_id: str = "default"): - self.conversation[convo_id] = [item for item in self.conversation[convo_id] if '@Trash@' not in item['content']] - function_call_items = [item for item in self.conversation[convo_id] if 'function' in item['role']] - function_call_num = len(function_call_items) - if function_call_num > 50: - for i in range(function_call_num - 25): - self.conversation[convo_id].remove(function_call_items[i]) + # def clear_function_call(self, convo_id: str = "default"): + # self.conversation[convo_id] = [item for item in self.conversation[convo_id] if '@Trash@' not in item['content']] + # function_call_items = [item for item in self.conversation[convo_id] if 'function' in item['role']] + # function_call_num = len(function_call_items) + # if function_call_num > 50: + # for i in range(function_call_num - 25): + # self.conversation[convo_id].remove(function_call_items[i]) # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def get_token_count(self, convo_id: str = "default") -> int: @@ -485,7 +486,7 @@ def get_post_body( pass_history: bool = True, **kwargs, ): - json_post = { + json_post_body = { "model": os.environ.get("MODEL_NAME") or model or self.engine, "messages": self.conversation[convo_id] if pass_history else [{"role": "system","content": self.system_prompt},{"role": role, "content": prompt}], "stream": True, @@ -504,12 +505,12 @@ def get_post_body( "user": role, "max_tokens": 5000, } - json_post.update(function_call_list["base"]) + json_post_body.update(copy.deepcopy(function_call_list["base"])) if config.SEARCH_USE_GPT: - json_post["functions"].append(function_call_list["web_search"]) - json_post["functions"].append(function_call_list["url_fetch"]) + json_post_body["functions"].append(function_call_list["web_search"]) + json_post_body["functions"].append(function_call_list["url_fetch"]) - return json_post + return json_post_body def get_max_tokens(self, convo_id: str) -> int: """ @@ -536,8 +537,8 @@ def ask_stream( self.reset(convo_id=convo_id, system_prompt=self.system_prompt) self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name) json_post, message_token = self.truncate_conversation(prompt, role, convo_id, model, pass_history, **kwargs) - print(json_post) - print(self.conversation[convo_id]) + print(json.dumps(json_post, indent=4)) + # print(self.conversation[convo_id]) if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106": model_max_tokens = kwargs.get("max_tokens", self.max_tokens) @@ -642,7 +643,7 @@ def ask_stream( else: self.add_to_conversation(full_response, response_role, convo_id=convo_id) self.function_calls_counter = {} - self.clear_function_call(convo_id=convo_id) + # self.clear_function_call(convo_id=convo_id) self.encode_web_text_list = [] # total_tokens = self.get_token_count(convo_id)